Ausgabe der neuen DB Einträge

This commit is contained in:
hubobel 2022-01-02 21:50:48 +01:00
parent bad48e1627
commit cfbbb9ee3d
2399 changed files with 843193 additions and 43 deletions

View file

@ -0,0 +1,12 @@
# -*- test-case-name: twisted.web.test -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Twisted Web: HTTP clients and servers, plus tools for implementing them.
Contains a L{web server<twisted.web.server>} (including an
L{HTTP implementation<twisted.web.http>}, a
L{resource model<twisted.web.resource>}), and
a L{web client<twisted.web.client>}.
"""

View file

@ -0,0 +1,7 @@
# -*- test-case-name: twisted.web.test.test_httpauth -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
HTTP header-based authentication migrated from web2
"""

View file

@ -0,0 +1,61 @@
# -*- test-case-name: twisted.web.test.test_httpauth -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
HTTP BASIC authentication.
@see: U{http://tools.ietf.org/html/rfc1945}
@see: U{http://tools.ietf.org/html/rfc2616}
@see: U{http://tools.ietf.org/html/rfc2617}
"""
from __future__ import division, absolute_import
import binascii
from zope.interface import implementer
from twisted.cred import credentials, error
from twisted.web.iweb import ICredentialFactory
@implementer(ICredentialFactory)
class BasicCredentialFactory(object):
"""
Credential Factory for HTTP Basic Authentication
@type authenticationRealm: L{bytes}
@ivar authenticationRealm: The HTTP authentication realm which will be issued in
challenges.
"""
scheme = b'basic'
def __init__(self, authenticationRealm):
self.authenticationRealm = authenticationRealm
def getChallenge(self, request):
"""
Return a challenge including the HTTP authentication realm with which
this factory was created.
"""
return {'realm': self.authenticationRealm}
def decode(self, response, request):
"""
Parse the base64-encoded, colon-separated username and password into a
L{credentials.UsernamePassword} instance.
"""
try:
creds = binascii.a2b_base64(response + b'===')
except binascii.Error:
raise error.LoginFailed('Invalid credentials')
creds = creds.split(b':', 1)
if len(creds) == 2:
return credentials.UsernamePassword(*creds)
else:
raise error.LoginFailed('Invalid credentials')

View file

@ -0,0 +1,56 @@
# -*- test-case-name: twisted.web.test.test_httpauth -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Implementation of RFC2617: HTTP Digest Authentication
@see: U{http://www.faqs.org/rfcs/rfc2617.html}
"""
from __future__ import division, absolute_import
from zope.interface import implementer
from twisted.cred import credentials
from twisted.web.iweb import ICredentialFactory
@implementer(ICredentialFactory)
class DigestCredentialFactory(object):
"""
Wrapper for L{digest.DigestCredentialFactory} that implements the
L{ICredentialFactory} interface.
"""
scheme = b'digest'
def __init__(self, algorithm, authenticationRealm):
"""
Create the digest credential factory that this object wraps.
"""
self.digest = credentials.DigestCredentialFactory(algorithm,
authenticationRealm)
def getChallenge(self, request):
"""
Generate the challenge for use in the WWW-Authenticate header
@param request: The L{IRequest} to with access was denied and for the
response to which this challenge is being generated.
@return: The L{dict} that can be used to generate a WWW-Authenticate
header.
"""
return self.digest.getChallenge(request.getClientAddress().host)
def decode(self, response, request):
"""
Create a L{twisted.cred.credentials.DigestedCredentials} object
from the given response and request.
@see: L{ICredentialFactory.decode}
"""
return self.digest.decode(response,
request.method,
request.getClientAddress().host)

View file

@ -0,0 +1,236 @@
# -*- test-case-name: twisted.web.test.test_httpauth -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
A guard implementation which supports HTTP header-based authentication
schemes.
If no I{Authorization} header is supplied, an anonymous login will be
attempted by using a L{Anonymous} credentials object. If such a header is
supplied and does not contain allowed credentials, or if anonymous login is
denied, a 401 will be sent in the response along with I{WWW-Authenticate}
headers for each of the allowed authentication schemes.
"""
from __future__ import absolute_import, division
from twisted.cred import error
from twisted.cred.credentials import Anonymous
from twisted.python.compat import unicode
from twisted.python.components import proxyForInterface
from twisted.web import util
from twisted.web.resource import ErrorPage, IResource
from twisted.logger import Logger
from zope.interface import implementer
@implementer(IResource)
class UnauthorizedResource(object):
"""
Simple IResource to escape Resource dispatch
"""
isLeaf = True
def __init__(self, factories):
self._credentialFactories = factories
def render(self, request):
"""
Send www-authenticate headers to the client
"""
def ensureBytes(s):
return s.encode('ascii') if isinstance(s, unicode) else s
def generateWWWAuthenticate(scheme, challenge):
l = []
for k, v in challenge.items():
k = ensureBytes(k)
v = ensureBytes(v)
l.append(k + b"=" + quoteString(v))
return b" ".join([scheme, b", ".join(l)])
def quoteString(s):
return b'"' + s.replace(b'\\', b'\\\\').replace(b'"', b'\\"') + b'"'
request.setResponseCode(401)
for fact in self._credentialFactories:
challenge = fact.getChallenge(request)
request.responseHeaders.addRawHeader(
b'www-authenticate',
generateWWWAuthenticate(fact.scheme, challenge))
if request.method == b'HEAD':
return b''
return b'Unauthorized'
def getChildWithDefault(self, path, request):
"""
Disable resource dispatch
"""
return self
@implementer(IResource)
class HTTPAuthSessionWrapper(object):
"""
Wrap a portal, enforcing supported header-based authentication schemes.
@ivar _portal: The L{Portal} which will be used to retrieve L{IResource}
avatars.
@ivar _credentialFactories: A list of L{ICredentialFactory} providers which
will be used to decode I{Authorization} headers into L{ICredentials}
providers.
"""
isLeaf = False
_log = Logger()
def __init__(self, portal, credentialFactories):
"""
Initialize a session wrapper
@type portal: C{Portal}
@param portal: The portal that will authenticate the remote client
@type credentialFactories: C{Iterable}
@param credentialFactories: The portal that will authenticate the
remote client based on one submitted C{ICredentialFactory}
"""
self._portal = portal
self._credentialFactories = credentialFactories
def _authorizedResource(self, request):
"""
Get the L{IResource} which the given request is authorized to receive.
If the proper authorization headers are present, the resource will be
requested from the portal. If not, an anonymous login attempt will be
made.
"""
authheader = request.getHeader(b'authorization')
if not authheader:
return util.DeferredResource(self._login(Anonymous()))
factory, respString = self._selectParseHeader(authheader)
if factory is None:
return UnauthorizedResource(self._credentialFactories)
try:
credentials = factory.decode(respString, request)
except error.LoginFailed:
return UnauthorizedResource(self._credentialFactories)
except:
self._log.failure("Unexpected failure from credentials factory")
return ErrorPage(500, None, None)
else:
return util.DeferredResource(self._login(credentials))
def render(self, request):
"""
Find the L{IResource} avatar suitable for the given request, if
possible, and render it. Otherwise, perhaps render an error page
requiring authorization or describing an internal server failure.
"""
return self._authorizedResource(request).render(request)
def getChildWithDefault(self, path, request):
"""
Inspect the Authorization HTTP header, and return a deferred which,
when fired after successful authentication, will return an authorized
C{Avatar}. On authentication failure, an C{UnauthorizedResource} will
be returned, essentially halting further dispatch on the wrapped
resource and all children
"""
# Don't consume any segments of the request - this class should be
# transparent!
request.postpath.insert(0, request.prepath.pop())
return self._authorizedResource(request)
def _login(self, credentials):
"""
Get the L{IResource} avatar for the given credentials.
@return: A L{Deferred} which will be called back with an L{IResource}
avatar or which will errback if authentication fails.
"""
d = self._portal.login(credentials, None, IResource)
d.addCallbacks(self._loginSucceeded, self._loginFailed)
return d
def _loginSucceeded(self, args):
"""
Handle login success by wrapping the resulting L{IResource} avatar
so that the C{logout} callback will be invoked when rendering is
complete.
"""
interface, avatar, logout = args
class ResourceWrapper(proxyForInterface(IResource, 'resource')):
"""
Wrap an L{IResource} so that whenever it or a child of it
completes rendering, the cred logout hook will be invoked.
An assumption is made here that exactly one L{IResource} from
among C{avatar} and all of its children will be rendered. If
more than one is rendered, C{logout} will be invoked multiple
times and probably earlier than desired.
"""
def getChildWithDefault(self, name, request):
"""
Pass through the lookup to the wrapped resource, wrapping
the result in L{ResourceWrapper} to ensure C{logout} is
called when rendering of the child is complete.
"""
return ResourceWrapper(self.resource.getChildWithDefault(name, request))
def render(self, request):
"""
Hook into response generation so that when rendering has
finished completely (with or without error), C{logout} is
called.
"""
request.notifyFinish().addBoth(lambda ign: logout())
return super(ResourceWrapper, self).render(request)
return ResourceWrapper(avatar)
def _loginFailed(self, result):
"""
Handle login failure by presenting either another challenge (for
expected authentication/authorization-related failures) or a server
error page (for anything else).
"""
if result.check(error.Unauthorized, error.LoginFailed):
return UnauthorizedResource(self._credentialFactories)
else:
self._log.failure(
"HTTPAuthSessionWrapper.getChildWithDefault encountered "
"unexpected error",
failure=result,
)
return ErrorPage(500, None, None)
def _selectParseHeader(self, header):
"""
Choose an C{ICredentialFactory} from C{_credentialFactories}
suitable to use to decode the given I{Authenticate} header.
@return: A two-tuple of a factory and the remaining portion of the
header value to be decoded or a two-tuple of L{None} if no
factory can decode the header value.
"""
elements = header.split(b' ')
scheme = elements[0].lower()
for fact in self._credentialFactories:
if fact.scheme == scheme:
return (fact, b' '.join(elements[1:]))
return (None, None)

View file

@ -0,0 +1,185 @@
# -*- test-case-name: twisted.web.test.test_template -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
from __future__ import division, absolute_import
from zope.interface import implementer
from twisted.web.iweb import IRenderable
from twisted.web.error import MissingRenderMethod, UnexposedMethodError
from twisted.web.error import MissingTemplateLoader
class Expose(object):
"""
Helper for exposing methods for various uses using a simple decorator-style
callable.
Instances of this class can be called with one or more functions as
positional arguments. The names of these functions will be added to a list
on the class object of which they are methods.
@ivar attributeName: The attribute with which exposed methods will be
tracked.
"""
def __init__(self, doc=None):
self.doc = doc
def __call__(self, *funcObjs):
"""
Add one or more functions to the set of exposed functions.
This is a way to declare something about a class definition, similar to
L{zope.interface.declarations.implementer}. Use it like this::
magic = Expose('perform extra magic')
class Foo(Bar):
def twiddle(self, x, y):
...
def frob(self, a, b):
...
magic(twiddle, frob)
Later you can query the object::
aFoo = Foo()
magic.get(aFoo, 'twiddle')(x=1, y=2)
The call to C{get} will fail if the name it is given has not been
exposed using C{magic}.
@param funcObjs: One or more function objects which will be exposed to
the client.
@return: The first of C{funcObjs}.
"""
if not funcObjs:
raise TypeError("expose() takes at least 1 argument (0 given)")
for fObj in funcObjs:
fObj.exposedThrough = getattr(fObj, 'exposedThrough', [])
fObj.exposedThrough.append(self)
return funcObjs[0]
_nodefault = object()
def get(self, instance, methodName, default=_nodefault):
"""
Retrieve an exposed method with the given name from the given instance.
@raise UnexposedMethodError: Raised if C{default} is not specified and
there is no exposed method with the given name.
@return: A callable object for the named method assigned to the given
instance.
"""
method = getattr(instance, methodName, None)
exposedThrough = getattr(method, 'exposedThrough', [])
if self not in exposedThrough:
if default is self._nodefault:
raise UnexposedMethodError(self, methodName)
return default
return method
@classmethod
def _withDocumentation(cls, thunk):
"""
Slight hack to make users of this class appear to have a docstring to
documentation generators, by defining them with a decorator. (This hack
should be removed when epydoc can be convinced to use some other method
for documenting.)
"""
return cls(thunk.__doc__)
# Avoid exposing the ugly, private classmethod name in the docs. Luckily this
# namespace is private already so this doesn't leak further.
exposer = Expose._withDocumentation
@exposer
def renderer():
"""
Decorate with L{renderer} to use methods as template render directives.
For example::
class Foo(Element):
@renderer
def twiddle(self, request, tag):
return tag('Hello, world.')
<div xmlns:t="http://twistedmatrix.com/ns/twisted.web.template/0.1">
<span t:render="twiddle" />
</div>
Will result in this final output::
<div>
<span>Hello, world.</span>
</div>
"""
@implementer(IRenderable)
class Element(object):
"""
Base for classes which can render part of a page.
An Element is a renderer that can be embedded in a stan document and can
hook its template (from the loader) up to render methods.
An Element might be used to encapsulate the rendering of a complex piece of
data which is to be displayed in multiple different contexts. The Element
allows the rendering logic to be easily re-used in different ways.
Element returns render methods which are registered using
L{twisted.web._element.renderer}. For example::
class Menu(Element):
@renderer
def items(self, request, tag):
....
Render methods are invoked with two arguments: first, the
L{twisted.web.http.Request} being served and second, the tag object which
"invoked" the render method.
@type loader: L{ITemplateLoader} provider
@ivar loader: The factory which will be used to load documents to
return from C{render}.
"""
loader = None
def __init__(self, loader=None):
if loader is not None:
self.loader = loader
def lookupRenderMethod(self, name):
"""
Look up and return the named render method.
"""
method = renderer.get(self, name, None)
if method is None:
raise MissingRenderMethod(self, name)
return method
def render(self, request):
"""
Implement L{IRenderable} to allow one L{Element} to be embedded in
another's template or rendering output.
(This will simply load the template from the C{loader}; when used in a
template, the flattening engine will keep track of this object
separately as the object to lookup renderers on and call
L{Element.renderer} to look them up. The resulting object from this
method is not directly associated with this L{Element}.)
"""
loader = self.loader
if loader is None:
raise MissingTemplateLoader(self)
return loader.load()

View file

@ -0,0 +1,421 @@
# -*- test-case-name: twisted.web.test.test_flatten -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Context-free flattener/serializer for rendering Python objects, possibly
complex or arbitrarily nested, as strings.
"""
from __future__ import division, absolute_import
from io import BytesIO
from sys import exc_info
from types import GeneratorType
from traceback import extract_tb
try:
from inspect import iscoroutine
except ImportError:
def iscoroutine(*args, **kwargs):
return False
from twisted.python.compat import unicode, nativeString, iteritems
from twisted.internet.defer import Deferred, ensureDeferred
from twisted.web._stan import Tag, slot, voidElements, Comment, CDATA, CharRef
from twisted.web.error import UnfilledSlot, UnsupportedType, FlattenerError
from twisted.web.iweb import IRenderable
def escapeForContent(data):
"""
Escape some character or UTF-8 byte data for inclusion in an HTML or XML
document, by replacing metacharacters (C{&<>}) with their entity
equivalents (C{&amp;&lt;&gt;}).
This is used as an input to L{_flattenElement}'s C{dataEscaper} parameter.
@type data: C{bytes} or C{unicode}
@param data: The string to escape.
@rtype: C{bytes}
@return: The quoted form of C{data}. If C{data} is unicode, return a utf-8
encoded string.
"""
if isinstance(data, unicode):
data = data.encode('utf-8')
data = data.replace(b'&', b'&amp;'
).replace(b'<', b'&lt;'
).replace(b'>', b'&gt;')
return data
def attributeEscapingDoneOutside(data):
"""
Escape some character or UTF-8 byte data for inclusion in the top level of
an attribute. L{attributeEscapingDoneOutside} actually passes the data
through unchanged, because L{writeWithAttributeEscaping} handles the
quoting of the text within attributes outside the generator returned by
L{_flattenElement}; this is used as the C{dataEscaper} argument to that
L{_flattenElement} call so that that generator does not redundantly escape
its text output.
@type data: C{bytes} or C{unicode}
@param data: The string to escape.
@return: The string, unchanged, except for encoding.
@rtype: C{bytes}
"""
if isinstance(data, unicode):
return data.encode("utf-8")
return data
def writeWithAttributeEscaping(write):
"""
Decorate a C{write} callable so that all output written is properly quoted
for inclusion within an XML attribute value.
If a L{Tag <twisted.web.template.Tag>} C{x} is flattened within the context
of the contents of another L{Tag <twisted.web.template.Tag>} C{y}, the
metacharacters (C{<>&"}) delimiting C{x} should be passed through
unchanged, but the textual content of C{x} should still be quoted, as
usual. For example: C{<y><x>&amp;</x></y>}. That is the default behavior
of L{_flattenElement} when L{escapeForContent} is passed as the
C{dataEscaper}.
However, when a L{Tag <twisted.web.template.Tag>} C{x} is flattened within
the context of an I{attribute} of another L{Tag <twisted.web.template.Tag>}
C{y}, then the metacharacters delimiting C{x} should be quoted so that it
can be parsed from the attribute's value. In the DOM itself, this is not a
valid thing to do, but given that renderers and slots may be freely moved
around in a L{twisted.web.template} template, it is a condition which may
arise in a document and must be handled in a way which produces valid
output. So, for example, you should be able to get C{<y attr="&lt;x /&gt;"
/>}. This should also be true for other XML/HTML meta-constructs such as
comments and CDATA, so if you were to serialize a L{comment
<twisted.web.template.Comment>} in an attribute you should get C{<y
attr="&lt;-- comment --&gt;" />}. Therefore in order to capture these
meta-characters, flattening is done with C{write} callable that is wrapped
with L{writeWithAttributeEscaping}.
The final case, and hopefully the much more common one as compared to
serializing L{Tag <twisted.web.template.Tag>} and arbitrary L{IRenderable}
objects within an attribute, is to serialize a simple string, and those
should be passed through for L{writeWithAttributeEscaping} to quote
without applying a second, redundant level of quoting.
@param write: A callable which will be invoked with the escaped L{bytes}.
@return: A callable that writes data with escaping.
"""
def _write(data):
write(escapeForContent(data).replace(b'"', b'&quot;'))
return _write
def escapedCDATA(data):
"""
Escape CDATA for inclusion in a document.
@type data: L{str} or L{unicode}
@param data: The string to escape.
@rtype: L{str}
@return: The quoted form of C{data}. If C{data} is unicode, return a utf-8
encoded string.
"""
if isinstance(data, unicode):
data = data.encode('utf-8')
return data.replace(b']]>', b']]]]><![CDATA[>')
def escapedComment(data):
"""
Escape a comment for inclusion in a document.
@type data: L{str} or L{unicode}
@param data: The string to escape.
@rtype: C{str}
@return: The quoted form of C{data}. If C{data} is unicode, return a utf-8
encoded string.
"""
if isinstance(data, unicode):
data = data.encode('utf-8')
data = data.replace(b'--', b'- - ').replace(b'>', b'&gt;')
if data and data[-1:] == b'-':
data += b' '
return data
def _getSlotValue(name, slotData, default=None):
"""
Find the value of the named slot in the given stack of slot data.
"""
for slotFrame in slotData[::-1]:
if slotFrame is not None and name in slotFrame:
return slotFrame[name]
else:
if default is not None:
return default
raise UnfilledSlot(name)
def _flattenElement(request, root, write, slotData, renderFactory,
dataEscaper):
"""
Make C{root} slightly more flat by yielding all its immediate contents as
strings, deferreds or generators that are recursive calls to itself.
@param request: A request object which will be passed to
L{IRenderable.render}.
@param root: An object to be made flatter. This may be of type C{unicode},
L{str}, L{slot}, L{Tag <twisted.web.template.Tag>}, L{tuple}, L{list},
L{types.GeneratorType}, L{Deferred}, or an object that implements
L{IRenderable}.
@param write: A callable which will be invoked with each L{bytes} produced
by flattening C{root}.
@param slotData: A L{list} of L{dict} mapping L{str} slot names to data
with which those slots will be replaced.
@param renderFactory: If not L{None}, an object that provides
L{IRenderable}.
@param dataEscaper: A 1-argument callable which takes L{bytes} or
L{unicode} and returns L{bytes}, quoted as appropriate for the
rendering context. This is really only one of two values:
L{attributeEscapingDoneOutside} or L{escapeForContent}, depending on
whether the rendering context is within an attribute or not. See the
explanation in L{writeWithAttributeEscaping}.
@return: An iterator that eventually yields L{bytes} that should be written
to the output. However it may also yield other iterators or
L{Deferred}s; if it yields another iterator, the caller will iterate
it; if it yields a L{Deferred}, the result of that L{Deferred} will
either be L{bytes}, in which case it's written, or another generator,
in which case it is iterated. See L{_flattenTree} for the trampoline
that consumes said values.
@rtype: An iterator which yields L{bytes}, L{Deferred}, and more iterators
of the same type.
"""
def keepGoing(newRoot, dataEscaper=dataEscaper,
renderFactory=renderFactory, write=write):
return _flattenElement(request, newRoot, write, slotData,
renderFactory, dataEscaper)
if isinstance(root, (bytes, unicode)):
write(dataEscaper(root))
elif isinstance(root, slot):
slotValue = _getSlotValue(root.name, slotData, root.default)
yield keepGoing(slotValue)
elif isinstance(root, CDATA):
write(b'<![CDATA[')
write(escapedCDATA(root.data))
write(b']]>')
elif isinstance(root, Comment):
write(b'<!--')
write(escapedComment(root.data))
write(b'-->')
elif isinstance(root, Tag):
slotData.append(root.slotData)
if root.render is not None:
rendererName = root.render
rootClone = root.clone(False)
rootClone.render = None
renderMethod = renderFactory.lookupRenderMethod(rendererName)
result = renderMethod(request, rootClone)
yield keepGoing(result)
slotData.pop()
return
if not root.tagName:
yield keepGoing(root.children)
return
write(b'<')
if isinstance(root.tagName, unicode):
tagName = root.tagName.encode('ascii')
else:
tagName = root.tagName
write(tagName)
for k, v in iteritems(root.attributes):
if isinstance(k, unicode):
k = k.encode('ascii')
write(b' ' + k + b'="')
# Serialize the contents of the attribute, wrapping the results of
# that serialization so that _everything_ is quoted.
yield keepGoing(
v,
attributeEscapingDoneOutside,
write=writeWithAttributeEscaping(write))
write(b'"')
if root.children or nativeString(tagName) not in voidElements:
write(b'>')
# Regardless of whether we're in an attribute or not, switch back
# to the escapeForContent dataEscaper. The contents of a tag must
# be quoted no matter what; in the top-level document, just so
# they're valid, and if they're within an attribute, they have to
# be quoted so that after applying the *un*-quoting required to re-
# parse the tag within the attribute, all the quoting is still
# correct.
yield keepGoing(root.children, escapeForContent)
write(b'</' + tagName + b'>')
else:
write(b' />')
elif isinstance(root, (tuple, list, GeneratorType)):
for element in root:
yield keepGoing(element)
elif isinstance(root, CharRef):
escaped = '&#%d;' % (root.ordinal,)
write(escaped.encode('ascii'))
elif isinstance(root, Deferred):
yield root.addCallback(lambda result: (result, keepGoing(result)))
elif iscoroutine(root):
d = ensureDeferred(root)
yield d.addCallback(lambda result: (result, keepGoing(result)))
elif IRenderable.providedBy(root):
result = root.render(request)
yield keepGoing(result, renderFactory=root)
else:
raise UnsupportedType(root)
def _flattenTree(request, root, write):
"""
Make C{root} into an iterable of L{bytes} and L{Deferred} by doing a depth
first traversal of the tree.
@param request: A request object which will be passed to
L{IRenderable.render}.
@param root: An object to be made flatter. This may be of type C{unicode},
L{bytes}, L{slot}, L{Tag <twisted.web.template.Tag>}, L{tuple},
L{list}, L{types.GeneratorType}, L{Deferred}, or something providing
L{IRenderable}.
@param write: A callable which will be invoked with each L{bytes} produced
by flattening C{root}.
@return: An iterator which yields objects of type L{bytes} and L{Deferred}.
A L{Deferred} is only yielded when one is encountered in the process of
flattening C{root}. The returned iterator must not be iterated again
until the L{Deferred} is called back.
"""
stack = [_flattenElement(request, root, write, [], None, escapeForContent)]
while stack:
try:
frame = stack[-1].gi_frame
element = next(stack[-1])
except StopIteration:
stack.pop()
except Exception as e:
stack.pop()
roots = []
for generator in stack:
roots.append(generator.gi_frame.f_locals['root'])
roots.append(frame.f_locals['root'])
raise FlattenerError(e, roots, extract_tb(exc_info()[2]))
else:
if isinstance(element, Deferred):
def cbx(originalAndToFlatten):
original, toFlatten = originalAndToFlatten
stack.append(toFlatten)
return original
yield element.addCallback(cbx)
else:
stack.append(element)
def _writeFlattenedData(state, write, result):
"""
Take strings from an iterator and pass them to a writer function.
@param state: An iterator of L{str} and L{Deferred}. L{str} instances will
be passed to C{write}. L{Deferred} instances will be waited on before
resuming iteration of C{state}.
@param write: A callable which will be invoked with each L{str}
produced by iterating C{state}.
@param result: A L{Deferred} which will be called back when C{state} has
been completely flattened into C{write} or which will be errbacked if
an exception in a generator passed to C{state} or an errback from a
L{Deferred} from state occurs.
@return: L{None}
"""
while True:
try:
element = next(state)
except StopIteration:
result.callback(None)
except:
result.errback()
else:
def cby(original):
_writeFlattenedData(state, write, result)
return original
element.addCallbacks(cby, result.errback)
break
def flatten(request, root, write):
"""
Incrementally write out a string representation of C{root} using C{write}.
In order to create a string representation, C{root} will be decomposed into
simpler objects which will themselves be decomposed and so on until strings
or objects which can easily be converted to strings are encountered.
@param request: A request object which will be passed to the C{render}
method of any L{IRenderable} provider which is encountered.
@param root: An object to be made flatter. This may be of type L{unicode},
L{bytes}, L{slot}, L{Tag <twisted.web.template.Tag>}, L{tuple},
L{list}, L{types.GeneratorType}, L{Deferred}, or something that provides
L{IRenderable}.
@param write: A callable which will be invoked with each L{bytes} produced
by flattening C{root}.
@return: A L{Deferred} which will be called back when C{root} has been
completely flattened into C{write} or which will be errbacked if an
unexpected exception occurs.
"""
result = Deferred()
state = _flattenTree(request, root, write)
_writeFlattenedData(state, write, result)
return result
def flattenString(request, root):
"""
Collate a string representation of C{root} into a single string.
This is basically gluing L{flatten} to an L{io.BytesIO} and returning
the results. See L{flatten} for the exact meanings of C{request} and
C{root}.
@return: A L{Deferred} which will be called back with a single string as
its result when C{root} has been completely flattened into C{write} or
which will be errbacked if an unexpected exception occurs.
"""
io = BytesIO()
d = flatten(request, root, io.write)
d.addCallback(lambda _: io.getvalue())
return d

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,114 @@
# -*- test-case-name: twisted.web.test.test_http -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
HTTP response code definitions.
"""
from __future__ import division, absolute_import
_CONTINUE = 100
SWITCHING = 101
OK = 200
CREATED = 201
ACCEPTED = 202
NON_AUTHORITATIVE_INFORMATION = 203
NO_CONTENT = 204
RESET_CONTENT = 205
PARTIAL_CONTENT = 206
MULTI_STATUS = 207
MULTIPLE_CHOICE = 300
MOVED_PERMANENTLY = 301
FOUND = 302
SEE_OTHER = 303
NOT_MODIFIED = 304
USE_PROXY = 305
TEMPORARY_REDIRECT = 307
BAD_REQUEST = 400
UNAUTHORIZED = 401
PAYMENT_REQUIRED = 402
FORBIDDEN = 403
NOT_FOUND = 404
NOT_ALLOWED = 405
NOT_ACCEPTABLE = 406
PROXY_AUTH_REQUIRED = 407
REQUEST_TIMEOUT = 408
CONFLICT = 409
GONE = 410
LENGTH_REQUIRED = 411
PRECONDITION_FAILED = 412
REQUEST_ENTITY_TOO_LARGE = 413
REQUEST_URI_TOO_LONG = 414
UNSUPPORTED_MEDIA_TYPE = 415
REQUESTED_RANGE_NOT_SATISFIABLE = 416
EXPECTATION_FAILED = 417
INTERNAL_SERVER_ERROR = 500
NOT_IMPLEMENTED = 501
BAD_GATEWAY = 502
SERVICE_UNAVAILABLE = 503
GATEWAY_TIMEOUT = 504
HTTP_VERSION_NOT_SUPPORTED = 505
INSUFFICIENT_STORAGE_SPACE = 507
NOT_EXTENDED = 510
RESPONSES = {
# 100
_CONTINUE: b"Continue",
SWITCHING: b"Switching Protocols",
# 200
OK: b"OK",
CREATED: b"Created",
ACCEPTED: b"Accepted",
NON_AUTHORITATIVE_INFORMATION: b"Non-Authoritative Information",
NO_CONTENT: b"No Content",
RESET_CONTENT: b"Reset Content.",
PARTIAL_CONTENT: b"Partial Content",
MULTI_STATUS: b"Multi-Status",
# 300
MULTIPLE_CHOICE: b"Multiple Choices",
MOVED_PERMANENTLY: b"Moved Permanently",
FOUND: b"Found",
SEE_OTHER: b"See Other",
NOT_MODIFIED: b"Not Modified",
USE_PROXY: b"Use Proxy",
# 306 not defined??
TEMPORARY_REDIRECT: b"Temporary Redirect",
# 400
BAD_REQUEST: b"Bad Request",
UNAUTHORIZED: b"Unauthorized",
PAYMENT_REQUIRED: b"Payment Required",
FORBIDDEN: b"Forbidden",
NOT_FOUND: b"Not Found",
NOT_ALLOWED: b"Method Not Allowed",
NOT_ACCEPTABLE: b"Not Acceptable",
PROXY_AUTH_REQUIRED: b"Proxy Authentication Required",
REQUEST_TIMEOUT: b"Request Time-out",
CONFLICT: b"Conflict",
GONE: b"Gone",
LENGTH_REQUIRED: b"Length Required",
PRECONDITION_FAILED: b"Precondition Failed",
REQUEST_ENTITY_TOO_LARGE: b"Request Entity Too Large",
REQUEST_URI_TOO_LONG: b"Request-URI Too Long",
UNSUPPORTED_MEDIA_TYPE: b"Unsupported Media Type",
REQUESTED_RANGE_NOT_SATISFIABLE: b"Requested Range not satisfiable",
EXPECTATION_FAILED: b"Expectation Failed",
# 500
INTERNAL_SERVER_ERROR: b"Internal Server Error",
NOT_IMPLEMENTED: b"Not Implemented",
BAD_GATEWAY: b"Bad Gateway",
SERVICE_UNAVAILABLE: b"Service Unavailable",
GATEWAY_TIMEOUT: b"Gateway Time-out",
HTTP_VERSION_NOT_SUPPORTED: b"HTTP Version not supported",
INSUFFICIENT_STORAGE_SPACE: b"Insufficient Storage Space",
NOT_EXTENDED: b"Not Extended"
}

View file

@ -0,0 +1,330 @@
# -*- test-case-name: twisted.web.test.test_stan -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
An s-expression-like syntax for expressing xml in pure python.
Stan tags allow you to build XML documents using Python.
Stan is a DOM, or Document Object Model, implemented using basic Python types
and functions called "flatteners". A flattener is a function that knows how to
turn an object of a specific type into something that is closer to an HTML
string. Stan differs from the W3C DOM by not being as cumbersome and heavy
weight. Since the object model is built using simple python types such as lists,
strings, and dictionaries, the API is simpler and constructing a DOM less
cumbersome.
@var voidElements: the names of HTML 'U{void
elements<http://www.whatwg.org/specs/web-apps/current-work/multipage/syntax.html#void-elements>}';
those which can't have contents and can therefore be self-closing in the
output.
"""
from __future__ import absolute_import, division
from twisted.python.compat import iteritems
class slot(object):
"""
Marker for markup insertion in a template.
@type name: C{str}
@ivar name: The name of this slot. The key which must be used in
L{Tag.fillSlots} to fill it.
@type children: C{list}
@ivar children: The L{Tag} objects included in this L{slot}'s template.
@type default: anything flattenable, or L{None}
@ivar default: The default contents of this slot, if it is left unfilled.
If this is L{None}, an L{UnfilledSlot} will be raised, rather than
L{None} actually being used.
@type filename: C{str} or L{None}
@ivar filename: The name of the XML file from which this tag was parsed.
If it was not parsed from an XML file, L{None}.
@type lineNumber: C{int} or L{None}
@ivar lineNumber: The line number on which this tag was encountered in the
XML file from which it was parsed. If it was not parsed from an XML
file, L{None}.
@type columnNumber: C{int} or L{None}
@ivar columnNumber: The column number at which this tag was encountered in
the XML file from which it was parsed. If it was not parsed from an
XML file, L{None}.
"""
def __init__(self, name, default=None, filename=None, lineNumber=None,
columnNumber=None):
self.name = name
self.children = []
self.default = default
self.filename = filename
self.lineNumber = lineNumber
self.columnNumber = columnNumber
def __repr__(self):
return "slot(%r)" % (self.name,)
class Tag(object):
"""
A L{Tag} represents an XML tags with a tag name, attributes, and children.
A L{Tag} can be constructed using the special L{twisted.web.template.tags}
object, or it may be constructed directly with a tag name. L{Tag}s have a
special method, C{__call__}, which makes representing trees of XML natural
using pure python syntax.
@ivar tagName: The name of the represented element. For a tag like
C{<div></div>}, this would be C{"div"}.
@type tagName: C{str}
@ivar attributes: The attributes of the element.
@type attributes: C{dict} mapping C{str} to renderable objects.
@ivar children: The child L{Tag}s of this C{Tag}.
@type children: C{list} of renderable objects.
@ivar render: The name of the render method to use for this L{Tag}. This
name will be looked up at render time by the
L{twisted.web.template.Element} doing the rendering, via
L{twisted.web.template.Element.lookupRenderMethod}, to determine which
method to call.
@type render: C{str}
@type filename: C{str} or L{None}
@ivar filename: The name of the XML file from which this tag was parsed.
If it was not parsed from an XML file, L{None}.
@type lineNumber: C{int} or L{None}
@ivar lineNumber: The line number on which this tag was encountered in the
XML file from which it was parsed. If it was not parsed from an XML
file, L{None}.
@type columnNumber: C{int} or L{None}
@ivar columnNumber: The column number at which this tag was encountered in
the XML file from which it was parsed. If it was not parsed from an
XML file, L{None}.
@type slotData: C{dict} or L{None}
@ivar slotData: The data which can fill slots. If present, a dictionary
mapping slot names to renderable values. The values in this dict might
be anything that can be present as the child of a L{Tag}; strings,
lists, L{Tag}s, generators, etc.
"""
slotData = None
filename = None
lineNumber = None
columnNumber = None
def __init__(self, tagName, attributes=None, children=None, render=None,
filename=None, lineNumber=None, columnNumber=None):
self.tagName = tagName
self.render = render
if attributes is None:
self.attributes = {}
else:
self.attributes = attributes
if children is None:
self.children = []
else:
self.children = children
if filename is not None:
self.filename = filename
if lineNumber is not None:
self.lineNumber = lineNumber
if columnNumber is not None:
self.columnNumber = columnNumber
def fillSlots(self, **slots):
"""
Remember the slots provided at this position in the DOM.
During the rendering of children of this node, slots with names in
C{slots} will be rendered as their corresponding values.
@return: C{self}. This enables the idiom C{return tag.fillSlots(...)} in
renderers.
"""
if self.slotData is None:
self.slotData = {}
self.slotData.update(slots)
return self
def __call__(self, *children, **kw):
"""
Add children and change attributes on this tag.
This is implemented using __call__ because it then allows the natural
syntax::
table(tr1, tr2, width="100%", height="50%", border="1")
Children may be other tag instances, strings, functions, or any other
object which has a registered flatten.
Attributes may be 'transparent' tag instances (so that
C{a(href=transparent(data="foo", render=myhrefrenderer))} works),
strings, functions, or any other object which has a registered
flattener.
If the attribute is a python keyword, such as 'class', you can add an
underscore to the name, like 'class_'.
There is one special keyword argument, 'render', which will be used as
the name of the renderer and saved as the 'render' attribute of this
instance, rather than the DOM 'render' attribute in the attributes
dictionary.
"""
self.children.extend(children)
for k, v in iteritems(kw):
if k[-1] == '_':
k = k[:-1]
if k == 'render':
self.render = v
else:
self.attributes[k] = v
return self
def _clone(self, obj, deep):
"""
Clone an arbitrary object; used by L{Tag.clone}.
@param obj: an object with a clone method, a list or tuple, or something
which should be immutable.
@param deep: whether to continue cloning child objects; i.e. the
contents of lists, the sub-tags within a tag.
@return: a clone of C{obj}.
"""
if hasattr(obj, 'clone'):
return obj.clone(deep)
elif isinstance(obj, (list, tuple)):
return [self._clone(x, deep) for x in obj]
else:
return obj
def clone(self, deep=True):
"""
Return a clone of this tag. If deep is True, clone all of this tag's
children. Otherwise, just shallow copy the children list without copying
the children themselves.
"""
if deep:
newchildren = [self._clone(x, True) for x in self.children]
else:
newchildren = self.children[:]
newattrs = self.attributes.copy()
for key in newattrs.keys():
newattrs[key] = self._clone(newattrs[key], True)
newslotdata = None
if self.slotData:
newslotdata = self.slotData.copy()
for key in newslotdata:
newslotdata[key] = self._clone(newslotdata[key], True)
newtag = Tag(
self.tagName,
attributes=newattrs,
children=newchildren,
render=self.render,
filename=self.filename,
lineNumber=self.lineNumber,
columnNumber=self.columnNumber)
newtag.slotData = newslotdata
return newtag
def clear(self):
"""
Clear any existing children from this tag.
"""
self.children = []
return self
def __repr__(self):
rstr = ''
if self.attributes:
rstr += ', attributes=%r' % self.attributes
if self.children:
rstr += ', children=%r' % self.children
return "Tag(%r%s)" % (self.tagName, rstr)
voidElements = ('img', 'br', 'hr', 'base', 'meta', 'link', 'param', 'area',
'input', 'col', 'basefont', 'isindex', 'frame', 'command',
'embed', 'keygen', 'source', 'track', 'wbs')
class CDATA(object):
"""
A C{<![CDATA[]]>} block from a template. Given a separate representation in
the DOM so that they may be round-tripped through rendering without losing
information.
@ivar data: The data between "C{<![CDATA[}" and "C{]]>}".
@type data: C{unicode}
"""
def __init__(self, data):
self.data = data
def __repr__(self):
return 'CDATA(%r)' % (self.data,)
class Comment(object):
"""
A C{<!-- -->} comment from a template. Given a separate representation in
the DOM so that they may be round-tripped through rendering without losing
information.
@ivar data: The data between "C{<!--}" and "C{-->}".
@type data: C{unicode}
"""
def __init__(self, data):
self.data = data
def __repr__(self):
return 'Comment(%r)' % (self.data,)
class CharRef(object):
"""
A numeric character reference. Given a separate representation in the DOM
so that non-ASCII characters may be output as pure ASCII.
@ivar ordinal: The ordinal value of the unicode character to which this is
object refers.
@type ordinal: C{int}
@since: 12.0
"""
def __init__(self, ordinal):
self.ordinal = ordinal
def __repr__(self):
return "CharRef(%d)" % (self.ordinal,)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,26 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
I am a simple test resource.
"""
from __future__ import absolute_import, division
from twisted.web import static
class Test(static.Data):
isLeaf = True
def __init__(self):
static.Data.__init__(
self,
b"""
<html>
<head><title>Twisted Web Demo</title><head>
<body>
Hello! This is a Twisted Web test page.
</body>
</html>
""",
"text/html")

View file

@ -0,0 +1,386 @@
# -*- test-case-name: twisted.web.test.test_distrib -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Distributed web servers.
This is going to have to be refactored so that argument parsing is done
by each subprocess and not by the main web server (i.e. GET, POST etc.).
"""
# System Imports
import os, copy
try:
import pwd
except ImportError:
pwd = None
from io import BytesIO
from xml.dom.minidom import getDOMImplementation
# Twisted Imports
from twisted.spread import pb
from twisted.spread.banana import SIZE_LIMIT
from twisted.web import http, resource, server, util, static
from twisted.web.http_headers import Headers
from twisted.persisted import styles
from twisted.internet import address, reactor
from twisted.logger import Logger
class _ReferenceableProducerWrapper(pb.Referenceable):
def __init__(self, producer):
self.producer = producer
def remote_resumeProducing(self):
self.producer.resumeProducing()
def remote_pauseProducing(self):
self.producer.pauseProducing()
def remote_stopProducing(self):
self.producer.stopProducing()
class Request(pb.RemoteCopy, server.Request):
"""
A request which was received by a L{ResourceSubscription} and sent via
PB to a distributed node.
"""
def setCopyableState(self, state):
"""
Initialize this L{twisted.web.distrib.Request} based on the copied
state so that it closely resembles a L{twisted.web.server.Request}.
"""
for k in 'host', 'client':
tup = state[k]
addrdesc = {'INET': 'TCP', 'UNIX': 'UNIX'}[tup[0]]
addr = {'TCP': lambda: address.IPv4Address(addrdesc,
tup[1], tup[2]),
'UNIX': lambda: address.UNIXAddress(tup[1])}[addrdesc]()
state[k] = addr
state['requestHeaders'] = Headers(dict(state['requestHeaders']))
pb.RemoteCopy.setCopyableState(self, state)
# Emulate the local request interface --
self.content = BytesIO(self.content_data)
self.finish = self.remote.remoteMethod('finish')
self.setHeader = self.remote.remoteMethod('setHeader')
self.addCookie = self.remote.remoteMethod('addCookie')
self.setETag = self.remote.remoteMethod('setETag')
self.setResponseCode = self.remote.remoteMethod('setResponseCode')
self.setLastModified = self.remote.remoteMethod('setLastModified')
# To avoid failing if a resource tries to write a very long string
# all at once, this one will be handled slightly differently.
self._write = self.remote.remoteMethod('write')
def write(self, bytes):
"""
Write the given bytes to the response body.
@param bytes: The bytes to write. If this is longer than 640k, it
will be split up into smaller pieces.
"""
start = 0
end = SIZE_LIMIT
while True:
self._write(bytes[start:end])
start += SIZE_LIMIT
end += SIZE_LIMIT
if start >= len(bytes):
break
def registerProducer(self, producer, streaming):
self.remote.callRemote("registerProducer",
_ReferenceableProducerWrapper(producer),
streaming).addErrback(self.fail)
def unregisterProducer(self):
self.remote.callRemote("unregisterProducer").addErrback(self.fail)
def fail(self, failure):
self._log.failure('', failure=failure)
pb.setUnjellyableForClass(server.Request, Request)
class Issue:
_log = Logger()
def __init__(self, request):
self.request = request
def finished(self, result):
if result is not server.NOT_DONE_YET:
assert isinstance(result, str), "return value not a string"
self.request.write(result)
self.request.finish()
def failed(self, failure):
#XXX: Argh. FIXME.
failure = str(failure)
self.request.write(
resource.ErrorPage(http.INTERNAL_SERVER_ERROR,
"Server Connection Lost",
"Connection to distributed server lost:" +
util._PRE(failure)).
render(self.request))
self.request.finish()
self._log.info(failure)
class ResourceSubscription(resource.Resource):
isLeaf = 1
waiting = 0
_log = Logger()
def __init__(self, host, port):
resource.Resource.__init__(self)
self.host = host
self.port = port
self.pending = []
self.publisher = None
def __getstate__(self):
"""Get persistent state for this ResourceSubscription.
"""
# When I unserialize,
state = copy.copy(self.__dict__)
# Publisher won't be connected...
state['publisher'] = None
# I won't be making a connection
state['waiting'] = 0
# There will be no pending requests.
state['pending'] = []
return state
def connected(self, publisher):
"""I've connected to a publisher; I'll now send all my requests.
"""
self._log.info('connected to publisher')
publisher.broker.notifyOnDisconnect(self.booted)
self.publisher = publisher
self.waiting = 0
for request in self.pending:
self.render(request)
self.pending = []
def notConnected(self, msg):
"""I can't connect to a publisher; I'll now reply to all pending
requests.
"""
self._log.info(
"could not connect to distributed web service: {msg}",
msg=msg
)
self.waiting = 0
self.publisher = None
for request in self.pending:
request.write("Unable to connect to distributed server.")
request.finish()
self.pending = []
def booted(self):
self.notConnected("connection dropped")
def render(self, request):
"""Render this request, from my server.
This will always be asynchronous, and therefore return NOT_DONE_YET.
It spins off a request to the pb client, and either adds it to the list
of pending issues or requests it immediately, depending on if the
client is already connected.
"""
if not self.publisher:
self.pending.append(request)
if not self.waiting:
self.waiting = 1
bf = pb.PBClientFactory()
timeout = 10
if self.host == "unix":
reactor.connectUNIX(self.port, bf, timeout)
else:
reactor.connectTCP(self.host, self.port, bf, timeout)
d = bf.getRootObject()
d.addCallbacks(self.connected, self.notConnected)
else:
i = Issue(request)
self.publisher.callRemote('request', request).addCallbacks(i.finished, i.failed)
return server.NOT_DONE_YET
class ResourcePublisher(pb.Root, styles.Versioned):
"""
L{ResourcePublisher} exposes a remote API which can be used to respond
to request.
@ivar site: The site which will be used for resource lookup.
@type site: L{twisted.web.server.Site}
"""
_log = Logger()
def __init__(self, site):
self.site = site
persistenceVersion = 2
def upgradeToVersion2(self):
self.application.authorizer.removeIdentity("web")
del self.application.services[self.serviceName]
del self.serviceName
del self.application
del self.perspectiveName
def getPerspectiveNamed(self, name):
return self
def remote_request(self, request):
"""
Look up the resource for the given request and render it.
"""
res = self.site.getResourceFor(request)
self._log.info(request)
result = res.render(request)
if result is not server.NOT_DONE_YET:
request.write(result)
request.finish()
return server.NOT_DONE_YET
class UserDirectory(resource.Resource):
"""
A resource which lists available user resources and serves them as
children.
@ivar _pwd: An object like L{pwd} which is used to enumerate users and
their home directories.
"""
userDirName = 'public_html'
userSocketName = '.twistd-web-pb'
template = """
<html>
<head>
<title>twisted.web.distrib.UserDirectory</title>
<style>
a
{
font-family: Lucida, Verdana, Helvetica, Arial, sans-serif;
color: #369;
text-decoration: none;
}
th
{
font-family: Lucida, Verdana, Helvetica, Arial, sans-serif;
font-weight: bold;
text-decoration: none;
text-align: left;
}
pre, code
{
font-family: "Courier New", Courier, monospace;
}
p, body, td, ol, ul, menu, blockquote, div
{
font-family: Lucida, Verdana, Helvetica, Arial, sans-serif;
color: #000;
}
</style>
</head>
<body>
<h1>twisted.web.distrib.UserDirectory</h1>
%(users)s
</body>
</html>
"""
def __init__(self, userDatabase=None):
resource.Resource.__init__(self)
if userDatabase is None:
userDatabase = pwd
self._pwd = userDatabase
def _users(self):
"""
Return a list of two-tuples giving links to user resources and text to
associate with those links.
"""
users = []
for user in self._pwd.getpwall():
name, passwd, uid, gid, gecos, dir, shell = user
realname = gecos.split(',')[0]
if not realname:
realname = name
if os.path.exists(os.path.join(dir, self.userDirName)):
users.append((name, realname + ' (file)'))
twistdsock = os.path.join(dir, self.userSocketName)
if os.path.exists(twistdsock):
linkName = name + '.twistd'
users.append((linkName, realname + ' (twistd)'))
return users
def render_GET(self, request):
"""
Render as HTML a listing of all known users with links to their
personal resources.
"""
domImpl = getDOMImplementation()
newDoc = domImpl.createDocument(None, "ul", None)
listing = newDoc.documentElement
for link, text in self._users():
linkElement = newDoc.createElement('a')
linkElement.setAttribute('href', link + '/')
textNode = newDoc.createTextNode(text)
linkElement.appendChild(textNode)
item = newDoc.createElement('li')
item.appendChild(linkElement)
listing.appendChild(item)
htmlDoc = self.template % ({'users': listing.toxml()})
return htmlDoc.encode("utf-8")
def getChild(self, name, request):
if name == '':
return self
td = '.twistd'
if name[-len(td):] == td:
username = name[:-len(td)]
sub = 1
else:
username = name
sub = 0
try:
pw_name, pw_passwd, pw_uid, pw_gid, pw_gecos, pw_dir, pw_shell \
= self._pwd.getpwnam(username)
except KeyError:
return resource.NoResource()
if sub:
twistdsock = os.path.join(pw_dir, self.userSocketName)
rs = ResourceSubscription('unix',twistdsock)
self.putChild(name, rs)
return rs
else:
path = os.path.join(pw_dir, self.userDirName)
if not os.path.exists(path):
return resource.NoResource()
return static.File(path)

View file

@ -0,0 +1,272 @@
# -*- test-case-name: twisted.web.test.test_domhelpers -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
A library for performing interesting tasks with DOM objects.
"""
from io import StringIO
from twisted.web import microdom
from twisted.web.microdom import getElementsByTagName, escape, unescape
# These modules are imported here as a shortcut.
escape
getElementsByTagName
class NodeLookupError(Exception):
pass
def substitute(request, node, subs):
"""
Look through the given node's children for strings, and
attempt to do string substitution with the given parameter.
"""
for child in node.childNodes:
if hasattr(child, 'nodeValue') and child.nodeValue:
child.replaceData(0, len(child.nodeValue), child.nodeValue % subs)
substitute(request, child, subs)
def _get(node, nodeId, nodeAttrs=('id','class','model','pattern')):
"""
(internal) Get a node with the specified C{nodeId} as any of the C{class},
C{id} or C{pattern} attributes.
"""
if hasattr(node, 'hasAttributes') and node.hasAttributes():
for nodeAttr in nodeAttrs:
if (str (node.getAttribute(nodeAttr)) == nodeId):
return node
if node.hasChildNodes():
if hasattr(node.childNodes, 'length'):
length = node.childNodes.length
else:
length = len(node.childNodes)
for childNum in range(length):
result = _get(node.childNodes[childNum], nodeId)
if result: return result
def get(node, nodeId):
"""
Get a node with the specified C{nodeId} as any of the C{class},
C{id} or C{pattern} attributes. If there is no such node, raise
L{NodeLookupError}.
"""
result = _get(node, nodeId)
if result: return result
raise NodeLookupError(nodeId)
def getIfExists(node, nodeId):
"""
Get a node with the specified C{nodeId} as any of the C{class},
C{id} or C{pattern} attributes. If there is no such node, return
L{None}.
"""
return _get(node, nodeId)
def getAndClear(node, nodeId):
"""Get a node with the specified C{nodeId} as any of the C{class},
C{id} or C{pattern} attributes. If there is no such node, raise
L{NodeLookupError}. Remove all child nodes before returning.
"""
result = get(node, nodeId)
if result:
clearNode(result)
return result
def clearNode(node):
"""
Remove all children from the given node.
"""
node.childNodes[:] = []
def locateNodes(nodeList, key, value, noNesting=1):
"""
Find subnodes in the given node where the given attribute
has the given value.
"""
returnList = []
if not isinstance(nodeList, type([])):
return locateNodes(nodeList.childNodes, key, value, noNesting)
for childNode in nodeList:
if not hasattr(childNode, 'getAttribute'):
continue
if str(childNode.getAttribute(key)) == value:
returnList.append(childNode)
if noNesting:
continue
returnList.extend(locateNodes(childNode, key, value, noNesting))
return returnList
def superSetAttribute(node, key, value):
if not hasattr(node, 'setAttribute'): return
node.setAttribute(key, value)
if node.hasChildNodes():
for child in node.childNodes:
superSetAttribute(child, key, value)
def superPrependAttribute(node, key, value):
if not hasattr(node, 'setAttribute'): return
old = node.getAttribute(key)
if old:
node.setAttribute(key, value+'/'+old)
else:
node.setAttribute(key, value)
if node.hasChildNodes():
for child in node.childNodes:
superPrependAttribute(child, key, value)
def superAppendAttribute(node, key, value):
if not hasattr(node, 'setAttribute'): return
old = node.getAttribute(key)
if old:
node.setAttribute(key, old + '/' + value)
else:
node.setAttribute(key, value)
if node.hasChildNodes():
for child in node.childNodes:
superAppendAttribute(child, key, value)
def gatherTextNodes(iNode, dounescape=0, joinWith=""):
"""Visit each child node and collect its text data, if any, into a string.
For example::
>>> doc=microdom.parseString('<a>1<b>2<c>3</c>4</b></a>')
>>> gatherTextNodes(doc.documentElement)
'1234'
With dounescape=1, also convert entities back into normal characters.
@return: the gathered nodes as a single string
@rtype: str
"""
gathered=[]
gathered_append=gathered.append
slice=[iNode]
while len(slice)>0:
c=slice.pop(0)
if hasattr(c, 'nodeValue') and c.nodeValue is not None:
if dounescape:
val=unescape(c.nodeValue)
else:
val=c.nodeValue
gathered_append(val)
slice[:0]=c.childNodes
return joinWith.join(gathered)
class RawText(microdom.Text):
"""This is an evil and horrible speed hack. Basically, if you have a big
chunk of XML that you want to insert into the DOM, but you don't want to
incur the cost of parsing it, you can construct one of these and insert it
into the DOM. This will most certainly only work with microdom as the API
for converting nodes to xml is different in every DOM implementation.
This could be improved by making this class a Lazy parser, so if you
inserted this into the DOM and then later actually tried to mutate this
node, it would be parsed then.
"""
def writexml(self, writer, indent="", addindent="", newl="", strip=0, nsprefixes=None, namespace=None):
writer.write("%s%s%s" % (indent, self.data, newl))
def findNodes(parent, matcher, accum=None):
if accum is None:
accum = []
if not parent.hasChildNodes():
return accum
for child in parent.childNodes:
# print child, child.nodeType, child.nodeName
if matcher(child):
accum.append(child)
findNodes(child, matcher, accum)
return accum
def findNodesShallowOnMatch(parent, matcher, recurseMatcher, accum=None):
if accum is None:
accum = []
if not parent.hasChildNodes():
return accum
for child in parent.childNodes:
# print child, child.nodeType, child.nodeName
if matcher(child):
accum.append(child)
if recurseMatcher(child):
findNodesShallowOnMatch(child, matcher, recurseMatcher, accum)
return accum
def findNodesShallow(parent, matcher, accum=None):
if accum is None:
accum = []
if not parent.hasChildNodes():
return accum
for child in parent.childNodes:
if matcher(child):
accum.append(child)
else:
findNodes(child, matcher, accum)
return accum
def findElementsWithAttributeShallow(parent, attribute):
"""
Return an iterable of the elements which are direct children of C{parent}
and which have the C{attribute} attribute.
"""
return findNodesShallow(parent,
lambda n: getattr(n, 'tagName', None) is not None and
n.hasAttribute(attribute))
def findElements(parent, matcher):
"""
Return an iterable of the elements which are children of C{parent} for
which the predicate C{matcher} returns true.
"""
return findNodes(
parent,
lambda n, matcher=matcher: getattr(n, 'tagName', None) is not None and
matcher(n))
def findElementsWithAttribute(parent, attribute, value=None):
if value:
return findElements(
parent,
lambda n, attribute=attribute, value=value:
n.hasAttribute(attribute) and n.getAttribute(attribute) == value)
else:
return findElements(
parent,
lambda n, attribute=attribute: n.hasAttribute(attribute))
def findNodesNamed(parent, name):
return findNodes(parent, lambda n, name=name: n.nodeName == name)
def writeNodeData(node, oldio):
for subnode in node.childNodes:
if hasattr(subnode, 'data'):
oldio.write(u"" + subnode.data)
else:
writeNodeData(subnode, oldio)
def getNodeText(node):
oldio = StringIO()
writeNodeData(node, oldio)
return oldio.getvalue()
def getParents(node):
l = []
while node:
l.append(node)
node = node.parentNode
return l
def namedChildren(parent, nodeName):
"""namedChildren(parent, nodeName) -> children (not descendants) of parent
that have tagName == nodeName
"""
return [n for n in parent.childNodes if getattr(n, 'tagName', '')==nodeName]

View file

@ -0,0 +1,407 @@
# -*- test-case-name: twisted.web.test.test_error -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Exception definitions for L{twisted.web}.
"""
from __future__ import division, absolute_import
try:
from future_builtins import ascii
except ImportError:
pass
__all__ = [
'Error', 'PageRedirect', 'InfiniteRedirection', 'RenderError',
'MissingRenderMethod', 'MissingTemplateLoader', 'UnexposedMethodError',
'UnfilledSlot', 'UnsupportedType', 'FlattenerError',
'RedirectWithNoLocation',
]
from twisted.web._responses import RESPONSES
from twisted.python.compat import unicode, nativeString, intToBytes, Sequence
def _codeToMessage(code):
"""
Returns the response message corresponding to an HTTP code, or None
if the code is unknown or unrecognized.
@type code: L{bytes}
@param code: Refers to an HTTP status code, for example C{http.NOT_FOUND}.
@return: A string message or none
@rtype: L{bytes}
"""
try:
return RESPONSES.get(int(code))
except (ValueError, AttributeError):
return None
class Error(Exception):
"""
A basic HTTP error.
@type status: L{bytes}
@ivar status: Refers to an HTTP status code, for example C{http.NOT_FOUND}.
@type message: L{bytes}
@param message: A short error message, for example "NOT FOUND".
@type response: L{bytes}
@ivar response: A complete HTML document for an error page.
"""
def __init__(self, code, message=None, response=None):
"""
Initializes a basic exception.
@type code: L{bytes} or L{int}
@param code: Refers to an HTTP status code (for example, 200) either as
an integer or a bytestring representing such. If no C{message} is
given, C{code} is mapped to a descriptive bytestring that is used
instead.
@type message: L{bytes}
@param message: A short error message, for example "NOT FOUND".
@type response: L{bytes}
@param response: A complete HTML document for an error page.
"""
message = message or _codeToMessage(code)
Exception.__init__(self, code, message, response)
if isinstance(code, int):
# If we're given an int, convert it to a bytestring
# downloadPage gives a bytes, Agent gives an int, and it worked by
# accident previously, so just make it keep working.
code = intToBytes(code)
self.status = code
self.message = message
self.response = response
def __str__(self):
return nativeString(self.status + b" " + self.message)
class PageRedirect(Error):
"""
A request resulted in an HTTP redirect.
@type location: L{bytes}
@ivar location: The location of the redirect which was not followed.
"""
def __init__(self, code, message=None, response=None, location=None):
"""
Initializes a page redirect exception.
@type code: L{bytes}
@param code: Refers to an HTTP status code, for example
C{http.NOT_FOUND}. If no C{message} is given, C{code} is mapped to a
descriptive string that is used instead.
@type message: L{bytes}
@param message: A short error message, for example "NOT FOUND".
@type response: L{bytes}
@param response: A complete HTML document for an error page.
@type location: L{bytes}
@param location: The location response-header field value. It is an
absolute URI used to redirect the receiver to a location other than
the Request-URI so the request can be completed.
"""
Error.__init__(self, code, message, response)
if self.message and location:
self.message = self.message + b" to " + location
self.location = location
class InfiniteRedirection(Error):
"""
HTTP redirection is occurring endlessly.
@type location: L{bytes}
@ivar location: The first URL in the series of redirections which was
not followed.
"""
def __init__(self, code, message=None, response=None, location=None):
"""
Initializes an infinite redirection exception.
@type code: L{bytes}
@param code: Refers to an HTTP status code, for example
C{http.NOT_FOUND}. If no C{message} is given, C{code} is mapped to a
descriptive string that is used instead.
@type message: L{bytes}
@param message: A short error message, for example "NOT FOUND".
@type response: L{bytes}
@param response: A complete HTML document for an error page.
@type location: L{bytes}
@param location: The location response-header field value. It is an
absolute URI used to redirect the receiver to a location other than
the Request-URI so the request can be completed.
"""
Error.__init__(self, code, message, response)
if self.message and location:
self.message = self.message + b" to " + location
self.location = location
class RedirectWithNoLocation(Error):
"""
Exception passed to L{ResponseFailed} if we got a redirect without a
C{Location} header field.
@type uri: L{bytes}
@ivar uri: The URI which failed to give a proper location header
field.
@since: 11.1
"""
def __init__(self, code, message, uri):
"""
Initializes a page redirect exception when no location is given.
@type code: L{bytes}
@param code: Refers to an HTTP status code, for example
C{http.NOT_FOUND}. If no C{message} is given, C{code} is mapped to
a descriptive string that is used instead.
@type message: L{bytes}
@param message: A short error message.
@type uri: L{bytes}
@param uri: The URI which failed to give a proper location header
field.
"""
Error.__init__(self, code, message)
self.message = self.message + b" to " + uri
self.uri = uri
class UnsupportedMethod(Exception):
"""
Raised by a resource when faced with a strange request method.
RFC 2616 (HTTP 1.1) gives us two choices when faced with this situation:
If the type of request is known to us, but not allowed for the requested
resource, respond with NOT_ALLOWED. Otherwise, if the request is something
we don't know how to deal with in any case, respond with NOT_IMPLEMENTED.
When this exception is raised by a Resource's render method, the server
will make the appropriate response.
This exception's first argument MUST be a sequence of the methods the
resource *does* support.
"""
allowedMethods = ()
def __init__(self, allowedMethods, *args):
Exception.__init__(self, allowedMethods, *args)
self.allowedMethods = allowedMethods
if not isinstance(allowedMethods, Sequence):
raise TypeError(
"First argument must be a sequence of supported methods, "
"but my first argument is not a sequence.")
def __str__(self):
return "Expected one of %r" % (self.allowedMethods,)
class SchemeNotSupported(Exception):
"""
The scheme of a URI was not one of the supported values.
"""
class RenderError(Exception):
"""
Base exception class for all errors which can occur during template
rendering.
"""
class MissingRenderMethod(RenderError):
"""
Tried to use a render method which does not exist.
@ivar element: The element which did not have the render method.
@ivar renderName: The name of the renderer which could not be found.
"""
def __init__(self, element, renderName):
RenderError.__init__(self, element, renderName)
self.element = element
self.renderName = renderName
def __repr__(self):
return '%r: %r had no render method named %r' % (
self.__class__.__name__, self.element, self.renderName)
class MissingTemplateLoader(RenderError):
"""
L{MissingTemplateLoader} is raised when trying to render an Element without
a template loader, i.e. a C{loader} attribute.
@ivar element: The Element which did not have a document factory.
"""
def __init__(self, element):
RenderError.__init__(self, element)
self.element = element
def __repr__(self):
return '%r: %r had no loader' % (self.__class__.__name__,
self.element)
class UnexposedMethodError(Exception):
"""
Raised on any attempt to get a method which has not been exposed.
"""
class UnfilledSlot(Exception):
"""
During flattening, a slot with no associated data was encountered.
"""
class UnsupportedType(Exception):
"""
During flattening, an object of a type which cannot be flattened was
encountered.
"""
class ExcessiveBufferingError(Exception):
"""
The HTTP/2 protocol has been forced to buffer an excessive amount of
outbound data, and has therefore closed the connection and dropped all
outbound data.
"""
class FlattenerError(Exception):
"""
An error occurred while flattening an object.
@ivar _roots: A list of the objects on the flattener's stack at the time
the unflattenable object was encountered. The first element is least
deeply nested object and the last element is the most deeply nested.
"""
def __init__(self, exception, roots, traceback):
self._exception = exception
self._roots = roots
self._traceback = traceback
Exception.__init__(self, exception, roots, traceback)
def _formatRoot(self, obj):
"""
Convert an object from C{self._roots} to a string suitable for
inclusion in a render-traceback (like a normal Python traceback, but
can include "frame" source locations which are not in Python source
files).
@param obj: Any object which can be a render step I{root}.
Typically, L{Tag}s, strings, and other simple Python types.
@return: A string representation of C{obj}.
@rtype: L{str}
"""
# There's a circular dependency between this class and 'Tag', although
# only for an isinstance() check.
from twisted.web.template import Tag
if isinstance(obj, (bytes, str, unicode)):
# It's somewhat unlikely that there will ever be a str in the roots
# list. However, something like a MemoryError during a str.replace
# call (eg, replacing " with &quot;) could possibly cause this.
# Likewise, UTF-8 encoding a unicode string to a byte string might
# fail like this.
if len(obj) > 40:
if isinstance(obj, unicode):
ellipsis = u'<...>'
else:
ellipsis = b'<...>'
return ascii(obj[:20] + ellipsis + obj[-20:])
else:
return ascii(obj)
elif isinstance(obj, Tag):
if obj.filename is None:
return 'Tag <' + obj.tagName + '>'
else:
return "File \"%s\", line %d, column %d, in \"%s\"" % (
obj.filename, obj.lineNumber,
obj.columnNumber, obj.tagName)
else:
return ascii(obj)
def __repr__(self):
"""
Present a string representation which includes a template traceback, so
we can tell where this error occurred in the template, as well as in
Python.
"""
# Avoid importing things unnecessarily until we actually need them;
# since this is an 'error' module we should be extra paranoid about
# that.
from traceback import format_list
if self._roots:
roots = ' ' + '\n '.join([
self._formatRoot(r) for r in self._roots]) + '\n'
else:
roots = ''
if self._traceback:
traceback = '\n'.join([
line
for entry in format_list(self._traceback)
for line in entry.splitlines()]) + '\n'
else:
traceback = ''
return (
'Exception while flattening:\n' +
roots + traceback +
self._exception.__class__.__name__ + ': ' +
str(self._exception) + '\n')
def __str__(self):
return repr(self)
class UnsupportedSpecialHeader(Exception):
"""
A HTTP/2 request was received that contained a HTTP/2 pseudo-header field
that is not recognised by Twisted.
"""

View file

@ -0,0 +1,20 @@
# -*- test-case-name: twisted.web.test.test_httpauth -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Resource traversal integration with L{twisted.cred} to allow for
authentication and authorization of HTTP requests.
"""
from __future__ import division, absolute_import
# Expose HTTP authentication classes here.
from twisted.web._auth.wrapper import HTTPAuthSessionWrapper
from twisted.web._auth.basic import BasicCredentialFactory
from twisted.web._auth.digest import DigestCredentialFactory
__all__ = [
"HTTPAuthSessionWrapper",
"BasicCredentialFactory", "DigestCredentialFactory"]

View file

@ -0,0 +1,57 @@
# -*- test-case-name: twisted.web.test.test_html -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""I hold HTML generation helpers.
"""
from twisted.python import log
from twisted.python.compat import NativeStringIO as StringIO, escape
from twisted.python.deprecate import deprecated
from incremental import Version
@deprecated(Version('Twisted', 15, 3, 0), replacement='twisted.web.template')
def PRE(text):
"Wrap <pre> tags around some text and HTML-escape it."
return "<pre>"+escape(text)+"</pre>"
@deprecated(Version('Twisted', 15, 3, 0), replacement='twisted.web.template')
def UL(lst):
io = StringIO()
io.write("<ul>\n")
for el in lst:
io.write("<li> %s</li>\n" % el)
io.write("</ul>")
return io.getvalue()
@deprecated(Version('Twisted', 15, 3, 0), replacement='twisted.web.template')
def linkList(lst):
io = StringIO()
io.write("<ul>\n")
for hr, el in lst:
io.write('<li> <a href="%s">%s</a></li>\n' % (hr, el))
io.write("</ul>")
return io.getvalue()
@deprecated(Version('Twisted', 15, 3, 0), replacement='twisted.web.template')
def output(func, *args, **kw):
"""output(func, *args, **kw) -> html string
Either return the result of a function (which presumably returns an
HTML-legal string) or a sparse HTMLized error message and a message
in the server log.
"""
try:
return func(*args, **kw)
except:
log.msg("Error calling %r:" % (func,))
log.err()
return PRE("An error occurred.")

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,294 @@
# -*- test-case-name: twisted.web.test.test_http_headers -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
An API for storing HTTP header names and values.
"""
from __future__ import division, absolute_import
from twisted.python.compat import comparable, cmp, unicode
def _dashCapitalize(name):
"""
Return a byte string which is capitalized using '-' as a word separator.
@param name: The name of the header to capitalize.
@type name: L{bytes}
@return: The given header capitalized using '-' as a word separator.
@rtype: L{bytes}
"""
return b'-'.join([word.capitalize() for word in name.split(b'-')])
def _sanitizeLinearWhitespace(headerComponent):
r"""
Replace linear whitespace (C{\n}, C{\r\n}, C{\r}) in a header key
or value with a single space. If C{headerComponent} is not
L{bytes}, it is passed through unchanged.
@param headerComponent: The header key or value to sanitize.
@type headerComponent: L{bytes}
@return: The sanitized header key or value.
@rtype: L{bytes}
"""
return b' '.join(headerComponent.splitlines())
@comparable
class Headers(object):
"""
Stores HTTP headers in a key and multiple value format.
Most methods accept L{bytes} and L{unicode}, with an internal L{bytes}
representation. When passed L{unicode}, header names (e.g. 'Content-Type')
are encoded using ISO-8859-1 and header values (e.g.
'text/html;charset=utf-8') are encoded using UTF-8. Some methods that return
values will return them in the same type as the name given.
If the header keys or values cannot be encoded or decoded using the rules
above, using just L{bytes} arguments to the methods of this class will
ensure no decoding or encoding is done, and L{Headers} will treat the keys
and values as opaque byte strings.
@cvar _caseMappings: A L{dict} that maps lowercase header names
to their canonicalized representation.
@ivar _rawHeaders: A L{dict} mapping header names as L{bytes} to L{list}s of
header values as L{bytes}.
"""
_caseMappings = {
b'content-md5': b'Content-MD5',
b'dnt': b'DNT',
b'etag': b'ETag',
b'p3p': b'P3P',
b'te': b'TE',
b'www-authenticate': b'WWW-Authenticate',
b'x-xss-protection': b'X-XSS-Protection'}
def __init__(self, rawHeaders=None):
self._rawHeaders = {}
if rawHeaders is not None:
for name, values in rawHeaders.items():
self.setRawHeaders(name, values)
def __repr__(self):
"""
Return a string fully describing the headers set on this object.
"""
return '%s(%r)' % (self.__class__.__name__, self._rawHeaders,)
def __cmp__(self, other):
"""
Define L{Headers} instances as being equal to each other if they have
the same raw headers.
"""
if isinstance(other, Headers):
return cmp(
sorted(self._rawHeaders.items()),
sorted(other._rawHeaders.items()))
return NotImplemented
def _encodeName(self, name):
"""
Encode the name of a header (eg 'Content-Type') to an ISO-8859-1 encoded
bytestring if required.
@param name: A HTTP header name
@type name: L{unicode} or L{bytes}
@return: C{name}, encoded if required, lowercased
@rtype: L{bytes}
"""
if isinstance(name, unicode):
return name.lower().encode('iso-8859-1')
return name.lower()
def _encodeValue(self, value):
"""
Encode a single header value to a UTF-8 encoded bytestring if required.
@param value: A single HTTP header value.
@type value: L{bytes} or L{unicode}
@return: C{value}, encoded if required
@rtype: L{bytes}
"""
if isinstance(value, unicode):
return value.encode('utf8')
return value
def _encodeValues(self, values):
"""
Encode a L{list} of header values to a L{list} of UTF-8 encoded
bytestrings if required.
@param values: A list of HTTP header values.
@type values: L{list} of L{bytes} or L{unicode} (mixed types allowed)
@return: C{values}, with each item encoded if required
@rtype: L{list} of L{bytes}
"""
newValues = []
for value in values:
newValues.append(self._encodeValue(value))
return newValues
def _decodeValues(self, values):
"""
Decode a L{list} of header values into a L{list} of Unicode strings.
@param values: A list of HTTP header values.
@type values: L{list} of UTF-8 encoded L{bytes}
@return: C{values}, with each item decoded
@rtype: L{list} of L{unicode}
"""
newValues = []
for value in values:
newValues.append(value.decode('utf8'))
return newValues
def copy(self):
"""
Return a copy of itself with the same headers set.
@return: A new L{Headers}
"""
return self.__class__(self._rawHeaders)
def hasHeader(self, name):
"""
Check for the existence of a given header.
@type name: L{bytes} or L{unicode}
@param name: The name of the HTTP header to check for.
@rtype: L{bool}
@return: C{True} if the header exists, otherwise C{False}.
"""
return self._encodeName(name) in self._rawHeaders
def removeHeader(self, name):
"""
Remove the named header from this header object.
@type name: L{bytes} or L{unicode}
@param name: The name of the HTTP header to remove.
@return: L{None}
"""
self._rawHeaders.pop(self._encodeName(name), None)
def setRawHeaders(self, name, values):
"""
Sets the raw representation of the given header.
@type name: L{bytes} or L{unicode}
@param name: The name of the HTTP header to set the values for.
@type values: L{list} of L{bytes} or L{unicode} strings
@param values: A list of strings each one being a header value of
the given name.
@return: L{None}
"""
if not isinstance(values, list):
raise TypeError("Header entry %r should be list but found "
"instance of %r instead" % (name, type(values)))
name = _sanitizeLinearWhitespace(self._encodeName(name))
encodedValues = [_sanitizeLinearWhitespace(v)
for v in self._encodeValues(values)]
self._rawHeaders[name] = self._encodeValues(encodedValues)
def addRawHeader(self, name, value):
"""
Add a new raw value for the given header.
@type name: L{bytes} or L{unicode}
@param name: The name of the header for which to set the value.
@type value: L{bytes} or L{unicode}
@param value: The value to set for the named header.
"""
values = self.getRawHeaders(name)
if values is not None:
values.append(value)
else:
values = [value]
self.setRawHeaders(name, values)
def getRawHeaders(self, name, default=None):
"""
Returns a list of headers matching the given name as the raw string
given.
@type name: L{bytes} or L{unicode}
@param name: The name of the HTTP header to get the values of.
@param default: The value to return if no header with the given C{name}
exists.
@rtype: L{list} of strings, same type as C{name} (except when
C{default} is returned).
@return: If the named header is present, a L{list} of its
values. Otherwise, C{default}.
"""
encodedName = self._encodeName(name)
values = self._rawHeaders.get(encodedName, default)
if isinstance(name, unicode) and values is not default:
return self._decodeValues(values)
return values
def getAllRawHeaders(self):
"""
Return an iterator of key, value pairs of all headers contained in this
object, as L{bytes}. The keys are capitalized in canonical
capitalization.
"""
for k, v in self._rawHeaders.items():
yield self._canonicalNameCaps(k), v
def _canonicalNameCaps(self, name):
"""
Return the canonical name for the given header.
@type name: L{bytes}
@param name: The all-lowercase header name to capitalize in its
canonical form.
@rtype: L{bytes}
@return: The canonical name of the header.
"""
return self._caseMappings.get(name, _dashCapitalize(name))
__all__ = ['Headers']

View file

@ -0,0 +1,828 @@
# -*- test-case-name: twisted.web.test -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Interface definitions for L{twisted.web}.
@var UNKNOWN_LENGTH: An opaque object which may be used as the value of
L{IBodyProducer.length} to indicate that the length of the entity
body is not known in advance.
"""
from zope.interface import Interface, Attribute
from twisted.internet.interfaces import IPushProducer
from twisted.cred.credentials import IUsernameDigestHash
class IRequest(Interface):
"""
An HTTP request.
@since: 9.0
"""
method = Attribute("A L{bytes} giving the HTTP method that was used.")
uri = Attribute(
"A L{bytes} giving the full encoded URI which was requested (including"
" query arguments).")
path = Attribute(
"A L{bytes} giving the encoded query path of the request URI (not "
"including query arguments).")
args = Attribute(
"A mapping of decoded query argument names as L{bytes} to "
"corresponding query argument values as L{list}s of L{bytes}. "
"For example, for a URI with C{foo=bar&foo=baz&quux=spam} "
"for its query part, C{args} will be C{{b'foo': [b'bar', b'baz'], "
"b'quux': [b'spam']}}.")
prepath = Attribute(
"The URL path segments which have been processed during resource "
"traversal, as a list of {bytes}.")
postpath = Attribute(
"The URL path segments which have not (yet) been processed "
"during resource traversal, as a list of L{bytes}.")
requestHeaders = Attribute(
"A L{http_headers.Headers} instance giving all received HTTP request "
"headers.")
content = Attribute(
"A file-like object giving the request body. This may be a file on "
"disk, an L{io.BytesIO}, or some other type. The implementation is "
"free to decide on a per-request basis.")
responseHeaders = Attribute(
"A L{http_headers.Headers} instance holding all HTTP response "
"headers to be sent.")
def getHeader(key):
"""
Get an HTTP request header.
@type key: L{bytes} or L{str}
@param key: The name of the header to get the value of.
@rtype: L{bytes} or L{str} or L{None}
@return: The value of the specified header, or L{None} if that header
was not present in the request. The string type of the result
matches the type of C{key}.
"""
def getCookie(key):
"""
Get a cookie that was sent from the network.
@type key: L{bytes}
@param key: The name of the cookie to get.
@rtype: L{bytes} or L{None}
@returns: The value of the specified cookie, or L{None} if that cookie
was not present in the request.
"""
def getAllHeaders():
"""
Return dictionary mapping the names of all received headers to the last
value received for each.
Since this method does not return all header information,
C{requestHeaders.getAllRawHeaders()} may be preferred.
"""
def getRequestHostname():
"""
Get the hostname that the user passed in to the request.
This will either use the Host: header (if it is available) or the
host we are listening on if the header is unavailable.
@returns: the requested hostname
@rtype: L{str}
"""
def getHost():
"""
Get my originally requesting transport's host.
@return: An L{IAddress<twisted.internet.interfaces.IAddress>}.
"""
def getClientAddress():
"""
Return the address of the client who submitted this request.
The address may not be a network address. Callers must check
its type before using it.
@since: 18.4
@return: the client's address.
@rtype: an L{IAddress} provider.
"""
def getClientIP():
"""
Return the IP address of the client who submitted this request.
This method is B{deprecated}. See L{getClientAddress} instead.
@returns: the client IP address or L{None} if the request was submitted
over a transport where IP addresses do not make sense.
@rtype: L{str} or L{None}
"""
def getUser():
"""
Return the HTTP user sent with this request, if any.
If no user was supplied, return the empty string.
@returns: the HTTP user, if any
@rtype: L{str}
"""
def getPassword():
"""
Return the HTTP password sent with this request, if any.
If no password was supplied, return the empty string.
@returns: the HTTP password, if any
@rtype: L{str}
"""
def isSecure():
"""
Return True if this request is using a secure transport.
Normally this method returns True if this request's HTTPChannel
instance is using a transport that implements ISSLTransport.
This will also return True if setHost() has been called
with ssl=True.
@returns: True if this request is secure
@rtype: C{bool}
"""
def getSession(sessionInterface=None):
"""
Look up the session associated with this request or create a new one if
there is not one.
@return: The L{Session} instance identified by the session cookie in
the request, or the C{sessionInterface} component of that session
if C{sessionInterface} is specified.
"""
def URLPath():
"""
@return: A L{URLPath<twisted.python.urlpath.URLPath>} instance
which identifies the URL for which this request is.
"""
def prePathURL():
"""
At any time during resource traversal or resource rendering,
returns an absolute URL to the most nested resource which has
yet been reached.
@see: {twisted.web.server.Request.prepath}
@return: An absolute URL.
@type: L{bytes}
"""
def rememberRootURL():
"""
Remember the currently-processed part of the URL for later
recalling.
"""
def getRootURL():
"""
Get a previously-remembered URL.
@return: An absolute URL.
@type: L{bytes}
"""
# Methods for outgoing response
def finish():
"""
Indicate that the response to this request is complete.
"""
def write(data):
"""
Write some data to the body of the response to this request. Response
headers are written the first time this method is called, after which
new response headers may not be added.
@param data: Bytes of the response body.
@type data: L{bytes}
"""
def addCookie(k, v, expires=None, domain=None, path=None, max_age=None, comment=None, secure=None):
"""
Set an outgoing HTTP cookie.
In general, you should consider using sessions instead of cookies, see
L{twisted.web.server.Request.getSession} and the
L{twisted.web.server.Session} class for details.
"""
def setResponseCode(code, message=None):
"""
Set the HTTP response code.
@type code: L{int}
@type message: L{bytes}
"""
def setHeader(k, v):
"""
Set an HTTP response header. Overrides any previously set values for
this header.
@type k: L{bytes} or L{str}
@param k: The name of the header for which to set the value.
@type v: L{bytes} or L{str}
@param v: The value to set for the named header. A L{str} will be
UTF-8 encoded, which may not interoperable with other
implementations. Avoid passing non-ASCII characters if possible.
"""
def redirect(url):
"""
Utility function that does a redirect.
The request should have finish() called after this.
"""
def setLastModified(when):
"""
Set the C{Last-Modified} time for the response to this request.
If I am called more than once, I ignore attempts to set Last-Modified
earlier, only replacing the Last-Modified time if it is to a later
value.
If I am a conditional request, I may modify my response code to
L{NOT_MODIFIED<http.NOT_MODIFIED>} if appropriate for the time given.
@param when: The last time the resource being returned was modified, in
seconds since the epoch.
@type when: L{int}, L{long} or L{float}
@return: If I am a C{If-Modified-Since} conditional request and the time
given is not newer than the condition, I return
L{CACHED<http.CACHED>} to indicate that you should write no body.
Otherwise, I return a false value.
"""
def setETag(etag):
"""
Set an C{entity tag} for the outgoing response.
That's "entity tag" as in the HTTP/1.1 I{ETag} header, "used for
comparing two or more entities from the same requested resource."
If I am a conditional request, I may modify my response code to
L{NOT_MODIFIED<http.NOT_MODIFIED>} or
L{PRECONDITION_FAILED<http.PRECONDITION_FAILED>}, if appropriate for the
tag given.
@param etag: The entity tag for the resource being returned.
@type etag: L{str}
@return: If I am a C{If-None-Match} conditional request and the tag
matches one in the request, I return L{CACHED<http.CACHED>} to
indicate that you should write no body. Otherwise, I return a
false value.
"""
def setHost(host, port, ssl=0):
"""
Change the host and port the request thinks it's using.
This method is useful for working with reverse HTTP proxies (e.g. both
Squid and Apache's mod_proxy can do this), when the address the HTTP
client is using is different than the one we're listening on.
For example, Apache may be listening on https://www.example.com, and
then forwarding requests to http://localhost:8080, but we don't want
HTML produced by Twisted to say 'http://localhost:8080', they should
say 'https://www.example.com', so we do::
request.setHost('www.example.com', 443, ssl=1)
"""
class INonQueuedRequestFactory(Interface):
"""
A factory of L{IRequest} objects that does not take a ``queued`` parameter.
"""
def __call__(channel):
"""
Create an L{IRequest} that is operating on the given channel. There
must only be one L{IRequest} object processing at any given time on a
channel.
@param channel: A L{twisted.web.http.HTTPChannel} object.
@type channel: L{twisted.web.http.HTTPChannel}
@return: A request object.
@rtype: L{IRequest}
"""
class IAccessLogFormatter(Interface):
"""
An object which can represent an HTTP request as a line of text for
inclusion in an access log file.
"""
def __call__(timestamp, request):
"""
Generate a line for the access log.
@param timestamp: The time at which the request was completed in the
standard format for access logs.
@type timestamp: L{unicode}
@param request: The request object about which to log.
@type request: L{twisted.web.server.Request}
@return: One line describing the request without a trailing newline.
@rtype: L{unicode}
"""
class ICredentialFactory(Interface):
"""
A credential factory defines a way to generate a particular kind of
authentication challenge and a way to interpret the responses to these
challenges. It creates
L{ICredentials<twisted.cred.credentials.ICredentials>} providers from
responses. These objects will be used with L{twisted.cred} to authenticate
an authorize requests.
"""
scheme = Attribute(
"A L{str} giving the name of the authentication scheme with which "
"this factory is associated. For example, C{'basic'} or C{'digest'}.")
def getChallenge(request):
"""
Generate a new challenge to be sent to a client.
@type peer: L{twisted.web.http.Request}
@param peer: The request the response to which this challenge will be
included.
@rtype: L{dict}
@return: A mapping from L{str} challenge fields to associated L{str}
values.
"""
def decode(response, request):
"""
Create a credentials object from the given response.
@type response: L{str}
@param response: scheme specific response string
@type request: L{twisted.web.http.Request}
@param request: The request being processed (from which the response
was taken).
@raise twisted.cred.error.LoginFailed: If the response is invalid.
@rtype: L{twisted.cred.credentials.ICredentials} provider
@return: The credentials represented by the given response.
"""
class IBodyProducer(IPushProducer):
"""
Objects which provide L{IBodyProducer} write bytes to an object which
provides L{IConsumer<twisted.internet.interfaces.IConsumer>} by calling its
C{write} method repeatedly.
L{IBodyProducer} providers may start producing as soon as they have an
L{IConsumer<twisted.internet.interfaces.IConsumer>} provider. That is, they
should not wait for a C{resumeProducing} call to begin writing data.
L{IConsumer.unregisterProducer<twisted.internet.interfaces.IConsumer.unregisterProducer>}
must not be called. Instead, the
L{Deferred<twisted.internet.defer.Deferred>} returned from C{startProducing}
must be fired when all bytes have been written.
L{IConsumer.write<twisted.internet.interfaces.IConsumer.write>} may
synchronously invoke any of C{pauseProducing}, C{resumeProducing}, or
C{stopProducing}. These methods must be implemented with this in mind.
@since: 9.0
"""
# Despite the restrictions above and the additional requirements of
# stopProducing documented below, this interface still needs to be an
# IPushProducer subclass. Providers of it will be passed to IConsumer
# providers which only know about IPushProducer and IPullProducer, not
# about this interface. This interface needs to remain close enough to one
# of those interfaces for consumers to work with it.
length = Attribute(
"""
C{length} is a L{int} indicating how many bytes in total this
L{IBodyProducer} will write to the consumer or L{UNKNOWN_LENGTH}
if this is not known in advance.
""")
def startProducing(consumer):
"""
Start producing to the given
L{IConsumer<twisted.internet.interfaces.IConsumer>} provider.
@return: A L{Deferred<twisted.internet.defer.Deferred>} which stops
production of data when L{Deferred.cancel} is called, and which
fires with L{None} when all bytes have been produced or with a
L{Failure<twisted.python.failure.Failure>} if there is any problem
before all bytes have been produced.
"""
def stopProducing():
"""
In addition to the standard behavior of
L{IProducer.stopProducing<twisted.internet.interfaces.IProducer.stopProducing>}
(stop producing data), make sure the
L{Deferred<twisted.internet.defer.Deferred>} returned by
C{startProducing} is never fired.
"""
class IRenderable(Interface):
"""
An L{IRenderable} is an object that may be rendered by the
L{twisted.web.template} templating system.
"""
def lookupRenderMethod(name):
"""
Look up and return the render method associated with the given name.
@type name: L{str}
@param name: The value of a render directive encountered in the
document returned by a call to L{IRenderable.render}.
@return: A two-argument callable which will be invoked with the request
being responded to and the tag object on which the render directive
was encountered.
"""
def render(request):
"""
Get the document for this L{IRenderable}.
@type request: L{IRequest} provider or L{None}
@param request: The request in response to which this method is being
invoked.
@return: An object which can be flattened.
"""
class ITemplateLoader(Interface):
"""
A loader for templates; something usable as a value for
L{twisted.web.template.Element}'s C{loader} attribute.
"""
def load():
"""
Load a template suitable for rendering.
@return: a L{list} of L{list}s, L{unicode} objects, C{Element}s and
other L{IRenderable} providers.
"""
class IResponse(Interface):
"""
An object representing an HTTP response received from an HTTP server.
@since: 11.1
"""
version = Attribute(
"A three-tuple describing the protocol and protocol version "
"of the response. The first element is of type L{str}, the second "
"and third are of type L{int}. For example, C{(b'HTTP', 1, 1)}.")
code = Attribute("The HTTP status code of this response, as a L{int}.")
phrase = Attribute(
"The HTTP reason phrase of this response, as a L{str}.")
headers = Attribute("The HTTP response L{Headers} of this response.")
length = Attribute(
"The L{int} number of bytes expected to be in the body of this "
"response or L{UNKNOWN_LENGTH} if the server did not indicate how "
"many bytes to expect. For I{HEAD} responses, this will be 0; if "
"the response includes a I{Content-Length} header, it will be "
"available in C{headers}.")
request = Attribute(
"The L{IClientRequest} that resulted in this response.")
previousResponse = Attribute(
"The previous L{IResponse} from a redirect, or L{None} if there was no "
"previous response. This can be used to walk the response or request "
"history for redirections.")
def deliverBody(protocol):
"""
Register an L{IProtocol<twisted.internet.interfaces.IProtocol>} provider
to receive the response body.
The protocol will be connected to a transport which provides
L{IPushProducer}. The protocol's C{connectionLost} method will be
called with:
- ResponseDone, which indicates that all bytes from the response
have been successfully delivered.
- PotentialDataLoss, which indicates that it cannot be determined
if the entire response body has been delivered. This only occurs
when making requests to HTTP servers which do not set
I{Content-Length} or a I{Transfer-Encoding} in the response.
- ResponseFailed, which indicates that some bytes from the response
were lost. The C{reasons} attribute of the exception may provide
more specific indications as to why.
"""
def setPreviousResponse(response):
"""
Set the reference to the previous L{IResponse}.
The value of the previous response can be read via
L{IResponse.previousResponse}.
"""
class _IRequestEncoder(Interface):
"""
An object encoding data passed to L{IRequest.write}, for example for
compression purpose.
@since: 12.3
"""
def encode(data):
"""
Encode the data given and return the result.
@param data: The content to encode.
@type data: L{str}
@return: The encoded data.
@rtype: L{str}
"""
def finish():
"""
Callback called when the request is closing.
@return: If necessary, the pending data accumulated from previous
C{encode} calls.
@rtype: L{str}
"""
class _IRequestEncoderFactory(Interface):
"""
A factory for returing L{_IRequestEncoder} instances.
@since: 12.3
"""
def encoderForRequest(request):
"""
If applicable, returns a L{_IRequestEncoder} instance which will encode
the request.
"""
class IClientRequest(Interface):
"""
An object representing an HTTP request to make to an HTTP server.
@since: 13.1
"""
method = Attribute(
"The HTTP method for this request, as L{bytes}. For example: "
"C{b'GET'}, C{b'HEAD'}, C{b'POST'}, etc.")
absoluteURI = Attribute(
"The absolute URI of the requested resource, as L{bytes}; or L{None} "
"if the absolute URI cannot be determined.")
headers = Attribute(
"Headers to be sent to the server, as "
"a L{twisted.web.http_headers.Headers} instance.")
class IAgent(Interface):
"""
An agent makes HTTP requests.
The way in which requests are issued is left up to each implementation.
Some may issue them directly to the server indicated by the net location
portion of the request URL. Others may use a proxy specified by system
configuration.
Processing of responses is also left very widely specified. An
implementation may perform no special handling of responses, or it may
implement redirect following or content negotiation, it may implement a
cookie store or automatically respond to authentication challenges. It may
implement many other unforeseen behaviors as well.
It is also intended that L{IAgent} implementations be composable. An
implementation which provides cookie handling features should re-use an
implementation that provides connection pooling and this combination could
be used by an implementation which adds content negotiation functionality.
Some implementations will be completely self-contained, such as those which
actually perform the network operations to send and receive requests, but
most or all other implementations should implement a small number of new
features (perhaps one new feature) and delegate the rest of the
request/response machinery to another implementation.
This allows for great flexibility in the behavior an L{IAgent} will
provide. For example, an L{IAgent} with web browser-like behavior could be
obtained by combining a number of (hypothetical) implementations::
baseAgent = Agent(reactor)
redirect = BrowserLikeRedirectAgent(baseAgent, limit=10)
authenticate = AuthenticateAgent(
redirect, [diskStore.credentials, GtkAuthInterface()])
cookie = CookieAgent(authenticate, diskStore.cookie)
decode = ContentDecoderAgent(cookie, [(b"gzip", GzipDecoder())])
cache = CacheAgent(decode, diskStore.cache)
doSomeRequests(cache)
"""
def request(method, uri, headers=None, bodyProducer=None):
"""
Request the resource at the given location.
@param method: The request method to use, such as C{"GET"}, C{"HEAD"},
C{"PUT"}, C{"POST"}, etc.
@type method: L{bytes}
@param uri: The location of the resource to request. This should be an
absolute URI but some implementations may support relative URIs
(with absolute or relative paths). I{HTTP} and I{HTTPS} are the
schemes most likely to be supported but others may be as well.
@type uri: L{bytes}
@param headers: The headers to send with the request (or L{None} to
send no extra headers). An implementation may add its own headers
to this (for example for client identification or content
negotiation).
@type headers: L{Headers} or L{None}
@param bodyProducer: An object which can generate bytes to make up the
body of this request (for example, the properly encoded contents of
a file for a file upload). Or, L{None} if the request is to have
no body.
@type bodyProducer: L{IBodyProducer} provider
@return: A L{Deferred} that fires with an L{IResponse} provider when
the header of the response has been received (regardless of the
response status code) or with a L{Failure} if there is any problem
which prevents that response from being received (including
problems that prevent the request from being sent).
@rtype: L{Deferred}
"""
class IPolicyForHTTPS(Interface):
"""
An L{IPolicyForHTTPS} provides a policy for verifying the certificates of
HTTPS connections, in the form of a L{client connection creator
<twisted.internet.interfaces.IOpenSSLClientConnectionCreator>} per network
location.
@since: 14.0
"""
def creatorForNetloc(hostname, port):
"""
Create a L{client connection creator
<twisted.internet.interfaces.IOpenSSLClientConnectionCreator>}
appropriate for the given URL "netloc"; i.e. hostname and port number
pair.
@param hostname: The name of the requested remote host.
@type hostname: L{bytes}
@param port: The number of the requested remote port.
@type port: L{int}
@return: A client connection creator expressing the security
requirements for the given remote host.
@rtype: L{client connection creator
<twisted.internet.interfaces.IOpenSSLClientConnectionCreator>}
"""
class IAgentEndpointFactory(Interface):
"""
An L{IAgentEndpointFactory} provides a way of constructing an endpoint
used for outgoing Agent requests. This is useful in the case of needing to
proxy outgoing connections, or to otherwise vary the transport used.
@since: 15.0
"""
def endpointForURI(uri):
"""
Construct and return an L{IStreamClientEndpoint} for the outgoing
request's connection.
@param uri: The URI of the request.
@type uri: L{twisted.web.client.URI}
@return: An endpoint which will have its C{connect} method called to
issue the request.
@rtype: an L{IStreamClientEndpoint} provider
@raises twisted.internet.error.SchemeNotSupported: If the given
URI's scheme cannot be handled by this factory.
"""
UNKNOWN_LENGTH = u"twisted.web.iweb.UNKNOWN_LENGTH"
__all__ = [
"IUsernameDigestHash", "ICredentialFactory", "IRequest",
"IBodyProducer", "IRenderable", "IResponse", "_IRequestEncoder",
"_IRequestEncoderFactory", "IClientRequest",
"UNKNOWN_LENGTH"]

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,303 @@
# -*- test-case-name: twisted.web.test.test_proxy -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Simplistic HTTP proxy support.
This comes in two main variants - the Proxy and the ReverseProxy.
When a Proxy is in use, a browser trying to connect to a server (say,
www.yahoo.com) will be intercepted by the Proxy, and the proxy will covertly
connect to the server, and return the result.
When a ReverseProxy is in use, the client connects directly to the ReverseProxy
(say, www.yahoo.com) which farms off the request to one of a pool of servers,
and returns the result.
Normally, a Proxy is used on the client end of an Internet connection, while a
ReverseProxy is used on the server end.
"""
from __future__ import absolute_import, division
from twisted.python.compat import urllib_parse, urlquote
from twisted.internet import reactor
from twisted.internet.protocol import ClientFactory
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET
from twisted.web.http import HTTPClient, Request, HTTPChannel, _QUEUED_SENTINEL
class ProxyClient(HTTPClient):
"""
Used by ProxyClientFactory to implement a simple web proxy.
@ivar _finished: A flag which indicates whether or not the original request
has been finished yet.
"""
_finished = False
def __init__(self, command, rest, version, headers, data, father):
self.father = father
self.command = command
self.rest = rest
if b"proxy-connection" in headers:
del headers[b"proxy-connection"]
headers[b"connection"] = b"close"
headers.pop(b'keep-alive', None)
self.headers = headers
self.data = data
def connectionMade(self):
self.sendCommand(self.command, self.rest)
for header, value in self.headers.items():
self.sendHeader(header, value)
self.endHeaders()
self.transport.write(self.data)
def handleStatus(self, version, code, message):
self.father.setResponseCode(int(code), message)
def handleHeader(self, key, value):
# t.web.server.Request sets default values for these headers in its
# 'process' method. When these headers are received from the remote
# server, they ought to override the defaults, rather than append to
# them.
if key.lower() in [b'server', b'date', b'content-type']:
self.father.responseHeaders.setRawHeaders(key, [value])
else:
self.father.responseHeaders.addRawHeader(key, value)
def handleResponsePart(self, buffer):
self.father.write(buffer)
def handleResponseEnd(self):
"""
Finish the original request, indicating that the response has been
completely written to it, and disconnect the outgoing transport.
"""
if not self._finished:
self._finished = True
self.father.finish()
self.transport.loseConnection()
class ProxyClientFactory(ClientFactory):
"""
Used by ProxyRequest to implement a simple web proxy.
"""
protocol = ProxyClient
def __init__(self, command, rest, version, headers, data, father):
self.father = father
self.command = command
self.rest = rest
self.headers = headers
self.data = data
self.version = version
def buildProtocol(self, addr):
return self.protocol(self.command, self.rest, self.version,
self.headers, self.data, self.father)
def clientConnectionFailed(self, connector, reason):
"""
Report a connection failure in a response to the incoming request as
an error.
"""
self.father.setResponseCode(501, b"Gateway error")
self.father.responseHeaders.addRawHeader(b"Content-Type", b"text/html")
self.father.write(b"<H1>Could not connect</H1>")
self.father.finish()
class ProxyRequest(Request):
"""
Used by Proxy to implement a simple web proxy.
@ivar reactor: the reactor used to create connections.
@type reactor: object providing L{twisted.internet.interfaces.IReactorTCP}
"""
protocols = {b'http': ProxyClientFactory}
ports = {b'http': 80}
def __init__(self, channel, queued=_QUEUED_SENTINEL, reactor=reactor):
Request.__init__(self, channel, queued)
self.reactor = reactor
def process(self):
parsed = urllib_parse.urlparse(self.uri)
protocol = parsed[0]
host = parsed[1].decode('ascii')
port = self.ports[protocol]
if ':' in host:
host, port = host.split(':')
port = int(port)
rest = urllib_parse.urlunparse((b'', b'') + parsed[2:])
if not rest:
rest = rest + b'/'
class_ = self.protocols[protocol]
headers = self.getAllHeaders().copy()
if b'host' not in headers:
headers[b'host'] = host.encode('ascii')
self.content.seek(0, 0)
s = self.content.read()
clientFactory = class_(self.method, rest, self.clientproto, headers,
s, self)
self.reactor.connectTCP(host, port, clientFactory)
class Proxy(HTTPChannel):
"""
This class implements a simple web proxy.
Since it inherits from L{twisted.web.http.HTTPChannel}, to use it you
should do something like this::
from twisted.web import http
f = http.HTTPFactory()
f.protocol = Proxy
Make the HTTPFactory a listener on a port as per usual, and you have
a fully-functioning web proxy!
"""
requestFactory = ProxyRequest
class ReverseProxyRequest(Request):
"""
Used by ReverseProxy to implement a simple reverse proxy.
@ivar proxyClientFactoryClass: a proxy client factory class, used to create
new connections.
@type proxyClientFactoryClass: L{ClientFactory}
@ivar reactor: the reactor used to create connections.
@type reactor: object providing L{twisted.internet.interfaces.IReactorTCP}
"""
proxyClientFactoryClass = ProxyClientFactory
def __init__(self, channel, queued=_QUEUED_SENTINEL, reactor=reactor):
Request.__init__(self, channel, queued)
self.reactor = reactor
def process(self):
"""
Handle this request by connecting to the proxied server and forwarding
it there, then forwarding the response back as the response to this
request.
"""
self.requestHeaders.setRawHeaders(b"host",
[self.factory.host.encode('ascii')])
clientFactory = self.proxyClientFactoryClass(
self.method, self.uri, self.clientproto, self.getAllHeaders(),
self.content.read(), self)
self.reactor.connectTCP(self.factory.host, self.factory.port,
clientFactory)
class ReverseProxy(HTTPChannel):
"""
Implements a simple reverse proxy.
For details of usage, see the file examples/reverse-proxy.py.
"""
requestFactory = ReverseProxyRequest
class ReverseProxyResource(Resource):
"""
Resource that renders the results gotten from another server
Put this resource in the tree to cause everything below it to be relayed
to a different server.
@ivar proxyClientFactoryClass: a proxy client factory class, used to create
new connections.
@type proxyClientFactoryClass: L{ClientFactory}
@ivar reactor: the reactor used to create connections.
@type reactor: object providing L{twisted.internet.interfaces.IReactorTCP}
"""
proxyClientFactoryClass = ProxyClientFactory
def __init__(self, host, port, path, reactor=reactor):
"""
@param host: the host of the web server to proxy.
@type host: C{str}
@param port: the port of the web server to proxy.
@type port: C{port}
@param path: the base path to fetch data from. Note that you shouldn't
put any trailing slashes in it, it will be added automatically in
request. For example, if you put B{/foo}, a request on B{/bar} will
be proxied to B{/foo/bar}. Any required encoding of special
characters (such as " " or "/") should have been done already.
@type path: C{bytes}
"""
Resource.__init__(self)
self.host = host
self.port = port
self.path = path
self.reactor = reactor
def getChild(self, path, request):
"""
Create and return a proxy resource with the same proxy configuration
as this one, except that its path also contains the segment given by
C{path} at the end.
"""
return ReverseProxyResource(
self.host, self.port, self.path + b'/' + urlquote(path, safe=b"").encode('utf-8'),
self.reactor)
def render(self, request):
"""
Render a request by forwarding it to the proxied server.
"""
# RFC 2616 tells us that we can omit the port if it's the default port,
# but we have to provide it otherwise
if self.port == 80:
host = self.host
else:
host = u"%s:%d" % (self.host, self.port)
request.requestHeaders.setRawHeaders(b"host", [host.encode('ascii')])
request.content.seek(0, 0)
qs = urllib_parse.urlparse(request.uri)[4]
if qs:
rest = self.path + b'?' + qs
else:
rest = self.path
clientFactory = self.proxyClientFactoryClass(
request.method, rest, request.clientproto,
request.getAllHeaders(), request.content.read(), request)
self.reactor.connectTCP(self.host, self.port, clientFactory)
return NOT_DONE_YET

View file

@ -0,0 +1,422 @@
# -*- test-case-name: twisted.web.test.test_web -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Implementation of the lowest-level Resource class.
"""
from __future__ import division, absolute_import
__all__ = [
'IResource', 'getChildForRequest',
'Resource', 'ErrorPage', 'NoResource', 'ForbiddenResource',
'EncodingResourceWrapper']
import warnings
from zope.interface import Attribute, Interface, implementer
from twisted.python.compat import nativeString, unicode
from twisted.python.reflect import prefixedMethodNames
from twisted.python.components import proxyForInterface
from twisted.web._responses import FORBIDDEN, NOT_FOUND
from twisted.web.error import UnsupportedMethod
class IResource(Interface):
"""
A web resource.
"""
isLeaf = Attribute(
"""
Signal if this IResource implementor is a "leaf node" or not. If True,
getChildWithDefault will not be called on this Resource.
""")
def getChildWithDefault(name, request):
"""
Return a child with the given name for the given request.
This is the external interface used by the Resource publishing
machinery. If implementing IResource without subclassing
Resource, it must be provided. However, if subclassing Resource,
getChild overridden instead.
@param name: A single path component from a requested URL. For example,
a request for I{http://example.com/foo/bar} will result in calls to
this method with C{b"foo"} and C{b"bar"} as values for this
argument.
@type name: C{bytes}
@param request: A representation of all of the information about the
request that is being made for this child.
@type request: L{twisted.web.server.Request}
"""
def putChild(path, child):
"""
Put a child IResource implementor at the given path.
@param path: A single path component, to be interpreted relative to the
path this resource is found at, at which to put the given child.
For example, if resource A can be found at I{http://example.com/foo}
then a call like C{A.putChild(b"bar", B)} will make resource B
available at I{http://example.com/foo/bar}.
@type path: C{bytes}
"""
def render(request):
"""
Render a request. This is called on the leaf resource for a request.
@return: Either C{server.NOT_DONE_YET} to indicate an asynchronous or a
C{bytes} instance to write as the response to the request. If
C{NOT_DONE_YET} is returned, at some point later (for example, in a
Deferred callback) call C{request.write(b"<html>")} to write data to
the request, and C{request.finish()} to send the data to the
browser.
@raise twisted.web.error.UnsupportedMethod: If the HTTP verb
requested is not supported by this resource.
"""
def getChildForRequest(resource, request):
"""
Traverse resource tree to find who will handle the request.
"""
while request.postpath and not resource.isLeaf:
pathElement = request.postpath.pop(0)
request.prepath.append(pathElement)
resource = resource.getChildWithDefault(pathElement, request)
return resource
@implementer(IResource)
class Resource:
"""
Define a web-accessible resource.
This serves 2 main purposes; one is to provide a standard representation
for what HTTP specification calls an 'entity', and the other is to provide
an abstract directory structure for URL retrieval.
"""
entityType = IResource
server = None
def __init__(self):
"""
Initialize.
"""
self.children = {}
isLeaf = 0
### Abstract Collection Interface
def listStaticNames(self):
return list(self.children.keys())
def listStaticEntities(self):
return list(self.children.items())
def listNames(self):
return list(self.listStaticNames()) + self.listDynamicNames()
def listEntities(self):
return list(self.listStaticEntities()) + self.listDynamicEntities()
def listDynamicNames(self):
return []
def listDynamicEntities(self, request=None):
return []
def getStaticEntity(self, name):
return self.children.get(name)
def getDynamicEntity(self, name, request):
if name not in self.children:
return self.getChild(name, request)
else:
return None
def delEntity(self, name):
del self.children[name]
def reallyPutEntity(self, name, entity):
self.children[name] = entity
# Concrete HTTP interface
def getChild(self, path, request):
"""
Retrieve a 'child' resource from me.
Implement this to create dynamic resource generation -- resources which
are always available may be registered with self.putChild().
This will not be called if the class-level variable 'isLeaf' is set in
your subclass; instead, the 'postpath' attribute of the request will be
left as a list of the remaining path elements.
For example, the URL /foo/bar/baz will normally be::
| site.resource.getChild('foo').getChild('bar').getChild('baz').
However, if the resource returned by 'bar' has isLeaf set to true, then
the getChild call will never be made on it.
Parameters and return value have the same meaning and requirements as
those defined by L{IResource.getChildWithDefault}.
"""
return NoResource("No such child resource.")
def getChildWithDefault(self, path, request):
"""
Retrieve a static or dynamically generated child resource from me.
First checks if a resource was added manually by putChild, and then
call getChild to check for dynamic resources. Only override if you want
to affect behaviour of all child lookups, rather than just dynamic
ones.
This will check to see if I have a pre-registered child resource of the
given name, and call getChild if I do not.
@see: L{IResource.getChildWithDefault}
"""
if path in self.children:
return self.children[path]
return self.getChild(path, request)
def getChildForRequest(self, request):
warnings.warn("Please use module level getChildForRequest.", DeprecationWarning, 2)
return getChildForRequest(self, request)
def putChild(self, path, child):
"""
Register a static child.
You almost certainly don't want '/' in your path. If you
intended to have the root of a folder, e.g. /foo/, you want
path to be ''.
@param path: A single path component.
@type path: L{bytes}
@param child: The child resource to register.
@type child: L{IResource}
@see: L{IResource.putChild}
"""
if not isinstance(path, bytes):
warnings.warn(
'Path segment must be bytes; '
'passing {0} has never worked, and '
'will raise an exception in the future.'
.format(type(path)),
category=DeprecationWarning,
stacklevel=2)
self.children[path] = child
child.server = self.server
def render(self, request):
"""
Render a given resource. See L{IResource}'s render method.
I delegate to methods of self with the form 'render_METHOD'
where METHOD is the HTTP that was used to make the
request. Examples: render_GET, render_HEAD, render_POST, and
so on. Generally you should implement those methods instead of
overriding this one.
render_METHOD methods are expected to return a byte string which will be
the rendered page, unless the return value is C{server.NOT_DONE_YET}, in
which case it is this class's responsibility to write the results using
C{request.write(data)} and then call C{request.finish()}.
Old code that overrides render() directly is likewise expected
to return a byte string or NOT_DONE_YET.
@see: L{IResource.render}
"""
m = getattr(self, 'render_' + nativeString(request.method), None)
if not m:
try:
allowedMethods = self.allowedMethods
except AttributeError:
allowedMethods = _computeAllowedMethods(self)
raise UnsupportedMethod(allowedMethods)
return m(request)
def render_HEAD(self, request):
"""
Default handling of HEAD method.
I just return self.render_GET(request). When method is HEAD,
the framework will handle this correctly.
"""
return self.render_GET(request)
def _computeAllowedMethods(resource):
"""
Compute the allowed methods on a C{Resource} based on defined render_FOO
methods. Used when raising C{UnsupportedMethod} but C{Resource} does
not define C{allowedMethods} attribute.
"""
allowedMethods = []
for name in prefixedMethodNames(resource.__class__, "render_"):
# Potentially there should be an API for encode('ascii') in this
# situation - an API for taking a Python native string (bytes on Python
# 2, text on Python 3) and returning a socket-compatible string type.
allowedMethods.append(name.encode('ascii'))
return allowedMethods
class ErrorPage(Resource):
"""
L{ErrorPage} is a resource which responds with a particular
(parameterized) status and a body consisting of HTML containing some
descriptive text. This is useful for rendering simple error pages.
@ivar template: A native string which will have a dictionary interpolated
into it to generate the response body. The dictionary has the following
keys:
- C{"code"}: The status code passed to L{ErrorPage.__init__}.
- C{"brief"}: The brief description passed to L{ErrorPage.__init__}.
- C{"detail"}: The detailed description passed to
L{ErrorPage.__init__}.
@ivar code: An integer status code which will be used for the response.
@type code: C{int}
@ivar brief: A short string which will be included in the response body as
the page title.
@type brief: C{str}
@ivar detail: A longer string which will be included in the response body.
@type detail: C{str}
"""
template = """
<html>
<head><title>%(code)s - %(brief)s</title></head>
<body>
<h1>%(brief)s</h1>
<p>%(detail)s</p>
</body>
</html>
"""
def __init__(self, status, brief, detail):
Resource.__init__(self)
self.code = status
self.brief = brief
self.detail = detail
def render(self, request):
request.setResponseCode(self.code)
request.setHeader(b"content-type", b"text/html; charset=utf-8")
interpolated = self.template % dict(
code=self.code, brief=self.brief, detail=self.detail)
if isinstance(interpolated, unicode):
return interpolated.encode('utf-8')
return interpolated
def getChild(self, chnam, request):
return self
class NoResource(ErrorPage):
"""
L{NoResource} is a specialization of L{ErrorPage} which returns the HTTP
response code I{NOT FOUND}.
"""
def __init__(self, message="Sorry. No luck finding that resource."):
ErrorPage.__init__(self, NOT_FOUND, "No Such Resource", message)
class ForbiddenResource(ErrorPage):
"""
L{ForbiddenResource} is a specialization of L{ErrorPage} which returns the
I{FORBIDDEN} HTTP response code.
"""
def __init__(self, message="Sorry, resource is forbidden."):
ErrorPage.__init__(self, FORBIDDEN, "Forbidden Resource", message)
class _IEncodingResource(Interface):
"""
A resource which knows about L{_IRequestEncoderFactory}.
@since: 12.3
"""
def getEncoder(request):
"""
Parse the request and return an encoder if applicable, using
L{_IRequestEncoderFactory.encoderForRequest}.
@return: A L{_IRequestEncoder}, or L{None}.
"""
@implementer(_IEncodingResource)
class EncodingResourceWrapper(proxyForInterface(IResource)):
"""
Wrap a L{IResource}, potentially applying an encoding to the response body
generated.
Note that the returned children resources won't be wrapped, so you have to
explicitly wrap them if you want the encoding to be applied.
@ivar encoders: A list of
L{_IRequestEncoderFactory<twisted.web.iweb._IRequestEncoderFactory>}
returning L{_IRequestEncoder<twisted.web.iweb._IRequestEncoder>} that
may transform the data passed to C{Request.write}. The list must be
sorted in order of priority: the first encoder factory handling the
request will prevent the others from doing the same.
@type encoders: C{list}.
@since: 12.3
"""
def __init__(self, original, encoders):
super(EncodingResourceWrapper, self).__init__(original)
self._encoders = encoders
def getEncoder(self, request):
"""
Browser the list of encoders looking for one applicable encoder.
"""
for encoderFactory in self._encoders:
encoder = encoderFactory.encoderForRequest(request)
if encoder is not None:
return encoder

View file

@ -0,0 +1,52 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
#
from twisted.web import resource
class RewriterResource(resource.Resource):
def __init__(self, orig, *rewriteRules):
resource.Resource.__init__(self)
self.resource = orig
self.rewriteRules = list(rewriteRules)
def _rewrite(self, request):
for rewriteRule in self.rewriteRules:
rewriteRule(request)
def getChild(self, path, request):
request.postpath.insert(0, path)
request.prepath.pop()
self._rewrite(request)
path = request.postpath.pop(0)
request.prepath.append(path)
return self.resource.getChildWithDefault(path, request)
def render(self, request):
self._rewrite(request)
return self.resource.render(request)
def tildeToUsers(request):
if request.postpath and request.postpath[0][:1]=='~':
request.postpath[:1] = ['users', request.postpath[0][1:]]
request.path = '/'+'/'.join(request.prepath+request.postpath)
def alias(aliasPath, sourcePath):
"""
I am not a very good aliaser. But I'm the best I can be. If I'm
aliasing to a Resource that generates links, and it uses any parts
of request.prepath to do so, the links will not be relative to the
aliased path, but rather to the aliased-to path. That I can't
alias static.File directory listings that nicely. However, I can
still be useful, as many resources will play nice.
"""
sourcePath = sourcePath.split('/')
aliasPath = aliasPath.split('/')
def rewriter(request):
if request.postpath[:len(aliasPath)] == aliasPath:
after = request.postpath[len(aliasPath):]
request.postpath = sourcePath + after
request.path = '/'+'/'.join(request.prepath+request.postpath)
return rewriter

View file

@ -0,0 +1,182 @@
# -*- test-case-name: twisted.web.test.test_script -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
I contain PythonScript, which is a very simple python script resource.
"""
from __future__ import division, absolute_import
import os, traceback
from twisted import copyright
from twisted.python.filepath import _coerceToFilesystemEncoding
from twisted.python.compat import execfile, networkString, NativeStringIO, _PY3
from twisted.web import http, server, static, resource, util
rpyNoResource = """<p>You forgot to assign to the variable "resource" in your script. For example:</p>
<pre>
# MyCoolWebApp.rpy
import mygreatresource
resource = mygreatresource.MyGreatResource()
</pre>
"""
class AlreadyCached(Exception):
"""
This exception is raised when a path has already been cached.
"""
class CacheScanner:
def __init__(self, path, registry):
self.path = path
self.registry = registry
self.doCache = 0
def cache(self):
c = self.registry.getCachedPath(self.path)
if c is not None:
raise AlreadyCached(c)
self.recache()
def recache(self):
self.doCache = 1
noRsrc = resource.ErrorPage(500, "Whoops! Internal Error", rpyNoResource)
def ResourceScript(path, registry):
"""
I am a normal py file which must define a 'resource' global, which should
be an instance of (a subclass of) web.resource.Resource; it will be
renderred.
"""
cs = CacheScanner(path, registry)
glob = {'__file__': _coerceToFilesystemEncoding("", path),
'resource': noRsrc,
'registry': registry,
'cache': cs.cache,
'recache': cs.recache}
try:
execfile(path, glob, glob)
except AlreadyCached as ac:
return ac.args[0]
rsrc = glob['resource']
if cs.doCache and rsrc is not noRsrc:
registry.cachePath(path, rsrc)
return rsrc
def ResourceTemplate(path, registry):
from quixote import ptl_compile
glob = {'__file__': _coerceToFilesystemEncoding("", path),
'resource': resource.ErrorPage(500, "Whoops! Internal Error",
rpyNoResource),
'registry': registry}
with open(path) as f: # Not closed by quixote as of 2.9.1
e = ptl_compile.compile_template(f, path)
code = compile(e, "<source>", "exec")
eval(code, glob, glob)
return glob['resource']
class ResourceScriptWrapper(resource.Resource):
def __init__(self, path, registry=None):
resource.Resource.__init__(self)
self.path = path
self.registry = registry or static.Registry()
def render(self, request):
res = ResourceScript(self.path, self.registry)
return res.render(request)
def getChildWithDefault(self, path, request):
res = ResourceScript(self.path, self.registry)
return res.getChildWithDefault(path, request)
class ResourceScriptDirectory(resource.Resource):
"""
L{ResourceScriptDirectory} is a resource which serves scripts from a
filesystem directory. File children of a L{ResourceScriptDirectory} will
be served using L{ResourceScript}. Directory children will be served using
another L{ResourceScriptDirectory}.
@ivar path: A C{str} giving the filesystem path in which children will be
looked up.
@ivar registry: A L{static.Registry} instance which will be used to decide
how to interpret scripts found as children of this resource.
"""
def __init__(self, pathname, registry=None):
resource.Resource.__init__(self)
self.path = pathname
self.registry = registry or static.Registry()
def getChild(self, path, request):
fn = os.path.join(self.path, path)
if os.path.isdir(fn):
return ResourceScriptDirectory(fn, self.registry)
if os.path.exists(fn):
return ResourceScript(fn, self.registry)
return resource.NoResource()
def render(self, request):
return resource.NoResource().render(request)
class PythonScript(resource.Resource):
"""
I am an extremely simple dynamic resource; an embedded python script.
This will execute a file (usually of the extension '.epy') as Python code,
internal to the webserver.
"""
isLeaf = True
def __init__(self, filename, registry):
"""
Initialize me with a script name.
"""
self.filename = filename
self.registry = registry
def render(self, request):
"""
Render me to a web client.
Load my file, execute it in a special namespace (with 'request' and
'__file__' global vars) and finish the request. Output to the web-page
will NOT be handled with print - standard output goes to the log - but
with request.write.
"""
request.setHeader(b"x-powered-by", networkString("Twisted/%s" % copyright.version))
namespace = {'request': request,
'__file__': _coerceToFilesystemEncoding("", self.filename),
'registry': self.registry}
try:
execfile(self.filename, namespace, namespace)
except IOError as e:
if e.errno == 2: #file not found
request.setResponseCode(http.NOT_FOUND)
request.write(resource.NoResource("File not found.").render(request))
except:
io = NativeStringIO()
traceback.print_exc(file=io)
output = util._PRE(io.getvalue())
if _PY3:
output = output.encode("utf8")
request.write(output)
request.finish()
return server.NOT_DONE_YET

View file

@ -0,0 +1,911 @@
# -*- test-case-name: twisted.web.test.test_web -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
This is a web server which integrates with the twisted.internet infrastructure.
@var NOT_DONE_YET: A token value which L{twisted.web.resource.IResource.render}
implementations can return to indicate that the application will later call
C{.write} and C{.finish} to complete the request, and that the HTTP
connection should be left open.
@type NOT_DONE_YET: Opaque; do not depend on any particular type for this
value.
"""
from __future__ import division, absolute_import
import copy
import os
import re
try:
from urllib import quote
except ImportError:
from urllib.parse import quote as _quote
def quote(string, *args, **kwargs):
return _quote(
string.decode('charmap'), *args, **kwargs).encode('charmap')
import zlib
from binascii import hexlify
from zope.interface import implementer
from twisted.python.compat import networkString, nativeString, intToBytes
from twisted.spread.pb import Copyable, ViewPoint
from twisted.internet import address, interfaces
from twisted.internet.error import AlreadyCalled, AlreadyCancelled
from twisted.web import iweb, http, util
from twisted.web.http import unquote
from twisted.python import reflect, failure, components
from twisted import copyright
from twisted.web import resource
from twisted.web.error import UnsupportedMethod
from incremental import Version
from twisted.python.deprecate import deprecatedModuleAttribute
from twisted.python.compat import escape
from twisted.logger import Logger
NOT_DONE_YET = 1
__all__ = [
'supportedMethods',
'Request',
'Session',
'Site',
'version',
'NOT_DONE_YET',
'GzipEncoderFactory'
]
# backwards compatibility
deprecatedModuleAttribute(
Version("Twisted", 12, 1, 0),
"Please use twisted.web.http.datetimeToString instead",
"twisted.web.server",
"date_time_string")
deprecatedModuleAttribute(
Version("Twisted", 12, 1, 0),
"Please use twisted.web.http.stringToDatetime instead",
"twisted.web.server",
"string_date_time")
date_time_string = http.datetimeToString
string_date_time = http.stringToDatetime
# Support for other methods may be implemented on a per-resource basis.
supportedMethods = (b'GET', b'HEAD', b'POST')
def _addressToTuple(addr):
if isinstance(addr, address.IPv4Address):
return ('INET', addr.host, addr.port)
elif isinstance(addr, address.UNIXAddress):
return ('UNIX', addr.name)
else:
return tuple(addr)
@implementer(iweb.IRequest)
class Request(Copyable, http.Request, components.Componentized):
"""
An HTTP request.
@ivar defaultContentType: A L{bytes} giving the default I{Content-Type}
value to send in responses if no other value is set. L{None} disables
the default.
@ivar _insecureSession: The L{Session} object representing state that will
be transmitted over plain-text HTTP.
@ivar _secureSession: The L{Session} object representing the state that
will be transmitted only over HTTPS.
"""
defaultContentType = b"text/html"
site = None
appRootURL = None
prepath = postpath = None
__pychecker__ = 'unusednames=issuer'
_inFakeHead = False
_encoder = None
_log = Logger()
def __init__(self, *args, **kw):
http.Request.__init__(self, *args, **kw)
components.Componentized.__init__(self)
def getStateToCopyFor(self, issuer):
x = self.__dict__.copy()
del x['transport']
# XXX refactor this attribute out; it's from protocol
# del x['server']
del x['channel']
del x['content']
del x['site']
self.content.seek(0, 0)
x['content_data'] = self.content.read()
x['remote'] = ViewPoint(issuer, self)
# Address objects aren't jellyable
x['host'] = _addressToTuple(x['host'])
x['client'] = _addressToTuple(x['client'])
# Header objects also aren't jellyable.
x['requestHeaders'] = list(x['requestHeaders'].getAllRawHeaders())
return x
# HTML generation helpers
def sibLink(self, name):
"""
Return the text that links to a sibling of the requested resource.
@param name: The sibling resource
@type name: C{bytes}
@return: A relative URL.
@rtype: C{bytes}
"""
if self.postpath:
return (len(self.postpath)*b"../") + name
else:
return name
def childLink(self, name):
"""
Return the text that links to a child of the requested resource.
@param name: The child resource
@type name: C{bytes}
@return: A relative URL.
@rtype: C{bytes}
"""
lpp = len(self.postpath)
if lpp > 1:
return ((lpp-1)*b"../") + name
elif lpp == 1:
return name
else: # lpp == 0
if len(self.prepath) and self.prepath[-1]:
return self.prepath[-1] + b'/' + name
else:
return name
def gotLength(self, length):
"""
Called when HTTP channel got length of content in this request.
This method is not intended for users.
@param length: The length of the request body, as indicated by the
request headers. L{None} if the request headers do not indicate a
length.
"""
try:
getContentFile = self.channel.site.getContentFile
except AttributeError:
http.Request.gotLength(self, length)
else:
self.content = getContentFile(length)
def process(self):
"""
Process a request.
Find the addressed resource in this request's L{Site},
and call L{self.render()<Request.render()>} with it.
@see: L{Site.getResourceFor()}
"""
# get site from channel
self.site = self.channel.site
# set various default headers
self.setHeader(b'server', version)
self.setHeader(b'date', http.datetimeToString())
# Resource Identification
self.prepath = []
self.postpath = list(map(unquote, self.path[1:].split(b'/')))
# Short-circuit for requests whose path is '*'.
if self.path == b'*':
self._handleStar()
return
try:
resrc = self.site.getResourceFor(self)
if resource._IEncodingResource.providedBy(resrc):
encoder = resrc.getEncoder(self)
if encoder is not None:
self._encoder = encoder
self.render(resrc)
except:
self.processingFailed(failure.Failure())
def write(self, data):
"""
Write data to the transport (if not responding to a HEAD request).
@param data: A string to write to the response.
@type data: L{bytes}
"""
if not self.startedWriting:
# Before doing the first write, check to see if a default
# Content-Type header should be supplied. We omit it on
# NOT_MODIFIED and NO_CONTENT responses. We also omit it if there
# is a Content-Length header set to 0, as empty bodies don't need
# a content-type.
needsCT = self.code not in (http.NOT_MODIFIED, http.NO_CONTENT)
contentType = self.responseHeaders.getRawHeaders(b'content-type')
contentLength = self.responseHeaders.getRawHeaders(
b'content-length'
)
contentLengthZero = contentLength and (contentLength[0] == b'0')
if (needsCT and contentType is None and
self.defaultContentType is not None and
not contentLengthZero
):
self.responseHeaders.setRawHeaders(
b'content-type', [self.defaultContentType])
# Only let the write happen if we're not generating a HEAD response by
# faking out the request method. Note, if we are doing that,
# startedWriting will never be true, and the above logic may run
# multiple times. It will only actually change the responseHeaders
# once though, so it's still okay.
if not self._inFakeHead:
if self._encoder:
data = self._encoder.encode(data)
http.Request.write(self, data)
def finish(self):
"""
Override C{http.Request.finish} for possible encoding.
"""
if self._encoder:
data = self._encoder.finish()
if data:
http.Request.write(self, data)
return http.Request.finish(self)
def render(self, resrc):
"""
Ask a resource to render itself.
If the resource does not support the requested method,
generate a C{NOT IMPLEMENTED} or C{NOT ALLOWED} response.
@param resrc: The resource to render.
@type resrc: L{twisted.web.resource.IResource}
@see: L{IResource.render()<twisted.web.resource.IResource.render()>}
"""
try:
body = resrc.render(self)
except UnsupportedMethod as e:
allowedMethods = e.allowedMethods
if (self.method == b"HEAD") and (b"GET" in allowedMethods):
# We must support HEAD (RFC 2616, 5.1.1). If the
# resource doesn't, fake it by giving the resource
# a 'GET' request and then return only the headers,
# not the body.
self._log.info(
"Using GET to fake a HEAD request for {resrc}",
resrc=resrc
)
self.method = b"GET"
self._inFakeHead = True
body = resrc.render(self)
if body is NOT_DONE_YET:
self._log.info(
"Tried to fake a HEAD request for {resrc}, but "
"it got away from me.", resrc=resrc
)
# Oh well, I guess we won't include the content length.
else:
self.setHeader(b'content-length', intToBytes(len(body)))
self._inFakeHead = False
self.method = b"HEAD"
self.write(b'')
self.finish()
return
if self.method in (supportedMethods):
# We MUST include an Allow header
# (RFC 2616, 10.4.6 and 14.7)
self.setHeader(b'Allow', b', '.join(allowedMethods))
s = ('''Your browser approached me (at %(URI)s) with'''
''' the method "%(method)s". I only allow'''
''' the method%(plural)s %(allowed)s here.''' % {
'URI': escape(nativeString(self.uri)),
'method': nativeString(self.method),
'plural': ((len(allowedMethods) > 1) and 's') or '',
'allowed': ', '.join(
[nativeString(x) for x in allowedMethods])
})
epage = resource.ErrorPage(http.NOT_ALLOWED,
"Method Not Allowed", s)
body = epage.render(self)
else:
epage = resource.ErrorPage(
http.NOT_IMPLEMENTED, "Huh?",
"I don't know how to treat a %s request." %
(escape(self.method.decode("charmap")),))
body = epage.render(self)
# end except UnsupportedMethod
if body is NOT_DONE_YET:
return
if not isinstance(body, bytes):
body = resource.ErrorPage(
http.INTERNAL_SERVER_ERROR,
"Request did not return bytes",
"Request: " + util._PRE(reflect.safe_repr(self)) + "<br />" +
"Resource: " + util._PRE(reflect.safe_repr(resrc)) + "<br />" +
"Value: " + util._PRE(reflect.safe_repr(body))).render(self)
if self.method == b"HEAD":
if len(body) > 0:
# This is a Bad Thing (RFC 2616, 9.4)
self._log.info(
"Warning: HEAD request {slf} for resource {resrc} is"
" returning a message body. I think I'll eat it.",
slf=self,
resrc=resrc
)
self.setHeader(b'content-length',
intToBytes(len(body)))
self.write(b'')
else:
self.setHeader(b'content-length',
intToBytes(len(body)))
self.write(body)
self.finish()
def processingFailed(self, reason):
"""
Finish this request with an indication that processing failed and
possibly display a traceback.
@param reason: Reason this request has failed.
@type reason: L{twisted.python.failure.Failure}
@return: The reason passed to this method.
@rtype: L{twisted.python.failure.Failure}
"""
self._log.failure('', failure=reason)
if self.site.displayTracebacks:
body = (b"<html><head><title>web.Server Traceback"
b" (most recent call last)</title></head>"
b"<body><b>web.Server Traceback"
b" (most recent call last):</b>\n\n" +
util.formatFailure(reason) +
b"\n\n</body></html>\n")
else:
body = (b"<html><head><title>Processing Failed"
b"</title></head><body>"
b"<b>Processing Failed</b></body></html>")
self.setResponseCode(http.INTERNAL_SERVER_ERROR)
self.setHeader(b'content-type', b"text/html")
self.setHeader(b'content-length', intToBytes(len(body)))
self.write(body)
self.finish()
return reason
def view_write(self, issuer, data):
"""Remote version of write; same interface.
"""
self.write(data)
def view_finish(self, issuer):
"""Remote version of finish; same interface.
"""
self.finish()
def view_addCookie(self, issuer, k, v, **kwargs):
"""Remote version of addCookie; same interface.
"""
self.addCookie(k, v, **kwargs)
def view_setHeader(self, issuer, k, v):
"""Remote version of setHeader; same interface.
"""
self.setHeader(k, v)
def view_setLastModified(self, issuer, when):
"""Remote version of setLastModified; same interface.
"""
self.setLastModified(when)
def view_setETag(self, issuer, tag):
"""Remote version of setETag; same interface.
"""
self.setETag(tag)
def view_setResponseCode(self, issuer, code, message=None):
"""
Remote version of setResponseCode; same interface.
"""
self.setResponseCode(code, message)
def view_registerProducer(self, issuer, producer, streaming):
"""Remote version of registerProducer; same interface.
(requires a remote producer.)
"""
self.registerProducer(_RemoteProducerWrapper(producer), streaming)
def view_unregisterProducer(self, issuer):
self.unregisterProducer()
### these calls remain local
_secureSession = None
_insecureSession = None
@property
def session(self):
"""
If a session has already been created or looked up with
L{Request.getSession}, this will return that object. (This will always
be the session that matches the security of the request; so if
C{forceNotSecure} is used on a secure request, this will not return
that session.)
@return: the session attribute
@rtype: L{Session} or L{None}
"""
if self.isSecure():
return self._secureSession
else:
return self._insecureSession
def getSession(self, sessionInterface=None, forceNotSecure=False):
"""
Check if there is a session cookie, and if not, create it.
By default, the cookie with be secure for HTTPS requests and not secure
for HTTP requests. If for some reason you need access to the insecure
cookie from a secure request you can set C{forceNotSecure = True}.
@param forceNotSecure: Should we retrieve a session that will be
transmitted over HTTP, even if this L{Request} was delivered over
HTTPS?
@type forceNotSecure: L{bool}
"""
# Make sure we aren't creating a secure session on a non-secure page
secure = self.isSecure() and not forceNotSecure
if not secure:
cookieString = b"TWISTED_SESSION"
sessionAttribute = "_insecureSession"
else:
cookieString = b"TWISTED_SECURE_SESSION"
sessionAttribute = "_secureSession"
session = getattr(self, sessionAttribute)
if session is not None:
# We have a previously created session.
try:
# Refresh the session, to keep it alive.
session.touch()
except (AlreadyCalled, AlreadyCancelled):
# Session has already expired.
session = None
if session is None:
# No session was created yet for this request.
cookiename = b"_".join([cookieString] + self.sitepath)
sessionCookie = self.getCookie(cookiename)
if sessionCookie:
try:
session = self.site.getSession(sessionCookie)
except KeyError:
pass
# if it still hasn't been set, fix it up.
if not session:
session = self.site.makeSession()
self.addCookie(cookiename, session.uid, path=b"/",
secure=secure)
setattr(self, sessionAttribute, session)
if sessionInterface:
return session.getComponent(sessionInterface)
return session
def _prePathURL(self, prepath):
port = self.getHost().port
if self.isSecure():
default = 443
else:
default = 80
if port == default:
hostport = ''
else:
hostport = ':%d' % port
prefix = networkString('http%s://%s%s/' % (
self.isSecure() and 's' or '',
nativeString(self.getRequestHostname()),
hostport))
path = b'/'.join([quote(segment, safe=b'') for segment in prepath])
return prefix + path
def prePathURL(self):
return self._prePathURL(self.prepath)
def URLPath(self):
from twisted.python import urlpath
return urlpath.URLPath.fromRequest(self)
def rememberRootURL(self):
"""
Remember the currently-processed part of the URL for later
recalling.
"""
url = self._prePathURL(self.prepath[:-1])
self.appRootURL = url
def getRootURL(self):
"""
Get a previously-remembered URL.
@return: An absolute URL.
@rtype: L{bytes}
"""
return self.appRootURL
def _handleStar(self):
"""
Handle receiving a request whose path is '*'.
RFC 7231 defines an OPTIONS * request as being something that a client
can send as a low-effort way to probe server capabilities or readiness.
Rather than bother the user with this, we simply fast-path it back to
an empty 200 OK. Any non-OPTIONS verb gets a 405 Method Not Allowed
telling the client they can only use OPTIONS.
"""
if self.method == b'OPTIONS':
self.setResponseCode(http.OK)
else:
self.setResponseCode(http.NOT_ALLOWED)
self.setHeader(b'Allow', b'OPTIONS')
# RFC 7231 says we MUST set content-length 0 when responding to this
# with no body.
self.setHeader(b'Content-Length', b'0')
self.finish()
@implementer(iweb._IRequestEncoderFactory)
class GzipEncoderFactory(object):
"""
@cvar compressLevel: The compression level used by the compressor, default
to 9 (highest).
@since: 12.3
"""
_gzipCheckRegex = re.compile(br'(:?^|[\s,])gzip(:?$|[\s,])')
compressLevel = 9
def encoderForRequest(self, request):
"""
Check the headers if the client accepts gzip encoding, and encodes the
request if so.
"""
acceptHeaders = b','.join(
request.requestHeaders.getRawHeaders(b'accept-encoding', []))
if self._gzipCheckRegex.search(acceptHeaders):
encoding = request.responseHeaders.getRawHeaders(
b'content-encoding')
if encoding:
encoding = b','.join(encoding + [b'gzip'])
else:
encoding = b'gzip'
request.responseHeaders.setRawHeaders(b'content-encoding',
[encoding])
return _GzipEncoder(self.compressLevel, request)
@implementer(iweb._IRequestEncoder)
class _GzipEncoder(object):
"""
An encoder which supports gzip.
@ivar _zlibCompressor: The zlib compressor instance used to compress the
stream.
@ivar _request: A reference to the originating request.
@since: 12.3
"""
_zlibCompressor = None
def __init__(self, compressLevel, request):
self._zlibCompressor = zlib.compressobj(
compressLevel, zlib.DEFLATED, 16 + zlib.MAX_WBITS)
self._request = request
def encode(self, data):
"""
Write to the request, automatically compressing data on the fly.
"""
if not self._request.startedWriting:
# Remove the content-length header, we can't honor it
# because we compress on the fly.
self._request.responseHeaders.removeHeader(b'content-length')
return self._zlibCompressor.compress(data)
def finish(self):
"""
Finish handling the request request, flushing any data from the zlib
buffer.
"""
remain = self._zlibCompressor.flush()
self._zlibCompressor = None
return remain
class _RemoteProducerWrapper:
def __init__(self, remote):
self.resumeProducing = remote.remoteMethod("resumeProducing")
self.pauseProducing = remote.remoteMethod("pauseProducing")
self.stopProducing = remote.remoteMethod("stopProducing")
class Session(components.Componentized):
"""
A user's session with a system.
This utility class contains no functionality, but is used to
represent a session.
@ivar uid: A unique identifier for the session.
@type uid: L{bytes}
@ivar _reactor: An object providing L{IReactorTime} to use for scheduling
expiration.
@ivar sessionTimeout: timeout of a session, in seconds.
"""
sessionTimeout = 900
_expireCall = None
def __init__(self, site, uid, reactor=None):
"""
Initialize a session with a unique ID for that session.
"""
components.Componentized.__init__(self)
if reactor is None:
from twisted.internet import reactor
self._reactor = reactor
self.site = site
self.uid = uid
self.expireCallbacks = []
self.touch()
self.sessionNamespaces = {}
def startCheckingExpiration(self):
"""
Start expiration tracking.
@return: L{None}
"""
self._expireCall = self._reactor.callLater(
self.sessionTimeout, self.expire)
def notifyOnExpire(self, callback):
"""
Call this callback when the session expires or logs out.
"""
self.expireCallbacks.append(callback)
def expire(self):
"""
Expire/logout of the session.
"""
del self.site.sessions[self.uid]
for c in self.expireCallbacks:
c()
self.expireCallbacks = []
if self._expireCall and self._expireCall.active():
self._expireCall.cancel()
# Break reference cycle.
self._expireCall = None
def touch(self):
"""
Notify session modification.
"""
self.lastModified = self._reactor.seconds()
if self._expireCall is not None:
self._expireCall.reset(self.sessionTimeout)
version = networkString("TwistedWeb/%s" % (copyright.version,))
@implementer(interfaces.IProtocolNegotiationFactory)
class Site(http.HTTPFactory):
"""
A web site: manage log, sessions, and resources.
@ivar counter: increment value used for generating unique sessions ID.
@ivar requestFactory: A factory which is called with (channel)
and creates L{Request} instances. Default to L{Request}.
@ivar displayTracebacks: If set, unhandled exceptions raised during
rendering are returned to the client as HTML. Default to C{False}.
@ivar sessionFactory: factory for sessions objects. Default to L{Session}.
@ivar sessionCheckTime: Deprecated. See L{Session.sessionTimeout} instead.
"""
counter = 0
requestFactory = Request
displayTracebacks = False
sessionFactory = Session
sessionCheckTime = 1800
_entropy = os.urandom
def __init__(self, resource, requestFactory=None, *args, **kwargs):
"""
@param resource: The root of the resource hierarchy. All request
traversal for requests received by this factory will begin at this
resource.
@type resource: L{IResource} provider
@param requestFactory: Overwrite for default requestFactory.
@type requestFactory: C{callable} or C{class}.
@see: L{twisted.web.http.HTTPFactory.__init__}
"""
http.HTTPFactory.__init__(self, *args, **kwargs)
self.sessions = {}
self.resource = resource
if requestFactory is not None:
self.requestFactory = requestFactory
def _openLogFile(self, path):
from twisted.python import logfile
return logfile.LogFile(os.path.basename(path), os.path.dirname(path))
def __getstate__(self):
d = self.__dict__.copy()
d['sessions'] = {}
return d
def _mkuid(self):
"""
(internal) Generate an opaque, unique ID for a user's session.
"""
self.counter = self.counter + 1
return hexlify(self._entropy(32))
def makeSession(self):
"""
Generate a new Session instance, and store it for future reference.
"""
uid = self._mkuid()
session = self.sessions[uid] = self.sessionFactory(self, uid)
session.startCheckingExpiration()
return session
def getSession(self, uid):
"""
Get a previously generated session.
@param uid: Unique ID of the session.
@type uid: L{bytes}.
@raise: L{KeyError} if the session is not found.
"""
return self.sessions[uid]
def buildProtocol(self, addr):
"""
Generate a channel attached to this site.
"""
channel = http.HTTPFactory.buildProtocol(self, addr)
channel.requestFactory = self.requestFactory
channel.site = self
return channel
isLeaf = 0
def render(self, request):
"""
Redirect because a Site is always a directory.
"""
request.redirect(request.prePathURL() + b'/')
request.finish()
def getChildWithDefault(self, pathEl, request):
"""
Emulate a resource's getChild method.
"""
request.site = self
return self.resource.getChildWithDefault(pathEl, request)
def getResourceFor(self, request):
"""
Get a resource for a request.
This iterates through the resource hierarchy, calling
getChildWithDefault on each resource it finds for a path element,
stopping when it hits an element where isLeaf is true.
"""
request.site = self
# Sitepath is used to determine cookie names between distributed
# servers and disconnected sites.
request.sitepath = copy.copy(request.prepath)
return resource.getChildForRequest(self.resource, request)
# IProtocolNegotiationFactory
def acceptableProtocols(self):
"""
Protocols this server can speak.
"""
baseProtocols = [b'http/1.1']
if http.H2_ENABLED:
baseProtocols.insert(0, b'h2')
return baseProtocols

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,637 @@
# -*- test-case-name: twisted.web.test.test_xml -*-
#
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
*S*mall, *U*ncomplicated *X*ML.
This is a very simple implementation of XML/HTML as a network
protocol. It is not at all clever. Its main features are that it
does not:
- support namespaces
- mung mnemonic entity references
- validate
- perform *any* external actions (such as fetching URLs or writing files)
under *any* circumstances
- has lots and lots of horrible hacks for supporting broken HTML (as an
option, they're not on by default).
"""
from __future__ import print_function
from twisted.internet.protocol import Protocol
from twisted.python.compat import unicode
from twisted.python.reflect import prefixedMethodNames
# Elements of the three-tuples in the state table.
BEGIN_HANDLER = 0
DO_HANDLER = 1
END_HANDLER = 2
identChars = '.-_:'
lenientIdentChars = identChars + ';+#/%~'
def nop(*args, **kw):
"Do nothing."
def unionlist(*args):
l = []
for x in args:
l.extend(x)
d = dict([(x, 1) for x in l])
return d.keys()
def zipfndict(*args, **kw):
default = kw.get('default', nop)
d = {}
for key in unionlist(*[fndict.keys() for fndict in args]):
d[key] = tuple([x.get(key, default) for x in args])
return d
def prefixedMethodClassDict(clazz, prefix):
return dict([(name, getattr(clazz, prefix + name)) for name in prefixedMethodNames(clazz, prefix)])
def prefixedMethodObjDict(obj, prefix):
return dict([(name, getattr(obj, prefix + name)) for name in prefixedMethodNames(obj.__class__, prefix)])
class ParseError(Exception):
def __init__(self, filename, line, col, message):
self.filename = filename
self.line = line
self.col = col
self.message = message
def __str__(self):
return "%s:%s:%s: %s" % (self.filename, self.line, self.col,
self.message)
class XMLParser(Protocol):
state = None
encodings = None
filename = "<xml />"
beExtremelyLenient = 0
_prepend = None
# _leadingBodyData will sometimes be set before switching to the
# 'bodydata' state, when we "accidentally" read a byte of bodydata
# in a different state.
_leadingBodyData = None
def connectionMade(self):
self.lineno = 1
self.colno = 0
self.encodings = []
def saveMark(self):
'''Get the line number and column of the last character parsed'''
# This gets replaced during dataReceived, restored afterwards
return (self.lineno, self.colno)
def _parseError(self, message):
raise ParseError(*((self.filename,)+self.saveMark()+(message,)))
def _buildStateTable(self):
'''Return a dictionary of begin, do, end state function tuples'''
# _buildStateTable leaves something to be desired but it does what it
# does.. probably slowly, so I'm doing some evil caching so it doesn't
# get called more than once per class.
stateTable = getattr(self.__class__, '__stateTable', None)
if stateTable is None:
stateTable = self.__class__.__stateTable = zipfndict(
*[prefixedMethodObjDict(self, prefix)
for prefix in ('begin_', 'do_', 'end_')])
return stateTable
def _decode(self, data):
if 'UTF-16' in self.encodings or 'UCS-2' in self.encodings:
assert not len(data) & 1, 'UTF-16 must come in pairs for now'
if self._prepend:
data = self._prepend + data
for encoding in self.encodings:
data = unicode(data, encoding)
return data
def maybeBodyData(self):
if self.endtag:
return 'bodydata'
# Get ready for fun! We're going to allow
# <script>if (foo < bar)</script> to work!
# We do this by making everything between <script> and
# </script> a Text
# BUT <script src="foo"> will be special-cased to do regular,
# lenient behavior, because those may not have </script>
# -radix
if (self.tagName == 'script' and 'src' not in self.tagAttributes):
# we do this ourselves rather than having begin_waitforendscript
# because that can get called multiple times and we don't want
# bodydata to get reset other than the first time.
self.begin_bodydata(None)
return 'waitforendscript'
return 'bodydata'
def dataReceived(self, data):
stateTable = self._buildStateTable()
if not self.state:
# all UTF-16 starts with this string
if data.startswith((b'\xff\xfe', b'\xfe\xff')):
self._prepend = data[0:2]
self.encodings.append('UTF-16')
data = data[2:]
self.state = 'begin'
if self.encodings:
data = self._decode(data)
else:
data = data.decode("utf-8")
# bring state, lineno, colno into local scope
lineno, colno = self.lineno, self.colno
curState = self.state
# replace saveMark with a nested scope function
_saveMark = self.saveMark
def saveMark():
return (lineno, colno)
self.saveMark = saveMark
# fetch functions from the stateTable
beginFn, doFn, endFn = stateTable[curState]
try:
for byte in data:
# do newline stuff
if byte == u'\n':
lineno += 1
colno = 0
else:
colno += 1
newState = doFn(byte)
if newState is not None and newState != curState:
# this is the endFn from the previous state
endFn()
curState = newState
beginFn, doFn, endFn = stateTable[curState]
beginFn(byte)
finally:
self.saveMark = _saveMark
self.lineno, self.colno = lineno, colno
# state doesn't make sense if there's an exception..
self.state = curState
def connectionLost(self, reason):
"""
End the last state we were in.
"""
stateTable = self._buildStateTable()
stateTable[self.state][END_HANDLER]()
# state methods
def do_begin(self, byte):
if byte.isspace():
return
if byte != '<':
if self.beExtremelyLenient:
self._leadingBodyData = byte
return 'bodydata'
self._parseError("First char of document [%r] wasn't <" % (byte,))
return 'tagstart'
def begin_comment(self, byte):
self.commentbuf = ''
def do_comment(self, byte):
self.commentbuf += byte
if self.commentbuf.endswith('-->'):
self.gotComment(self.commentbuf[:-3])
return 'bodydata'
def begin_tagstart(self, byte):
self.tagName = '' # name of the tag
self.tagAttributes = {} # attributes of the tag
self.termtag = 0 # is the tag self-terminating
self.endtag = 0
def do_tagstart(self, byte):
if byte.isalnum() or byte in identChars:
self.tagName += byte
if self.tagName == '!--':
return 'comment'
elif byte.isspace():
if self.tagName:
if self.endtag:
# properly strict thing to do here is probably to only
# accept whitespace
return 'waitforgt'
return 'attrs'
else:
self._parseError("Whitespace before tag-name")
elif byte == '>':
if self.endtag:
self.gotTagEnd(self.tagName)
return 'bodydata'
else:
self.gotTagStart(self.tagName, {})
return (not self.beExtremelyLenient) and 'bodydata' or self.maybeBodyData()
elif byte == '/':
if self.tagName:
return 'afterslash'
else:
self.endtag = 1
elif byte in '!?':
if self.tagName:
if not self.beExtremelyLenient:
self._parseError("Invalid character in tag-name")
else:
self.tagName += byte
self.termtag = 1
elif byte == '[':
if self.tagName == '!':
return 'expectcdata'
else:
self._parseError("Invalid '[' in tag-name")
else:
if self.beExtremelyLenient:
self.bodydata = '<'
return 'unentity'
self._parseError('Invalid tag character: %r'% byte)
def begin_unentity(self, byte):
self.bodydata += byte
def do_unentity(self, byte):
self.bodydata += byte
return 'bodydata'
def end_unentity(self):
self.gotText(self.bodydata)
def begin_expectcdata(self, byte):
self.cdatabuf = byte
def do_expectcdata(self, byte):
self.cdatabuf += byte
cdb = self.cdatabuf
cd = '[CDATA['
if len(cd) > len(cdb):
if cd.startswith(cdb):
return
elif self.beExtremelyLenient:
## WHAT THE CRAP!? MSWord9 generates HTML that includes these
## bizarre <![if !foo]> <![endif]> chunks, so I've gotta ignore
## 'em as best I can. this should really be a separate parse
## state but I don't even have any idea what these _are_.
return 'waitforgt'
else:
self._parseError("Mal-formed CDATA header")
if cd == cdb:
self.cdatabuf = ''
return 'cdata'
self._parseError("Mal-formed CDATA header")
def do_cdata(self, byte):
self.cdatabuf += byte
if self.cdatabuf.endswith("]]>"):
self.cdatabuf = self.cdatabuf[:-3]
return 'bodydata'
def end_cdata(self):
self.gotCData(self.cdatabuf)
self.cdatabuf = ''
def do_attrs(self, byte):
if byte.isalnum() or byte in identChars:
# XXX FIXME really handle !DOCTYPE at some point
if self.tagName == '!DOCTYPE':
return 'doctype'
if self.tagName[0] in '!?':
return 'waitforgt'
return 'attrname'
elif byte.isspace():
return
elif byte == '>':
self.gotTagStart(self.tagName, self.tagAttributes)
return (not self.beExtremelyLenient) and 'bodydata' or self.maybeBodyData()
elif byte == '/':
return 'afterslash'
elif self.beExtremelyLenient:
# discard and move on? Only case I've seen of this so far was:
# <foo bar="baz"">
return
self._parseError("Unexpected character: %r" % byte)
def begin_doctype(self, byte):
self.doctype = byte
def do_doctype(self, byte):
if byte == '>':
return 'bodydata'
self.doctype += byte
def end_doctype(self):
self.gotDoctype(self.doctype)
self.doctype = None
def do_waitforgt(self, byte):
if byte == '>':
if self.endtag or not self.beExtremelyLenient:
return 'bodydata'
return self.maybeBodyData()
def begin_attrname(self, byte):
self.attrname = byte
self._attrname_termtag = 0
def do_attrname(self, byte):
if byte.isalnum() or byte in identChars:
self.attrname += byte
return
elif byte == '=':
return 'beforeattrval'
elif byte.isspace():
return 'beforeeq'
elif self.beExtremelyLenient:
if byte in '"\'':
return 'attrval'
if byte in lenientIdentChars or byte.isalnum():
self.attrname += byte
return
if byte == '/':
self._attrname_termtag = 1
return
if byte == '>':
self.attrval = 'True'
self.tagAttributes[self.attrname] = self.attrval
self.gotTagStart(self.tagName, self.tagAttributes)
if self._attrname_termtag:
self.gotTagEnd(self.tagName)
return 'bodydata'
return self.maybeBodyData()
# something is really broken. let's leave this attribute where it
# is and move on to the next thing
return
self._parseError("Invalid attribute name: %r %r" % (self.attrname, byte))
def do_beforeattrval(self, byte):
if byte in '"\'':
return 'attrval'
elif byte.isspace():
return
elif self.beExtremelyLenient:
if byte in lenientIdentChars or byte.isalnum():
return 'messyattr'
if byte == '>':
self.attrval = 'True'
self.tagAttributes[self.attrname] = self.attrval
self.gotTagStart(self.tagName, self.tagAttributes)
return self.maybeBodyData()
if byte == '\\':
# I saw this in actual HTML once:
# <font size=\"3\"><sup>SM</sup></font>
return
self._parseError("Invalid initial attribute value: %r; Attribute values must be quoted." % byte)
attrname = ''
attrval = ''
def begin_beforeeq(self,byte):
self._beforeeq_termtag = 0
def do_beforeeq(self, byte):
if byte == '=':
return 'beforeattrval'
elif byte.isspace():
return
elif self.beExtremelyLenient:
if byte.isalnum() or byte in identChars:
self.attrval = 'True'
self.tagAttributes[self.attrname] = self.attrval
return 'attrname'
elif byte == '>':
self.attrval = 'True'
self.tagAttributes[self.attrname] = self.attrval
self.gotTagStart(self.tagName, self.tagAttributes)
if self._beforeeq_termtag:
self.gotTagEnd(self.tagName)
return 'bodydata'
return self.maybeBodyData()
elif byte == '/':
self._beforeeq_termtag = 1
return
self._parseError("Invalid attribute")
def begin_attrval(self, byte):
self.quotetype = byte
self.attrval = ''
def do_attrval(self, byte):
if byte == self.quotetype:
return 'attrs'
self.attrval += byte
def end_attrval(self):
self.tagAttributes[self.attrname] = self.attrval
self.attrname = self.attrval = ''
def begin_messyattr(self, byte):
self.attrval = byte
def do_messyattr(self, byte):
if byte.isspace():
return 'attrs'
elif byte == '>':
endTag = 0
if self.attrval.endswith('/'):
endTag = 1
self.attrval = self.attrval[:-1]
self.tagAttributes[self.attrname] = self.attrval
self.gotTagStart(self.tagName, self.tagAttributes)
if endTag:
self.gotTagEnd(self.tagName)
return 'bodydata'
return self.maybeBodyData()
else:
self.attrval += byte
def end_messyattr(self):
if self.attrval:
self.tagAttributes[self.attrname] = self.attrval
def begin_afterslash(self, byte):
self._after_slash_closed = 0
def do_afterslash(self, byte):
# this state is only after a self-terminating slash, e.g. <foo/>
if self._after_slash_closed:
self._parseError("Mal-formed")#XXX When does this happen??
if byte != '>':
if self.beExtremelyLenient:
return
else:
self._parseError("No data allowed after '/'")
self._after_slash_closed = 1
self.gotTagStart(self.tagName, self.tagAttributes)
self.gotTagEnd(self.tagName)
# don't need maybeBodyData here because there better not be
# any javascript code after a <script/>... we'll see :(
return 'bodydata'
def begin_bodydata(self, byte):
if self._leadingBodyData:
self.bodydata = self._leadingBodyData
del self._leadingBodyData
else:
self.bodydata = ''
def do_bodydata(self, byte):
if byte == '<':
return 'tagstart'
if byte == '&':
return 'entityref'
self.bodydata += byte
def end_bodydata(self):
self.gotText(self.bodydata)
self.bodydata = ''
def do_waitforendscript(self, byte):
if byte == '<':
return 'waitscriptendtag'
self.bodydata += byte
def begin_waitscriptendtag(self, byte):
self.temptagdata = ''
self.tagName = ''
self.endtag = 0
def do_waitscriptendtag(self, byte):
# 1 enforce / as first byte read
# 2 enforce following bytes to be subset of "script" until
# tagName == "script"
# 2a when that happens, gotText(self.bodydata) and gotTagEnd(self.tagName)
# 3 spaces can happen anywhere, they're ignored
# e.g. < / script >
# 4 anything else causes all data I've read to be moved to the
# bodydata, and switch back to waitforendscript state
# If it turns out this _isn't_ a </script>, we need to
# remember all the data we've been through so we can append it
# to bodydata
self.temptagdata += byte
# 1
if byte == '/':
self.endtag = True
elif not self.endtag:
self.bodydata += "<" + self.temptagdata
return 'waitforendscript'
# 2
elif byte.isalnum() or byte in identChars:
self.tagName += byte
if not 'script'.startswith(self.tagName):
self.bodydata += "<" + self.temptagdata
return 'waitforendscript'
elif self.tagName == 'script':
self.gotText(self.bodydata)
self.gotTagEnd(self.tagName)
return 'waitforgt'
# 3
elif byte.isspace():
return 'waitscriptendtag'
# 4
else:
self.bodydata += "<" + self.temptagdata
return 'waitforendscript'
def begin_entityref(self, byte):
self.erefbuf = ''
self.erefextra = '' # extra bit for lenient mode
def do_entityref(self, byte):
if byte.isspace() or byte == "<":
if self.beExtremelyLenient:
# '&foo' probably was '&amp;foo'
if self.erefbuf and self.erefbuf != "amp":
self.erefextra = self.erefbuf
self.erefbuf = "amp"
if byte == "<":
return "tagstart"
else:
self.erefextra += byte
return 'spacebodydata'
self._parseError("Bad entity reference")
elif byte != ';':
self.erefbuf += byte
else:
return 'bodydata'
def end_entityref(self):
self.gotEntityReference(self.erefbuf)
# hacky support for space after & in entityref in beExtremelyLenient
# state should only happen in that case
def begin_spacebodydata(self, byte):
self.bodydata = self.erefextra
self.erefextra = None
do_spacebodydata = do_bodydata
end_spacebodydata = end_bodydata
# Sorta SAX-ish API
def gotTagStart(self, name, attributes):
'''Encountered an opening tag.
Default behaviour is to print.'''
print('begin', name, attributes)
def gotText(self, data):
'''Encountered text
Default behaviour is to print.'''
print('text:', repr(data))
def gotEntityReference(self, entityRef):
'''Encountered mnemonic entity reference
Default behaviour is to print.'''
print('entityRef: &%s;' % entityRef)
def gotComment(self, comment):
'''Encountered comment.
Default behaviour is to ignore.'''
pass
def gotCData(self, cdata):
'''Encountered CDATA
Default behaviour is to call the gotText method'''
self.gotText(cdata)
def gotDoctype(self, doctype):
"""Encountered DOCTYPE
This is really grotty: it basically just gives you everything between
'<!DOCTYPE' and '>' as an argument.
"""
print('!DOCTYPE', repr(doctype))
def gotTagEnd(self, name):
'''Encountered closing tag
Default behaviour is to print.'''
print('end', name)

View file

@ -0,0 +1,316 @@
# -*- test-case-name: twisted.web.test.test_tap -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Support for creating a service which runs a web server.
"""
from __future__ import absolute_import, division
import os
import warnings
import incremental
from twisted.application import service, strports
from twisted.internet import interfaces, reactor
from twisted.python import usage, reflect, threadpool, deprecate
from twisted.spread import pb
from twisted.web import distrib
from twisted.web import resource, server, static, script, demo, wsgi
from twisted.web import twcgi
class Options(usage.Options):
"""
Define the options accepted by the I{twistd web} plugin.
"""
synopsis = "[web options]"
optParameters = [["logfile", "l", None,
"Path to web CLF (Combined Log Format) log file."],
["certificate", "c", "server.pem",
"(DEPRECATED: use --listen) "
"SSL certificate to use for HTTPS. "],
["privkey", "k", "server.pem",
"(DEPRECATED: use --listen) "
"SSL certificate to use for HTTPS."],
]
optFlags = [
["notracebacks", "n", (
"(DEPRECATED: Tracebacks are disabled by default. "
"See --enable-tracebacks to turn them on.")],
["display-tracebacks", "", (
"Show uncaught exceptions during rendering tracebacks to "
"the client. WARNING: This may be a security risk and "
"expose private data!")],
]
optFlags.append([
"personal", "",
"Instead of generating a webserver, generate a "
"ResourcePublisher which listens on the port given by "
"--listen, or ~/%s " % (distrib.UserDirectory.userSocketName,) +
"if --listen is not specified."])
compData = usage.Completions(
optActions={"logfile" : usage.CompleteFiles("*.log"),
"certificate" : usage.CompleteFiles("*.pem"),
"privkey" : usage.CompleteFiles("*.pem")}
)
longdesc = """\
This starts a webserver. If you specify no arguments, it will be a
demo webserver that has the Test class from twisted.web.demo in it."""
def __init__(self):
usage.Options.__init__(self)
self['indexes'] = []
self['root'] = None
self['extraHeaders'] = []
self['ports'] = []
self['port'] = self['https'] = None
def opt_port(self, port):
"""
(DEPRECATED: use --listen)
Strports description of port to start the server on
"""
msg = deprecate.getDeprecationWarningString(
self.opt_port, incremental.Version('Twisted', 18, 4, 0))
warnings.warn(msg, category=DeprecationWarning, stacklevel=2)
self['port'] = port
opt_p = opt_port
def opt_https(self, port):
"""
(DEPRECATED: use --listen)
Port to listen on for Secure HTTP.
"""
msg = deprecate.getDeprecationWarningString(
self.opt_https, incremental.Version('Twisted', 18, 4, 0))
warnings.warn(msg, category=DeprecationWarning, stacklevel=2)
self['https'] = port
def opt_listen(self, port):
"""
Add an strports description of port to start the server on.
[default: tcp:8080]
"""
self['ports'].append(port)
def opt_index(self, indexName):
"""
Add the name of a file used to check for directory indexes.
[default: index, index.html]
"""
self['indexes'].append(indexName)
opt_i = opt_index
def opt_user(self):
"""
Makes a server with ~/public_html and ~/.twistd-web-pb support for
users.
"""
self['root'] = distrib.UserDirectory()
opt_u = opt_user
def opt_path(self, path):
"""
<path> is either a specific file or a directory to be set as the root
of the web server. Use this if you have a directory full of HTML, cgi,
epy, or rpy files or any other files that you want to be served up raw.
"""
self['root'] = static.File(os.path.abspath(path))
self['root'].processors = {
'.epy': script.PythonScript,
'.rpy': script.ResourceScript,
}
self['root'].processors['.cgi'] = twcgi.CGIScript
def opt_processor(self, proc):
"""
`ext=class' where `class' is added as a Processor for files ending
with `ext'.
"""
if not isinstance(self['root'], static.File):
raise usage.UsageError(
"You can only use --processor after --path.")
ext, klass = proc.split('=', 1)
self['root'].processors[ext] = reflect.namedClass(klass)
def opt_class(self, className):
"""
Create a Resource subclass with a zero-argument constructor.
"""
classObj = reflect.namedClass(className)
self['root'] = classObj()
def opt_resource_script(self, name):
"""
An .rpy file to be used as the root resource of the webserver.
"""
self['root'] = script.ResourceScriptWrapper(name)
def opt_wsgi(self, name):
"""
The FQPN of a WSGI application object to serve as the root resource of
the webserver.
"""
try:
application = reflect.namedAny(name)
except (AttributeError, ValueError):
raise usage.UsageError("No such WSGI application: %r" % (name,))
pool = threadpool.ThreadPool()
reactor.callWhenRunning(pool.start)
reactor.addSystemEventTrigger('after', 'shutdown', pool.stop)
self['root'] = wsgi.WSGIResource(reactor, pool, application)
def opt_mime_type(self, defaultType):
"""
Specify the default mime-type for static files.
"""
if not isinstance(self['root'], static.File):
raise usage.UsageError(
"You can only use --mime_type after --path.")
self['root'].defaultType = defaultType
opt_m = opt_mime_type
def opt_allow_ignore_ext(self):
"""
Specify whether or not a request for 'foo' should return 'foo.ext'
"""
if not isinstance(self['root'], static.File):
raise usage.UsageError("You can only use --allow_ignore_ext "
"after --path.")
self['root'].ignoreExt('*')
def opt_ignore_ext(self, ext):
"""
Specify an extension to ignore. These will be processed in order.
"""
if not isinstance(self['root'], static.File):
raise usage.UsageError("You can only use --ignore_ext "
"after --path.")
self['root'].ignoreExt(ext)
def opt_add_header(self, header):
"""
Specify an additional header to be included in all responses. Specified
as "HeaderName: HeaderValue".
"""
name, value = header.split(':', 1)
self['extraHeaders'].append((name.strip(), value.strip()))
def postOptions(self):
"""
Set up conditional defaults and check for dependencies.
If SSL is not available but an HTTPS server was configured, raise a
L{UsageError} indicating that this is not possible.
If no server port was supplied, select a default appropriate for the
other options supplied.
"""
if self['port'] is not None:
self['ports'].append(self['port'])
if self['https'] is not None:
try:
reflect.namedModule('OpenSSL.SSL')
except ImportError:
raise usage.UsageError("SSL support not installed")
sslStrport = 'ssl:port={}:privateKey={}:certKey={}'.format(
self['https'],
self['privkey'],
self['certificate'],
)
self['ports'].append(sslStrport)
if len(self['ports']) == 0:
if self['personal']:
path = os.path.expanduser(
os.path.join('~', distrib.UserDirectory.userSocketName))
self['ports'].append('unix:' + path)
else:
self['ports'].append('tcp:8080')
def makePersonalServerFactory(site):
"""
Create and return a factory which will respond to I{distrib} requests
against the given site.
@type site: L{twisted.web.server.Site}
@rtype: L{twisted.internet.protocol.Factory}
"""
return pb.PBServerFactory(distrib.ResourcePublisher(site))
class _AddHeadersResource(resource.Resource):
def __init__(self, originalResource, headers):
self._originalResource = originalResource
self._headers = headers
def getChildWithDefault(self, name, request):
for k, v in self._headers:
request.responseHeaders.addRawHeader(k, v)
return self._originalResource.getChildWithDefault(name, request)
def makeService(config):
s = service.MultiService()
if config['root']:
root = config['root']
if config['indexes']:
config['root'].indexNames = config['indexes']
else:
# This really ought to be web.Admin or something
root = demo.Test()
if isinstance(root, static.File):
root.registry.setComponent(interfaces.IServiceCollection, s)
if config['extraHeaders']:
root = _AddHeadersResource(root, config['extraHeaders'])
if config['logfile']:
site = server.Site(root, logPath=config['logfile'])
else:
site = server.Site(root)
if config["display-tracebacks"]:
site.displayTracebacks = True
# Deprecate --notracebacks/-n
if config["notracebacks"]:
msg = deprecate._getDeprecationWarningString(
"--notracebacks", incremental.Version('Twisted', 19, 7, 0))
warnings.warn(msg, category=DeprecationWarning, stacklevel=2)
if config['personal']:
site = makePersonalServerFactory(site)
for port in config['ports']:
svc = strports.service(port, site)
svc.setServiceParent(s)
return s

View file

@ -0,0 +1,575 @@
# -*- test-case-name: twisted.web.test.test_template -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
HTML rendering for twisted.web.
@var VALID_HTML_TAG_NAMES: A list of recognized HTML tag names, used by the
L{tag} object.
@var TEMPLATE_NAMESPACE: The XML namespace used to identify attributes and
elements used by the templating system, which should be removed from the
final output document.
@var tags: A convenience object which can produce L{Tag} objects on demand via
attribute access. For example: C{tags.div} is equivalent to C{Tag("div")}.
Tags not specified in L{VALID_HTML_TAG_NAMES} will result in an
L{AttributeError}.
"""
from __future__ import division, absolute_import
__all__ = [
'TEMPLATE_NAMESPACE', 'VALID_HTML_TAG_NAMES', 'Element', 'TagLoader',
'XMLString', 'XMLFile', 'renderer', 'flatten', 'flattenString', 'tags',
'Comment', 'CDATA', 'Tag', 'slot', 'CharRef', 'renderElement'
]
import warnings
from collections import OrderedDict
from zope.interface import implementer
from xml.sax import make_parser, handler
from twisted.python.compat import NativeStringIO, items
from twisted.python.filepath import FilePath
from twisted.web._stan import Tag, slot, Comment, CDATA, CharRef
from twisted.web.iweb import ITemplateLoader
from twisted.logger import Logger
TEMPLATE_NAMESPACE = 'http://twistedmatrix.com/ns/twisted.web.template/0.1'
# Go read the definition of NOT_DONE_YET. For lulz. This is totally
# equivalent. And this turns out to be necessary, because trying to import
# NOT_DONE_YET in this module causes a circular import which we cannot escape
# from. From which we cannot escape. Etc. glyph is okay with this solution for
# now, and so am I, as long as this comment stays to explain to future
# maintainers what it means. ~ C.
#
# See http://twistedmatrix.com/trac/ticket/5557 for progress on fixing this.
NOT_DONE_YET = 1
_moduleLog = Logger()
class _NSContext(object):
"""
A mapping from XML namespaces onto their prefixes in the document.
"""
def __init__(self, parent=None):
"""
Pull out the parent's namespaces, if there's no parent then default to
XML.
"""
self.parent = parent
if parent is not None:
self.nss = OrderedDict(parent.nss)
else:
self.nss = {'http://www.w3.org/XML/1998/namespace':'xml'}
def get(self, k, d=None):
"""
Get a prefix for a namespace.
@param d: The default prefix value.
"""
return self.nss.get(k, d)
def __setitem__(self, k, v):
"""
Proxy through to setting the prefix for the namespace.
"""
self.nss.__setitem__(k, v)
def __getitem__(self, k):
"""
Proxy through to getting the prefix for the namespace.
"""
return self.nss.__getitem__(k)
class _ToStan(handler.ContentHandler, handler.EntityResolver):
"""
A SAX parser which converts an XML document to the Twisted STAN
Document Object Model.
"""
def __init__(self, sourceFilename):
"""
@param sourceFilename: the filename to load the XML out of.
"""
self.sourceFilename = sourceFilename
self.prefixMap = _NSContext()
self.inCDATA = False
def setDocumentLocator(self, locator):
"""
Set the document locator, which knows about line and character numbers.
"""
self.locator = locator
def startDocument(self):
"""
Initialise the document.
"""
self.document = []
self.current = self.document
self.stack = []
self.xmlnsAttrs = []
def endDocument(self):
"""
Document ended.
"""
def processingInstruction(self, target, data):
"""
Processing instructions are ignored.
"""
def startPrefixMapping(self, prefix, uri):
"""
Set up the prefix mapping, which maps fully qualified namespace URIs
onto namespace prefixes.
This gets called before startElementNS whenever an C{xmlns} attribute
is seen.
"""
self.prefixMap = _NSContext(self.prefixMap)
self.prefixMap[uri] = prefix
# Ignore the template namespace; we'll replace those during parsing.
if uri == TEMPLATE_NAMESPACE:
return
# Add to a list that will be applied once we have the element.
if prefix is None:
self.xmlnsAttrs.append(('xmlns',uri))
else:
self.xmlnsAttrs.append(('xmlns:%s'%prefix,uri))
def endPrefixMapping(self, prefix):
"""
"Pops the stack" on the prefix mapping.
Gets called after endElementNS.
"""
self.prefixMap = self.prefixMap.parent
def startElementNS(self, namespaceAndName, qname, attrs):
"""
Gets called when we encounter a new xmlns attribute.
@param namespaceAndName: a (namespace, name) tuple, where name
determines which type of action to take, if the namespace matches
L{TEMPLATE_NAMESPACE}.
@param qname: ignored.
@param attrs: attributes on the element being started.
"""
filename = self.sourceFilename
lineNumber = self.locator.getLineNumber()
columnNumber = self.locator.getColumnNumber()
ns, name = namespaceAndName
if ns == TEMPLATE_NAMESPACE:
if name == 'transparent':
name = ''
elif name == 'slot':
try:
# Try to get the default value for the slot
default = attrs[(None, 'default')]
except KeyError:
# If there wasn't one, then use None to indicate no
# default.
default = None
el = slot(
attrs[(None, 'name')], default=default,
filename=filename, lineNumber=lineNumber,
columnNumber=columnNumber)
self.stack.append(el)
self.current.append(el)
self.current = el.children
return
render = None
attrs = OrderedDict(attrs)
for k, v in items(attrs):
attrNS, justTheName = k
if attrNS != TEMPLATE_NAMESPACE:
continue
if justTheName == 'render':
render = v
del attrs[k]
# nonTemplateAttrs is a dictionary mapping attributes that are *not* in
# TEMPLATE_NAMESPACE to their values. Those in TEMPLATE_NAMESPACE were
# just removed from 'attrs' in the loop immediately above. The key in
# nonTemplateAttrs is either simply the attribute name (if it was not
# specified as having a namespace in the template) or prefix:name,
# preserving the xml namespace prefix given in the document.
nonTemplateAttrs = OrderedDict()
for (attrNs, attrName), v in items(attrs):
nsPrefix = self.prefixMap.get(attrNs)
if nsPrefix is None:
attrKey = attrName
else:
attrKey = '%s:%s' % (nsPrefix, attrName)
nonTemplateAttrs[attrKey] = v
if ns == TEMPLATE_NAMESPACE and name == 'attr':
if not self.stack:
# TODO: define a better exception for this?
raise AssertionError(
'<{%s}attr> as top-level element' % (TEMPLATE_NAMESPACE,))
if 'name' not in nonTemplateAttrs:
# TODO: same here
raise AssertionError(
'<{%s}attr> requires a name attribute' % (TEMPLATE_NAMESPACE,))
el = Tag('', render=render, filename=filename,
lineNumber=lineNumber, columnNumber=columnNumber)
self.stack[-1].attributes[nonTemplateAttrs['name']] = el
self.stack.append(el)
self.current = el.children
return
# Apply any xmlns attributes
if self.xmlnsAttrs:
nonTemplateAttrs.update(OrderedDict(self.xmlnsAttrs))
self.xmlnsAttrs = []
# Add the prefix that was used in the parsed template for non-template
# namespaces (which will not be consumed anyway).
if ns != TEMPLATE_NAMESPACE and ns is not None:
prefix = self.prefixMap[ns]
if prefix is not None:
name = '%s:%s' % (self.prefixMap[ns],name)
el = Tag(
name, attributes=OrderedDict(nonTemplateAttrs), render=render,
filename=filename, lineNumber=lineNumber,
columnNumber=columnNumber)
self.stack.append(el)
self.current.append(el)
self.current = el.children
def characters(self, ch):
"""
Called when we receive some characters. CDATA characters get passed
through as is.
@type ch: C{string}
"""
if self.inCDATA:
self.stack[-1].append(ch)
return
self.current.append(ch)
def endElementNS(self, name, qname):
"""
A namespace tag is closed. Pop the stack, if there's anything left in
it, otherwise return to the document's namespace.
"""
self.stack.pop()
if self.stack:
self.current = self.stack[-1].children
else:
self.current = self.document
def startDTD(self, name, publicId, systemId):
"""
DTDs are ignored.
"""
def endDTD(self, *args):
"""
DTDs are ignored.
"""
def startCDATA(self):
"""
We're starting to be in a CDATA element, make a note of this.
"""
self.inCDATA = True
self.stack.append([])
def endCDATA(self):
"""
We're no longer in a CDATA element. Collect up the characters we've
parsed and put them in a new CDATA object.
"""
self.inCDATA = False
comment = ''.join(self.stack.pop())
self.current.append(CDATA(comment))
def comment(self, content):
"""
Add an XML comment which we've encountered.
"""
self.current.append(Comment(content))
def _flatsaxParse(fl):
"""
Perform a SAX parse of an XML document with the _ToStan class.
@param fl: The XML document to be parsed.
@type fl: A file object or filename.
@return: a C{list} of Stan objects.
"""
parser = make_parser()
parser.setFeature(handler.feature_validation, 0)
parser.setFeature(handler.feature_namespaces, 1)
parser.setFeature(handler.feature_external_ges, 0)
parser.setFeature(handler.feature_external_pes, 0)
s = _ToStan(getattr(fl, "name", None))
parser.setContentHandler(s)
parser.setEntityResolver(s)
parser.setProperty(handler.property_lexical_handler, s)
parser.parse(fl)
return s.document
@implementer(ITemplateLoader)
class TagLoader(object):
"""
An L{ITemplateLoader} that loads existing L{IRenderable} providers.
@ivar tag: The object which will be loaded.
@type tag: An L{IRenderable} provider.
"""
def __init__(self, tag):
"""
@param tag: The object which will be loaded.
@type tag: An L{IRenderable} provider.
"""
self.tag = tag
def load(self):
return [self.tag]
@implementer(ITemplateLoader)
class XMLString(object):
"""
An L{ITemplateLoader} that loads and parses XML from a string.
@ivar _loadedTemplate: The loaded document.
@type _loadedTemplate: a C{list} of Stan objects.
"""
def __init__(self, s):
"""
Run the parser on a L{NativeStringIO} copy of the string.
@param s: The string from which to load the XML.
@type s: C{str}, or a UTF-8 encoded L{bytes}.
"""
if not isinstance(s, str):
s = s.decode('utf8')
self._loadedTemplate = _flatsaxParse(NativeStringIO(s))
def load(self):
"""
Return the document.
@return: the loaded document.
@rtype: a C{list} of Stan objects.
"""
return self._loadedTemplate
@implementer(ITemplateLoader)
class XMLFile(object):
"""
An L{ITemplateLoader} that loads and parses XML from a file.
@ivar _loadedTemplate: The loaded document, or L{None}, if not loaded.
@type _loadedTemplate: a C{list} of Stan objects, or L{None}.
@ivar _path: The L{FilePath}, file object, or filename that is being
loaded from.
"""
def __init__(self, path):
"""
Run the parser on a file.
@param path: The file from which to load the XML.
@type path: L{FilePath}
"""
if not isinstance(path, FilePath):
warnings.warn(
"Passing filenames or file objects to XMLFile is deprecated "
"since Twisted 12.1. Pass a FilePath instead.",
category=DeprecationWarning, stacklevel=2)
self._loadedTemplate = None
self._path = path
def _loadDoc(self):
"""
Read and parse the XML.
@return: the loaded document.
@rtype: a C{list} of Stan objects.
"""
if not isinstance(self._path, FilePath):
return _flatsaxParse(self._path)
else:
with self._path.open('r') as f:
return _flatsaxParse(f)
def __repr__(self):
return '<XMLFile of %r>' % (self._path,)
def load(self):
"""
Return the document, first loading it if necessary.
@return: the loaded document.
@rtype: a C{list} of Stan objects.
"""
if self._loadedTemplate is None:
self._loadedTemplate = self._loadDoc()
return self._loadedTemplate
# Last updated October 2011, using W3Schools as a reference. Link:
# http://www.w3schools.com/html5/html5_reference.asp
# Note that <xmp> is explicitly omitted; its semantics do not work with
# t.w.template and it is officially deprecated.
VALID_HTML_TAG_NAMES = set([
'a', 'abbr', 'acronym', 'address', 'applet', 'area', 'article', 'aside',
'audio', 'b', 'base', 'basefont', 'bdi', 'bdo', 'big', 'blockquote',
'body', 'br', 'button', 'canvas', 'caption', 'center', 'cite', 'code',
'col', 'colgroup', 'command', 'datalist', 'dd', 'del', 'details', 'dfn',
'dir', 'div', 'dl', 'dt', 'em', 'embed', 'fieldset', 'figcaption',
'figure', 'font', 'footer', 'form', 'frame', 'frameset', 'h1', 'h2', 'h3',
'h4', 'h5', 'h6', 'head', 'header', 'hgroup', 'hr', 'html', 'i', 'iframe',
'img', 'input', 'ins', 'isindex', 'keygen', 'kbd', 'label', 'legend',
'li', 'link', 'map', 'mark', 'menu', 'meta', 'meter', 'nav', 'noframes',
'noscript', 'object', 'ol', 'optgroup', 'option', 'output', 'p', 'param',
'pre', 'progress', 'q', 'rp', 'rt', 'ruby', 's', 'samp', 'script',
'section', 'select', 'small', 'source', 'span', 'strike', 'strong',
'style', 'sub', 'summary', 'sup', 'table', 'tbody', 'td', 'textarea',
'tfoot', 'th', 'thead', 'time', 'title', 'tr', 'tt', 'u', 'ul', 'var',
'video', 'wbr',
])
class _TagFactory(object):
"""
A factory for L{Tag} objects; the implementation of the L{tags} object.
This allows for the syntactic convenience of C{from twisted.web.html import
tags; tags.a(href="linked-page.html")}, where 'a' can be basically any HTML
tag.
The class is not exposed publicly because you only ever need one of these,
and we already made it for you.
@see: L{tags}
"""
def __getattr__(self, tagName):
if tagName == 'transparent':
return Tag('')
# allow for E.del as E.del_
tagName = tagName.rstrip('_')
if tagName not in VALID_HTML_TAG_NAMES:
raise AttributeError('unknown tag %r' % (tagName,))
return Tag(tagName)
tags = _TagFactory()
def renderElement(request, element,
doctype=b'<!DOCTYPE html>', _failElement=None):
"""
Render an element or other C{IRenderable}.
@param request: The C{Request} being rendered to.
@param element: An C{IRenderable} which will be rendered.
@param doctype: A C{bytes} which will be written as the first line of
the request, or L{None} to disable writing of a doctype. The C{string}
should not include a trailing newline and will default to the HTML5
doctype C{'<!DOCTYPE html>'}.
@returns: NOT_DONE_YET
@since: 12.1
"""
if doctype is not None:
request.write(doctype)
request.write(b'\n')
if _failElement is None:
_failElement = twisted.web.util.FailureElement
d = flatten(request, element, request.write)
def eb(failure):
_moduleLog.failure(
"An error occurred while rendering the response.",
failure=failure
)
if request.site.displayTracebacks:
return flatten(request, _failElement(failure),
request.write).encode('utf8')
else:
request.write(
(b'<div style="font-size:800%;'
b'background-color:#FFF;'
b'color:#F00'
b'">An error occurred while rendering the response.</div>'))
d.addErrback(eb)
d.addBoth(lambda _: request.finish())
return NOT_DONE_YET
from twisted.web._element import Element, renderer
from twisted.web._flatten import flatten, flattenString
import twisted.web.util

View file

@ -0,0 +1,7 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.web}.
"""

View file

@ -0,0 +1,103 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
General helpers for L{twisted.web} unit tests.
"""
from __future__ import division, absolute_import
from twisted.internet.defer import succeed
from twisted.web import server
from twisted.trial.unittest import TestCase
from twisted.python.failure import Failure
from twisted.web._flatten import flattenString
from twisted.web.error import FlattenerError
def _render(resource, request):
result = resource.render(request)
if isinstance(result, bytes):
request.write(result)
request.finish()
return succeed(None)
elif result is server.NOT_DONE_YET:
if request.finished:
return succeed(None)
else:
return request.notifyFinish()
else:
raise ValueError("Unexpected return value: %r" % (result,))
class FlattenTestCase(TestCase):
"""
A test case that assists with testing L{twisted.web._flatten}.
"""
def assertFlattensTo(self, root, target):
"""
Assert that a root element, when flattened, is equal to a string.
"""
d = flattenString(None, root)
d.addCallback(lambda s: self.assertEqual(s, target))
return d
def assertFlattensImmediately(self, root, target):
"""
Assert that a root element, when flattened, is equal to a string, and
performs no asynchronus Deferred anything.
This version is more convenient in tests which wish to make multiple
assertions about flattening, since it can be called multiple times
without having to add multiple callbacks.
@return: the result of rendering L{root}, which should be equivalent to
L{target}.
@rtype: L{bytes}
"""
results = []
it = self.assertFlattensTo(root, target)
it.addBoth(results.append)
# Do our best to clean it up if something goes wrong.
self.addCleanup(it.cancel)
if not results:
self.fail("Rendering did not complete immediately.")
result = results[0]
if isinstance(result, Failure):
result.raiseException()
return results[0]
def assertFlatteningRaises(self, root, exn):
"""
Assert flattening a root element raises a particular exception.
"""
d = self.assertFailure(self.assertFlattensTo(root, b''), FlattenerError)
d.addCallback(lambda exc: self.assertIsInstance(exc._exception, exn))
return d
def assertIsFilesystemTemporary(case, fileObj):
"""
Assert that C{fileObj} is a temporary file on the filesystem.
@param case: A C{TestCase} instance to use to make the assertion.
@raise: C{case.failureException} if C{fileObj} is not a temporary file on
the filesystem.
"""
# The tempfile API used to create content returns an instance of a
# different type depending on what platform we're running on. The point
# here is to verify that the request body is in a file that's on the
# filesystem. Having a fileno method that returns an int is a somewhat
# close approximation of this. -exarkun
case.assertIsInstance(fileObj.fileno(), int)
__all__ = ["_render", "FlattenTestCase", "assertIsFilesystemTemporary"]

View file

@ -0,0 +1,168 @@
"""
Helpers for URI and method injection tests.
@see: U{CVE-2019-12387}
"""
import string
UNPRINTABLE_ASCII = (
frozenset(range(0, 128)) -
frozenset(bytearray(string.printable, 'ascii'))
)
NONASCII = frozenset(range(128, 256))
class MethodInjectionTestsMixin(object):
"""
A mixin that runs HTTP method injection tests. Define
L{MethodInjectionTestsMixin.attemptRequestWithMaliciousMethod} in
a L{twisted.trial.unittest.SynchronousTestCase} subclass to test
how HTTP client code behaves when presented with malicious HTTP
methods.
@see: U{CVE-2019-12387}
"""
def attemptRequestWithMaliciousMethod(self, method):
"""
Attempt to send a request with the given method. This should
synchronously raise a L{ValueError} if either is invalid.
@param method: the method (e.g. C{GET\x00})
@param uri: the URI
@type method:
"""
raise NotImplementedError()
def test_methodWithCLRFRejected(self):
"""
Issuing a request with a method that contains a carriage
return and line feed fails with a L{ValueError}.
"""
with self.assertRaises(ValueError) as cm:
method = b"GET\r\nX-Injected-Header: value"
self.attemptRequestWithMaliciousMethod(method)
self.assertRegex(str(cm.exception), "^Invalid method")
def test_methodWithUnprintableASCIIRejected(self):
"""
Issuing a request with a method that contains unprintable
ASCII characters fails with a L{ValueError}.
"""
for c in UNPRINTABLE_ASCII:
method = b"GET%s" % (bytearray([c]),)
with self.assertRaises(ValueError) as cm:
self.attemptRequestWithMaliciousMethod(method)
self.assertRegex(str(cm.exception), "^Invalid method")
def test_methodWithNonASCIIRejected(self):
"""
Issuing a request with a method that contains non-ASCII
characters fails with a L{ValueError}.
"""
for c in NONASCII:
method = b"GET%s" % (bytearray([c]),)
with self.assertRaises(ValueError) as cm:
self.attemptRequestWithMaliciousMethod(method)
self.assertRegex(str(cm.exception), "^Invalid method")
class URIInjectionTestsMixin(object):
"""
A mixin that runs HTTP URI injection tests. Define
L{MethodInjectionTestsMixin.attemptRequestWithMaliciousURI} in a
L{twisted.trial.unittest.SynchronousTestCase} subclass to test how
HTTP client code behaves when presented with malicious HTTP
URIs.
"""
def attemptRequestWithMaliciousURI(self, method):
"""
Attempt to send a request with the given URI. This should
synchronously raise a L{ValueError} if either is invalid.
@param uri: the URI.
@type method:
"""
raise NotImplementedError()
def test_hostWithCRLFRejected(self):
"""
Issuing a request with a URI whose host contains a carriage
return and line feed fails with a L{ValueError}.
"""
with self.assertRaises(ValueError) as cm:
uri = b"http://twisted\r\n.invalid/path"
self.attemptRequestWithMaliciousURI(uri)
self.assertRegex(str(cm.exception), "^Invalid URI")
def test_hostWithWithUnprintableASCIIRejected(self):
"""
Issuing a request with a URI whose host contains unprintable
ASCII characters fails with a L{ValueError}.
"""
for c in UNPRINTABLE_ASCII:
uri = b"http://twisted%s.invalid/OK" % (bytearray([c]),)
with self.assertRaises(ValueError) as cm:
self.attemptRequestWithMaliciousURI(uri)
self.assertRegex(str(cm.exception), "^Invalid URI")
def test_hostWithNonASCIIRejected(self):
"""
Issuing a request with a URI whose host contains non-ASCII
characters fails with a L{ValueError}.
"""
for c in NONASCII:
uri = b"http://twisted%s.invalid/OK" % (bytearray([c]),)
with self.assertRaises(ValueError) as cm:
self.attemptRequestWithMaliciousURI(uri)
self.assertRegex(str(cm.exception), "^Invalid URI")
def test_pathWithCRLFRejected(self):
"""
Issuing a request with a URI whose path contains a carriage
return and line feed fails with a L{ValueError}.
"""
with self.assertRaises(ValueError) as cm:
uri = b"http://twisted.invalid/\r\npath"
self.attemptRequestWithMaliciousURI(uri)
self.assertRegex(str(cm.exception), "^Invalid URI")
def test_pathWithWithUnprintableASCIIRejected(self):
"""
Issuing a request with a URI whose path contains unprintable
ASCII characters fails with a L{ValueError}.
"""
for c in UNPRINTABLE_ASCII:
uri = b"http://twisted.invalid/OK%s" % (bytearray([c]),)
with self.assertRaises(ValueError) as cm:
self.attemptRequestWithMaliciousURI(uri)
self.assertRegex(str(cm.exception), "^Invalid URI")
def test_pathWithNonASCIIRejected(self):
"""
Issuing a request with a URI whose path contains non-ASCII
characters fails with a L{ValueError}.
"""
for c in NONASCII:
uri = b"http://twisted.invalid/OK%s" % (bytearray([c]),)
with self.assertRaises(ValueError) as cm:
self.attemptRequestWithMaliciousURI(uri)
self.assertRegex(str(cm.exception), "^Invalid URI")

View file

@ -0,0 +1,486 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Helpers related to HTTP requests, used by tests.
"""
from __future__ import division, absolute_import
__all__ = ['DummyChannel', 'DummyRequest']
from io import BytesIO
from zope.interface import implementer, verify
from twisted.python.compat import intToBytes
from twisted.python.deprecate import deprecated
from incremental import Version
from twisted.internet.defer import Deferred
from twisted.internet.address import IPv4Address, IPv6Address
from twisted.internet.interfaces import ISSLTransport, IAddress
from twisted.trial import unittest
from twisted.web.http_headers import Headers
from twisted.web.resource import Resource
from twisted.web.server import NOT_DONE_YET, Session, Site
from twisted.web._responses import FOUND
textLinearWhitespaceComponents = [
u"Foo%sbar" % (lw,) for lw in
[u'\r', u'\n', u'\r\n']
]
sanitizedText = "Foo bar"
bytesLinearWhitespaceComponents = [
component.encode('ascii') for component in
textLinearWhitespaceComponents
]
sanitizedBytes = sanitizedText.encode('ascii')
@implementer(IAddress)
class NullAddress(object):
"""
A null implementation of L{IAddress}.
"""
class DummyChannel:
class TCP:
port = 80
disconnected = False
def __init__(self, peer=None):
if peer is None:
peer = IPv4Address("TCP", '192.168.1.1', 12344)
self._peer = peer
self.written = BytesIO()
self.producers = []
def getPeer(self):
return self._peer
def write(self, data):
if not isinstance(data, bytes):
raise TypeError("Can only write bytes to a transport, not %r" % (data,))
self.written.write(data)
def writeSequence(self, iovec):
for data in iovec:
self.write(data)
def getHost(self):
return IPv4Address("TCP", '10.0.0.1', self.port)
def registerProducer(self, producer, streaming):
self.producers.append((producer, streaming))
def unregisterProducer(self):
pass
def loseConnection(self):
self.disconnected = True
@implementer(ISSLTransport)
class SSL(TCP):
pass
site = Site(Resource())
def __init__(self, peer=None):
self.transport = self.TCP(peer)
def requestDone(self, request):
pass
def writeHeaders(self, version, code, reason, headers):
response_line = version + b" " + code + b" " + reason + b"\r\n"
headerSequence = [response_line]
headerSequence.extend(
name + b': ' + value + b"\r\n" for name, value in headers
)
headerSequence.append(b"\r\n")
self.transport.writeSequence(headerSequence)
def getPeer(self):
return self.transport.getPeer()
def getHost(self):
return self.transport.getHost()
def registerProducer(self, producer, streaming):
self.transport.registerProducer(producer, streaming)
def unregisterProducer(self):
self.transport.unregisterProducer()
def write(self, data):
self.transport.write(data)
def writeSequence(self, iovec):
self.transport.writeSequence(iovec)
def loseConnection(self):
self.transport.loseConnection()
def endRequest(self):
pass
def isSecure(self):
return isinstance(self.transport, self.SSL)
class DummyRequest(object):
"""
Represents a dummy or fake request. See L{twisted.web.server.Request}.
@ivar _finishedDeferreds: L{None} or a C{list} of L{Deferreds} which will
be called back with L{None} when C{finish} is called or which will be
errbacked if C{processingFailed} is called.
@type requestheaders: C{Headers}
@ivar requestheaders: A Headers instance that stores values for all request
headers.
@type responseHeaders: C{Headers}
@ivar responseHeaders: A Headers instance that stores values for all
response headers.
@type responseCode: C{int}
@ivar responseCode: The response code which was passed to
C{setResponseCode}.
@type written: C{list} of C{bytes}
@ivar written: The bytes which have been written to the request.
"""
uri = b'http://dummy/'
method = b'GET'
client = None
def registerProducer(self, prod, s):
"""
Call an L{IPullProducer}'s C{resumeProducing} method in a
loop until it unregisters itself.
@param prod: The producer.
@type prod: L{IPullProducer}
@param s: Whether or not the producer is streaming.
"""
# XXX: Handle IPushProducers
self.go = 1
while self.go:
prod.resumeProducing()
def unregisterProducer(self):
self.go = 0
def __init__(self, postpath, session=None, client=None):
self.sitepath = []
self.written = []
self.finished = 0
self.postpath = postpath
self.prepath = []
self.session = None
self.protoSession = session or Session(0, self)
self.args = {}
self.requestHeaders = Headers()
self.responseHeaders = Headers()
self.responseCode = None
self._finishedDeferreds = []
self._serverName = b"dummy"
self.clientproto = b"HTTP/1.0"
def getAllHeaders(self):
"""
Return dictionary mapping the names of all received headers to the last
value received for each.
Since this method does not return all header information,
C{self.requestHeaders.getAllRawHeaders()} may be preferred.
NOTE: This function is a direct copy of
C{twisted.web.http.Request.getAllRawHeaders}.
"""
headers = {}
for k, v in self.requestHeaders.getAllRawHeaders():
headers[k.lower()] = v[-1]
return headers
def getHeader(self, name):
"""
Retrieve the value of a request header.
@type name: C{bytes}
@param name: The name of the request header for which to retrieve the
value. Header names are compared case-insensitively.
@rtype: C{bytes} or L{None}
@return: The value of the specified request header.
"""
return self.requestHeaders.getRawHeaders(name.lower(), [None])[0]
def setHeader(self, name, value):
"""TODO: make this assert on write() if the header is content-length
"""
self.responseHeaders.addRawHeader(name, value)
def getSession(self):
if self.session:
return self.session
assert not self.written, "Session cannot be requested after data has been written."
self.session = self.protoSession
return self.session
def render(self, resource):
"""
Render the given resource as a response to this request.
This implementation only handles a few of the most common behaviors of
resources. It can handle a render method that returns a string or
C{NOT_DONE_YET}. It doesn't know anything about the semantics of
request methods (eg HEAD) nor how to set any particular headers.
Basically, it's largely broken, but sufficient for some tests at least.
It should B{not} be expanded to do all the same stuff L{Request} does.
Instead, L{DummyRequest} should be phased out and L{Request} (or some
other real code factored in a different way) used.
"""
result = resource.render(self)
if result is NOT_DONE_YET:
return
self.write(result)
self.finish()
def write(self, data):
if not isinstance(data, bytes):
raise TypeError("write() only accepts bytes")
self.written.append(data)
def notifyFinish(self):
"""
Return a L{Deferred} which is called back with L{None} when the request
is finished. This will probably only work if you haven't called
C{finish} yet.
"""
finished = Deferred()
self._finishedDeferreds.append(finished)
return finished
def finish(self):
"""
Record that the request is finished and callback and L{Deferred}s
waiting for notification of this.
"""
self.finished = self.finished + 1
if self._finishedDeferreds is not None:
observers = self._finishedDeferreds
self._finishedDeferreds = None
for obs in observers:
obs.callback(None)
def processingFailed(self, reason):
"""
Errback and L{Deferreds} waiting for finish notification.
"""
if self._finishedDeferreds is not None:
observers = self._finishedDeferreds
self._finishedDeferreds = None
for obs in observers:
obs.errback(reason)
def addArg(self, name, value):
self.args[name] = [value]
def setResponseCode(self, code, message=None):
"""
Set the HTTP status response code, but takes care that this is called
before any data is written.
"""
assert not self.written, "Response code cannot be set after data has been written: %s." % "@@@@".join(self.written)
self.responseCode = code
self.responseMessage = message
def setLastModified(self, when):
assert not self.written, "Last-Modified cannot be set after data has been written: %s." % "@@@@".join(self.written)
def setETag(self, tag):
assert not self.written, "ETag cannot be set after data has been written: %s." % "@@@@".join(self.written)
def getClientIP(self):
"""
Return the IPv4 address of the client which made this request, if there
is one, otherwise L{None}.
"""
if isinstance(self.client, (IPv4Address, IPv6Address)):
return self.client.host
return None
def getClientAddress(self):
"""
Return the L{IAddress} of the client that made this request.
@return: an address.
@rtype: an L{IAddress} provider.
"""
if self.client is None:
return NullAddress()
return self.client
def getRequestHostname(self):
"""
Get a dummy hostname associated to the HTTP request.
@rtype: C{bytes}
@returns: a dummy hostname
"""
return self._serverName
def getHost(self):
"""
Get a dummy transport's host.
@rtype: C{IPv4Address}
@returns: a dummy transport's host
"""
return IPv4Address('TCP', '127.0.0.1', 80)
def setHost(self, host, port, ssl=0):
"""
Change the host and port the request thinks it's using.
@type host: C{bytes}
@param host: The value to which to change the host header.
@type ssl: C{bool}
@param ssl: A flag which, if C{True}, indicates that the request is
considered secure (if C{True}, L{isSecure} will return C{True}).
"""
self._forceSSL = ssl # set first so isSecure will work
if self.isSecure():
default = 443
else:
default = 80
if port == default:
hostHeader = host
else:
hostHeader = host + b":" + intToBytes(port)
self.requestHeaders.addRawHeader(b"host", hostHeader)
def redirect(self, url):
"""
Utility function that does a redirect.
The request should have finish() called after this.
"""
self.setResponseCode(FOUND)
self.setHeader(b"location", url)
DummyRequest.getClientIP = deprecated(
Version('Twisted', 18, 4, 0),
replacement="getClientAddress",
)(DummyRequest.getClientIP)
class DummyRequestTests(unittest.SynchronousTestCase):
"""
Tests for L{DummyRequest}.
"""
def test_getClientIPDeprecated(self):
"""
L{DummyRequest.getClientIP} is deprecated in favor of
L{DummyRequest.getClientAddress}
"""
request = DummyRequest([])
request.getClientIP()
warnings = self.flushWarnings(
offendingFunctions=[self.test_getClientIPDeprecated])
self.assertEqual(1, len(warnings))
[warning] = warnings
self.assertEqual(warning.get("category"), DeprecationWarning)
self.assertEqual(
warning.get("message"),
("twisted.web.test.requesthelper.DummyRequest.getClientIP "
"was deprecated in Twisted 18.4.0; "
"please use getClientAddress instead"),
)
def test_getClientIPSupportsIPv6(self):
"""
L{DummyRequest.getClientIP} supports IPv6 addresses, just like
L{twisted.web.http.Request.getClientIP}.
"""
request = DummyRequest([])
client = IPv6Address("TCP", "::1", 12345)
request.client = client
self.assertEqual("::1", request.getClientIP())
def test_getClientAddressWithoutClient(self):
"""
L{DummyRequest.getClientAddress} returns an L{IAddress}
provider no C{client} has been set.
"""
request = DummyRequest([])
null = request.getClientAddress()
verify.verifyObject(IAddress, null)
def test_getClientAddress(self):
"""
L{DummyRequest.getClientAddress} returns the C{client}.
"""
request = DummyRequest([])
client = IPv4Address("TCP", "127.0.0.1", 12345)
request.client = client
address = request.getClientAddress()
self.assertIs(address, client)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,462 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.web.twcgi}.
"""
import sys
import os
import json
from io import BytesIO
from twisted.trial import unittest
from twisted.internet import address, reactor, interfaces, error
from twisted.python import util, failure, log
from twisted.web.http import NOT_FOUND, INTERNAL_SERVER_ERROR
from twisted.web import client, twcgi, server, resource, http_headers
from twisted.web.test._util import _render
from twisted.web.test.test_web import DummyRequest
DUMMY_CGI = '''\
print("Header: OK")
print("")
print("cgi output")
'''
DUAL_HEADER_CGI = '''\
print("Header: spam")
print("Header: eggs")
print("")
print("cgi output")
'''
BROKEN_HEADER_CGI = '''\
print("XYZ")
print("")
print("cgi output")
'''
SPECIAL_HEADER_CGI = '''\
print("Server: monkeys")
print("Date: last year")
print("")
print("cgi output")
'''
READINPUT_CGI = '''\
# This is an example of a correctly-written CGI script which reads a body
# from stdin, which only reads env['CONTENT_LENGTH'] bytes.
import os, sys
body_length = int(os.environ.get('CONTENT_LENGTH',0))
indata = sys.stdin.read(body_length)
print("Header: OK")
print("")
print("readinput ok")
'''
READALLINPUT_CGI = '''\
# This is an example of the typical (incorrect) CGI script which expects
# the server to close stdin when the body of the request is complete.
# A correct CGI should only read env['CONTENT_LENGTH'] bytes.
import sys
indata = sys.stdin.read()
print("Header: OK")
print("")
print("readallinput ok")
'''
NO_DUPLICATE_CONTENT_TYPE_HEADER_CGI = '''\
print("content-type: text/cgi-duplicate-test")
print("")
print("cgi output")
'''
HEADER_OUTPUT_CGI = '''\
import json
import os
print("")
print("")
vals = {x:y for x,y in os.environ.items() if x.startswith("HTTP_")}
print(json.dumps(vals))
'''
class PythonScript(twcgi.FilteredScript):
filter = sys.executable
class CGITests(unittest.TestCase):
"""
Tests for L{twcgi.FilteredScript}.
"""
if not interfaces.IReactorProcess.providedBy(reactor):
skip = "CGI tests require a functional reactor.spawnProcess()"
def startServer(self, cgi):
root = resource.Resource()
cgipath = util.sibpath(__file__, cgi)
root.putChild(b"cgi", PythonScript(cgipath))
site = server.Site(root)
self.p = reactor.listenTCP(0, site)
return self.p.getHost().port
def tearDown(self):
if getattr(self, 'p', None):
return self.p.stopListening()
def writeCGI(self, source):
cgiFilename = os.path.abspath(self.mktemp())
with open(cgiFilename, 'wt') as cgiFile:
cgiFile.write(source)
return cgiFilename
def test_CGI(self):
cgiFilename = self.writeCGI(DUMMY_CGI)
portnum = self.startServer(cgiFilename)
url = 'http://localhost:%d/cgi' % (portnum,)
url = url.encode("ascii")
d = client.Agent(reactor).request(b"GET", url)
d.addCallback(client.readBody)
d.addCallback(self._testCGI_1)
return d
def _testCGI_1(self, res):
self.assertEqual(res, b"cgi output" + os.linesep.encode("ascii"))
def test_protectedServerAndDate(self):
"""
If the CGI script emits a I{Server} or I{Date} header, these are
ignored.
"""
cgiFilename = self.writeCGI(SPECIAL_HEADER_CGI)
portnum = self.startServer(cgiFilename)
url = "http://localhost:%d/cgi" % (portnum,)
url = url.encode("ascii")
agent = client.Agent(reactor)
d = agent.request(b"GET", url)
d.addCallback(discardBody)
def checkResponse(response):
self.assertNotIn('monkeys',
response.headers.getRawHeaders('server'))
self.assertNotIn('last year',
response.headers.getRawHeaders('date'))
d.addCallback(checkResponse)
return d
def test_noDuplicateContentTypeHeaders(self):
"""
If the CGI script emits a I{content-type} header, make sure that the
server doesn't add an additional (duplicate) one, as per ticket 4786.
"""
cgiFilename = self.writeCGI(NO_DUPLICATE_CONTENT_TYPE_HEADER_CGI)
portnum = self.startServer(cgiFilename)
url = "http://localhost:%d/cgi" % (portnum,)
url = url.encode("ascii")
agent = client.Agent(reactor)
d = agent.request(b"GET", url)
d.addCallback(discardBody)
def checkResponse(response):
self.assertEqual(
response.headers.getRawHeaders('content-type'),
['text/cgi-duplicate-test'])
return response
d.addCallback(checkResponse)
return d
def test_noProxyPassthrough(self):
"""
The CGI script is never called with the Proxy header passed through.
"""
cgiFilename = self.writeCGI(HEADER_OUTPUT_CGI)
portnum = self.startServer(cgiFilename)
url = "http://localhost:%d/cgi" % (portnum,)
url = url.encode("ascii")
agent = client.Agent(reactor)
headers = http_headers.Headers({b"Proxy": [b"foo"],
b"X-Innocent-Header": [b"bar"]})
d = agent.request(b"GET", url, headers=headers)
def checkResponse(response):
headers = json.loads(response.decode("ascii"))
self.assertEqual(
set(headers.keys()),
{"HTTP_HOST", "HTTP_CONNECTION", "HTTP_X_INNOCENT_HEADER"})
d.addCallback(client.readBody)
d.addCallback(checkResponse)
return d
def test_duplicateHeaderCGI(self):
"""
If a CGI script emits two instances of the same header, both are sent
in the response.
"""
cgiFilename = self.writeCGI(DUAL_HEADER_CGI)
portnum = self.startServer(cgiFilename)
url = "http://localhost:%d/cgi" % (portnum,)
url = url.encode("ascii")
agent = client.Agent(reactor)
d = agent.request(b"GET", url)
d.addCallback(discardBody)
def checkResponse(response):
self.assertEqual(
response.headers.getRawHeaders('header'), ['spam', 'eggs'])
d.addCallback(checkResponse)
return d
def test_malformedHeaderCGI(self):
"""
Check for the error message in the duplicated header
"""
cgiFilename = self.writeCGI(BROKEN_HEADER_CGI)
portnum = self.startServer(cgiFilename)
url = "http://localhost:%d/cgi" % (portnum,)
url = url.encode("ascii")
agent = client.Agent(reactor)
d = agent.request(b"GET", url)
d.addCallback(discardBody)
loggedMessages = []
def addMessage(eventDict):
loggedMessages.append(log.textFromEventDict(eventDict))
log.addObserver(addMessage)
self.addCleanup(log.removeObserver, addMessage)
def checkResponse(ignored):
self.assertIn("ignoring malformed CGI header: " + repr(b'XYZ'),
loggedMessages)
d.addCallback(checkResponse)
return d
def test_ReadEmptyInput(self):
cgiFilename = os.path.abspath(self.mktemp())
with open(cgiFilename, 'wt') as cgiFile:
cgiFile.write(READINPUT_CGI)
portnum = self.startServer(cgiFilename)
agent = client.Agent(reactor)
url = "http://localhost:%d/cgi" % (portnum,)
url = url.encode("ascii")
d = agent.request(b"GET", url)
d.addCallback(client.readBody)
d.addCallback(self._test_ReadEmptyInput_1)
return d
test_ReadEmptyInput.timeout = 5
def _test_ReadEmptyInput_1(self, res):
expected = "readinput ok{}".format(os.linesep)
expected = expected.encode("ascii")
self.assertEqual(res, expected)
def test_ReadInput(self):
cgiFilename = os.path.abspath(self.mktemp())
with open(cgiFilename, 'wt') as cgiFile:
cgiFile.write(READINPUT_CGI)
portnum = self.startServer(cgiFilename)
agent = client.Agent(reactor)
url = "http://localhost:%d/cgi" % (portnum,)
url = url.encode("ascii")
d = agent.request(
uri=url,
method=b"POST",
bodyProducer=client.FileBodyProducer(
BytesIO(b"Here is your stdin")),
)
d.addCallback(client.readBody)
d.addCallback(self._test_ReadInput_1)
return d
test_ReadInput.timeout = 5
def _test_ReadInput_1(self, res):
expected = "readinput ok{}".format(os.linesep)
expected = expected.encode("ascii")
self.assertEqual(res, expected)
def test_ReadAllInput(self):
cgiFilename = os.path.abspath(self.mktemp())
with open(cgiFilename, 'wt') as cgiFile:
cgiFile.write(READALLINPUT_CGI)
portnum = self.startServer(cgiFilename)
url = "http://localhost:%d/cgi" % (portnum,)
url = url.encode("ascii")
d = client.Agent(reactor).request(
uri=url,
method=b"POST",
bodyProducer=client.FileBodyProducer(
BytesIO(b"Here is your stdin")),
)
d.addCallback(client.readBody)
d.addCallback(self._test_ReadAllInput_1)
return d
test_ReadAllInput.timeout = 5
def _test_ReadAllInput_1(self, res):
expected = "readallinput ok{}".format(os.linesep)
expected = expected.encode("ascii")
self.assertEqual(res, expected)
def test_useReactorArgument(self):
"""
L{twcgi.FilteredScript.runProcess} uses the reactor passed as an
argument to the constructor.
"""
class FakeReactor:
"""
A fake reactor recording whether spawnProcess is called.
"""
called = False
def spawnProcess(self, *args, **kwargs):
"""
Set the C{called} flag to C{True} if C{spawnProcess} is called.
@param args: Positional arguments.
@param kwargs: Keyword arguments.
"""
self.called = True
fakeReactor = FakeReactor()
request = DummyRequest(['a', 'b'])
request.client = address.IPv4Address('TCP', '127.0.0.1', 12345)
resource = twcgi.FilteredScript("dummy-file", reactor=fakeReactor)
_render(resource, request)
self.assertTrue(fakeReactor.called)
class CGIScriptTests(unittest.TestCase):
"""
Tests for L{twcgi.CGIScript}.
"""
def test_pathInfo(self):
"""
L{twcgi.CGIScript.render} sets the process environment
I{PATH_INFO} from the request path.
"""
class FakeReactor:
"""
A fake reactor recording the environment passed to spawnProcess.
"""
def spawnProcess(self, process, filename, args, env, wdir):
"""
Store the C{env} L{dict} to an instance attribute.
@param process: Ignored
@param filename: Ignored
@param args: Ignored
@param env: The environment L{dict} which will be stored
@param wdir: Ignored
"""
self.process_env = env
_reactor = FakeReactor()
resource = twcgi.CGIScript(self.mktemp(), reactor=_reactor)
request = DummyRequest(['a', 'b'])
request.client = address.IPv4Address('TCP', '127.0.0.1', 12345)
_render(resource, request)
self.assertEqual(_reactor.process_env["PATH_INFO"],
"/a/b")
class CGIDirectoryTests(unittest.TestCase):
"""
Tests for L{twcgi.CGIDirectory}.
"""
def test_render(self):
"""
L{twcgi.CGIDirectory.render} sets the HTTP response code to I{NOT
FOUND}.
"""
resource = twcgi.CGIDirectory(self.mktemp())
request = DummyRequest([''])
d = _render(resource, request)
def cbRendered(ignored):
self.assertEqual(request.responseCode, NOT_FOUND)
d.addCallback(cbRendered)
return d
def test_notFoundChild(self):
"""
L{twcgi.CGIDirectory.getChild} returns a resource which renders an
response with the HTTP I{NOT FOUND} status code if the indicated child
does not exist as an entry in the directory used to initialized the
L{twcgi.CGIDirectory}.
"""
path = self.mktemp()
os.makedirs(path)
resource = twcgi.CGIDirectory(path)
request = DummyRequest(['foo'])
child = resource.getChild("foo", request)
d = _render(child, request)
def cbRendered(ignored):
self.assertEqual(request.responseCode, NOT_FOUND)
d.addCallback(cbRendered)
return d
class CGIProcessProtocolTests(unittest.TestCase):
"""
Tests for L{twcgi.CGIProcessProtocol}.
"""
def test_prematureEndOfHeaders(self):
"""
If the process communicating with L{CGIProcessProtocol} ends before
finishing writing out headers, the response has I{INTERNAL SERVER
ERROR} as its status code.
"""
request = DummyRequest([''])
protocol = twcgi.CGIProcessProtocol(request)
protocol.processEnded(failure.Failure(error.ProcessTerminated()))
self.assertEqual(request.responseCode, INTERNAL_SERVER_ERROR)
def discardBody(response):
"""
Discard the body of a HTTP response.
@param response: The response.
@return: The response.
"""
return client.readBody(response).addCallback(lambda _: response)

View file

@ -0,0 +1,45 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for various parts of L{twisted.web}.
"""
from zope.interface import implementer, verify
from twisted.internet import defer, interfaces
from twisted.trial import unittest
from twisted.web import client
@implementer(interfaces.IStreamClientEndpoint)
class DummyEndPoint(object):
"""An endpoint that does not connect anywhere"""
def __init__(self, someString):
self.someString = someString
def __repr__(self):
return 'DummyEndPoint({})'.format(self.someString)
def connect(self, factory):
return defer.succeed(dict(factory=factory))
class HTTPConnectionPoolTests(unittest.TestCase):
"""
Unit tests for L{client.HTTPConnectionPoolTest}.
"""
def test_implements(self):
"""L{DummyEndPoint}s implements L{interfaces.IStreamClientEndpoint}"""
ep = DummyEndPoint("something")
verify.verifyObject(interfaces.IStreamClientEndpoint, ep)
def test_repr(self):
"""connection L{repr()} includes endpoint's L{repr()}"""
pool = client.HTTPConnectionPool(reactor=None)
ep = DummyEndPoint("this_is_probably_unique")
d = pool.getConnection('someplace', ep)
result = self.successResultOf(d)
representation = repr(result)
self.assertIn(repr(ep), representation)

View file

@ -0,0 +1,527 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.web.distrib}.
"""
from os.path import abspath
from xml.dom.minidom import parseString
try:
import pwd
except ImportError:
pwd = None
from zope.interface.verify import verifyObject
from twisted.python import filepath, failure
from twisted.internet import reactor, defer
from twisted.trial import unittest
from twisted.spread import pb
from twisted.spread.banana import SIZE_LIMIT
from twisted.web import distrib, client, resource, static, server
from twisted.web.test.test_web import DummyRequest, DummyChannel
from twisted.web.test._util import _render
from twisted.test import proto_helpers
from twisted.web.http_headers import Headers
from twisted.logger import globalLogPublisher
class MySite(server.Site):
pass
class PBServerFactory(pb.PBServerFactory):
"""
A PB server factory which keeps track of the most recent protocol it
created.
@ivar proto: L{None} or the L{Broker} instance most recently returned
from C{buildProtocol}.
"""
proto = None
def buildProtocol(self, addr):
self.proto = pb.PBServerFactory.buildProtocol(self, addr)
return self.proto
class ArbitraryError(Exception):
"""
An exception for this test.
"""
class DistribTests(unittest.TestCase):
port1 = None
port2 = None
sub = None
f1 = None
def tearDown(self):
"""
Clean up all the event sources left behind by either directly by
test methods or indirectly via some distrib API.
"""
dl = [defer.Deferred(), defer.Deferred()]
if self.f1 is not None and self.f1.proto is not None:
self.f1.proto.notifyOnDisconnect(lambda: dl[0].callback(None))
else:
dl[0].callback(None)
if self.sub is not None and self.sub.publisher is not None:
self.sub.publisher.broker.notifyOnDisconnect(
lambda: dl[1].callback(None))
self.sub.publisher.broker.transport.loseConnection()
else:
dl[1].callback(None)
if self.port1 is not None:
dl.append(self.port1.stopListening())
if self.port2 is not None:
dl.append(self.port2.stopListening())
return defer.gatherResults(dl)
def testDistrib(self):
# site1 is the publisher
r1 = resource.Resource()
r1.putChild(b"there", static.Data(b"root", "text/plain"))
site1 = server.Site(r1)
self.f1 = PBServerFactory(distrib.ResourcePublisher(site1))
self.port1 = reactor.listenTCP(0, self.f1)
self.sub = distrib.ResourceSubscription("127.0.0.1",
self.port1.getHost().port)
r2 = resource.Resource()
r2.putChild(b"here", self.sub)
f2 = MySite(r2)
self.port2 = reactor.listenTCP(0, f2)
agent = client.Agent(reactor)
url = "http://127.0.0.1:{}/here/there".format(
self.port2.getHost().port)
url = url.encode("ascii")
d = agent.request(b"GET", url)
d.addCallback(client.readBody)
d.addCallback(self.assertEqual, b'root')
return d
def _setupDistribServer(self, child):
"""
Set up a resource on a distrib site using L{ResourcePublisher}.
@param child: The resource to publish using distrib.
@return: A tuple consisting of the host and port on which to contact
the created site.
"""
distribRoot = resource.Resource()
distribRoot.putChild(b"child", child)
distribSite = server.Site(distribRoot)
self.f1 = distribFactory = PBServerFactory(
distrib.ResourcePublisher(distribSite))
distribPort = reactor.listenTCP(
0, distribFactory, interface="127.0.0.1")
self.addCleanup(distribPort.stopListening)
addr = distribPort.getHost()
self.sub = mainRoot = distrib.ResourceSubscription(
addr.host, addr.port)
mainSite = server.Site(mainRoot)
mainPort = reactor.listenTCP(0, mainSite, interface="127.0.0.1")
self.addCleanup(mainPort.stopListening)
mainAddr = mainPort.getHost()
return mainPort, mainAddr
def _requestTest(self, child, **kwargs):
"""
Set up a resource on a distrib site using L{ResourcePublisher} and
then retrieve it from a L{ResourceSubscription} via an HTTP client.
@param child: The resource to publish using distrib.
@param **kwargs: Extra keyword arguments to pass to L{Agent.request} when
requesting the resource.
@return: A L{Deferred} which fires with the result of the request.
"""
mainPort, mainAddr = self._setupDistribServer(child)
agent = client.Agent(reactor)
url = "http://%s:%s/child" % (mainAddr.host, mainAddr.port)
url = url.encode("ascii")
d = agent.request(b"GET", url, **kwargs)
d.addCallback(client.readBody)
return d
def _requestAgentTest(self, child, **kwargs):
"""
Set up a resource on a distrib site using L{ResourcePublisher} and
then retrieve it from a L{ResourceSubscription} via an HTTP client.
@param child: The resource to publish using distrib.
@param **kwargs: Extra keyword arguments to pass to L{Agent.request} when
requesting the resource.
@return: A L{Deferred} which fires with a tuple consisting of a
L{twisted.test.proto_helpers.AccumulatingProtocol} containing the
body of the response and an L{IResponse} with the response itself.
"""
mainPort, mainAddr = self._setupDistribServer(child)
url = "http://{}:{}/child".format(mainAddr.host, mainAddr.port)
url = url.encode("ascii")
d = client.Agent(reactor).request(b"GET", url, **kwargs)
def cbCollectBody(response):
protocol = proto_helpers.AccumulatingProtocol()
response.deliverBody(protocol)
d = protocol.closedDeferred = defer.Deferred()
d.addCallback(lambda _: (protocol, response))
return d
d.addCallback(cbCollectBody)
return d
def test_requestHeaders(self):
"""
The request headers are available on the request object passed to a
distributed resource's C{render} method.
"""
requestHeaders = {}
logObserver = proto_helpers.EventLoggingObserver()
globalLogPublisher.addObserver(logObserver)
req = [None]
class ReportRequestHeaders(resource.Resource):
def render(self, request):
req[0] = request
requestHeaders.update(dict(
request.requestHeaders.getAllRawHeaders()))
return b""
def check_logs():
msgs = [e["log_format"] for e in logObserver]
self.assertIn('connected to publisher', msgs)
self.assertIn(
"could not connect to distributed web service: {msg}",
msgs
)
self.assertIn(req[0], msgs)
globalLogPublisher.removeObserver(logObserver)
request = self._requestTest(
ReportRequestHeaders(), headers=Headers({'foo': ['bar']}))
def cbRequested(result):
self.f1.proto.notifyOnDisconnect(check_logs)
self.assertEqual(requestHeaders[b'Foo'], [b'bar'])
request.addCallback(cbRequested)
return request
def test_requestResponseCode(self):
"""
The response code can be set by the request object passed to a
distributed resource's C{render} method.
"""
class SetResponseCode(resource.Resource):
def render(self, request):
request.setResponseCode(200)
return ""
request = self._requestAgentTest(SetResponseCode())
def cbRequested(result):
self.assertEqual(result[0].data, b"")
self.assertEqual(result[1].code, 200)
self.assertEqual(result[1].phrase, b"OK")
request.addCallback(cbRequested)
return request
def test_requestResponseCodeMessage(self):
"""
The response code and message can be set by the request object passed to
a distributed resource's C{render} method.
"""
class SetResponseCode(resource.Resource):
def render(self, request):
request.setResponseCode(200, b"some-message")
return ""
request = self._requestAgentTest(SetResponseCode())
def cbRequested(result):
self.assertEqual(result[0].data, b"")
self.assertEqual(result[1].code, 200)
self.assertEqual(result[1].phrase, b"some-message")
request.addCallback(cbRequested)
return request
def test_largeWrite(self):
"""
If a string longer than the Banana size limit is passed to the
L{distrib.Request} passed to the remote resource, it is broken into
smaller strings to be transported over the PB connection.
"""
class LargeWrite(resource.Resource):
def render(self, request):
request.write(b'x' * SIZE_LIMIT + b'y')
request.finish()
return server.NOT_DONE_YET
request = self._requestTest(LargeWrite())
request.addCallback(self.assertEqual, b'x' * SIZE_LIMIT + b'y')
return request
def test_largeReturn(self):
"""
Like L{test_largeWrite}, but for the case where C{render} returns a
long string rather than explicitly passing it to L{Request.write}.
"""
class LargeReturn(resource.Resource):
def render(self, request):
return b'x' * SIZE_LIMIT + b'y'
request = self._requestTest(LargeReturn())
request.addCallback(self.assertEqual, b'x' * SIZE_LIMIT + b'y')
return request
def test_connectionLost(self):
"""
If there is an error issuing the request to the remote publisher, an
error response is returned.
"""
# Using pb.Root as a publisher will cause request calls to fail with an
# error every time. Just what we want to test.
self.f1 = serverFactory = PBServerFactory(pb.Root())
self.port1 = serverPort = reactor.listenTCP(0, serverFactory)
self.sub = subscription = distrib.ResourceSubscription(
"127.0.0.1", serverPort.getHost().port)
request = DummyRequest([b''])
d = _render(subscription, request)
def cbRendered(ignored):
self.assertEqual(request.responseCode, 500)
# This is the error we caused the request to fail with. It should
# have been logged.
errors = self.flushLoggedErrors(pb.NoSuchMethod)
self.assertEqual(len(errors), 1)
# The error page is rendered as HTML.
expected = [
b'',
b'<html>',
b' <head><title>500 - Server Connection Lost</title></head>',
b' <body>',
b' <h1>Server Connection Lost</h1>',
b' <p>Connection to distributed server lost:'
b'<pre>'
b'[Failure instance: Traceback from remote host -- '
b'twisted.spread.flavors.NoSuchMethod: '
b'No such method: remote_request',
b']</pre></p>',
b' </body>',
b'</html>',
b''
]
self.assertEqual([b'\n'.join(expected)], request.written)
d.addCallback(cbRendered)
return d
def test_logFailed(self):
"""
When a request fails, the string form of the failure is logged.
"""
logObserver = proto_helpers.EventLoggingObserver.createWithCleanup(
self,
globalLogPublisher
)
f = failure.Failure(ArbitraryError())
request = DummyRequest([b''])
issue = distrib.Issue(request)
issue.failed(f)
self.assertEquals(1, len(logObserver))
self.assertIn(
"Failure instance",
logObserver[0]["log_format"]
)
def test_requestFail(self):
"""
When L{twisted.web.distrib.Request}'s fail is called, the failure
is logged.
"""
logObserver = proto_helpers.EventLoggingObserver.createWithCleanup(
self,
globalLogPublisher
)
err = ArbitraryError()
f = failure.Failure(err)
req = distrib.Request(DummyChannel())
req.fail(f)
self.flushLoggedErrors(ArbitraryError)
self.assertEquals(1, len(logObserver))
self.assertIs(logObserver[0]["log_failure"], f)
class _PasswordDatabase:
def __init__(self, users):
self._users = users
def getpwall(self):
return iter(self._users)
def getpwnam(self, username):
for user in self._users:
if user[0] == username:
return user
raise KeyError()
class UserDirectoryTests(unittest.TestCase):
"""
Tests for L{UserDirectory}, a resource for listing all user resources
available on a system.
"""
def setUp(self):
self.alice = ('alice', 'x', 123, 456, 'Alice,,,', self.mktemp(), '/bin/sh')
self.bob = ('bob', 'x', 234, 567, 'Bob,,,', self.mktemp(), '/bin/sh')
self.database = _PasswordDatabase([self.alice, self.bob])
self.directory = distrib.UserDirectory(self.database)
def test_interface(self):
"""
L{UserDirectory} instances provide L{resource.IResource}.
"""
self.assertTrue(verifyObject(resource.IResource, self.directory))
def _404Test(self, name):
"""
Verify that requesting the C{name} child of C{self.directory} results
in a 404 response.
"""
request = DummyRequest([name])
result = self.directory.getChild(name, request)
d = _render(result, request)
def cbRendered(ignored):
self.assertEqual(request.responseCode, 404)
d.addCallback(cbRendered)
return d
def test_getInvalidUser(self):
"""
L{UserDirectory.getChild} returns a resource which renders a 404
response when passed a string which does not correspond to any known
user.
"""
return self._404Test('carol')
def test_getUserWithoutResource(self):
"""
L{UserDirectory.getChild} returns a resource which renders a 404
response when passed a string which corresponds to a known user who has
neither a user directory nor a user distrib socket.
"""
return self._404Test('alice')
def test_getPublicHTMLChild(self):
"""
L{UserDirectory.getChild} returns a L{static.File} instance when passed
the name of a user with a home directory containing a I{public_html}
directory.
"""
home = filepath.FilePath(self.bob[-2])
public_html = home.child('public_html')
public_html.makedirs()
request = DummyRequest(['bob'])
result = self.directory.getChild('bob', request)
self.assertIsInstance(result, static.File)
self.assertEqual(result.path, public_html.path)
def test_getDistribChild(self):
"""
L{UserDirectory.getChild} returns a L{ResourceSubscription} instance
when passed the name of a user suffixed with C{".twistd"} who has a
home directory containing a I{.twistd-web-pb} socket.
"""
home = filepath.FilePath(self.bob[-2])
home.makedirs()
web = home.child('.twistd-web-pb')
request = DummyRequest(['bob'])
result = self.directory.getChild('bob.twistd', request)
self.assertIsInstance(result, distrib.ResourceSubscription)
self.assertEqual(result.host, 'unix')
self.assertEqual(abspath(result.port), web.path)
def test_invalidMethod(self):
"""
L{UserDirectory.render} raises L{UnsupportedMethod} in response to a
non-I{GET} request.
"""
request = DummyRequest([''])
request.method = 'POST'
self.assertRaises(
server.UnsupportedMethod, self.directory.render, request)
def test_render(self):
"""
L{UserDirectory} renders a list of links to available user content
in response to a I{GET} request.
"""
public_html = filepath.FilePath(self.alice[-2]).child('public_html')
public_html.makedirs()
web = filepath.FilePath(self.bob[-2])
web.makedirs()
# This really only works if it's a unix socket, but the implementation
# doesn't currently check for that. It probably should someday, and
# then skip users with non-sockets.
web.child('.twistd-web-pb').setContent(b"")
request = DummyRequest([''])
result = _render(self.directory, request)
def cbRendered(ignored):
document = parseString(b''.join(request.written))
# Each user should have an li with a link to their page.
[alice, bob] = document.getElementsByTagName('li')
self.assertEqual(alice.firstChild.tagName, 'a')
self.assertEqual(alice.firstChild.getAttribute('href'), 'alice/')
self.assertEqual(alice.firstChild.firstChild.data, 'Alice (file)')
self.assertEqual(bob.firstChild.tagName, 'a')
self.assertEqual(bob.firstChild.getAttribute('href'), 'bob.twistd/')
self.assertEqual(bob.firstChild.firstChild.data, 'Bob (twistd)')
result.addCallback(cbRendered)
return result
def test_passwordDatabase(self):
"""
If L{UserDirectory} is instantiated with no arguments, it uses the
L{pwd} module as its password database.
"""
directory = distrib.UserDirectory()
self.assertIdentical(directory._pwd, pwd)
if pwd is None:
test_passwordDatabase.skip = "pwd module required"

View file

@ -0,0 +1,305 @@
# -*- test-case-name: twisted.web.test.test_domhelpers -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Specific tests for (some of) the methods in L{twisted.web.domhelpers}.
"""
from xml.dom import minidom
from twisted.python.compat import unicode
from twisted.trial.unittest import TestCase
from twisted.web import domhelpers, microdom
class DOMHelpersTestsMixin:
"""
A mixin for L{TestCase} subclasses which defines test methods for
domhelpers functionality based on a DOM creation function provided by a
subclass.
"""
dom = None
def test_getElementsByTagName(self):
doc1 = self.dom.parseString('<foo/>')
actual = domhelpers.getElementsByTagName(doc1, 'foo')[0].nodeName
expected = 'foo'
self.assertEqual(actual, expected)
el1 = doc1.documentElement
actual = domhelpers.getElementsByTagName(el1, 'foo')[0].nodeName
self.assertEqual(actual, expected)
doc2_xml = '<a><foo in="a"/><b><foo in="b"/></b><c><foo in="c"/></c><foo in="d"/><foo in="ef"/><g><foo in="g"/><h><foo in="h"/></h></g></a>'
doc2 = self.dom.parseString(doc2_xml)
tag_list = domhelpers.getElementsByTagName(doc2, 'foo')
actual = ''.join([node.getAttribute('in') for node in tag_list])
expected = 'abcdefgh'
self.assertEqual(actual, expected)
el2 = doc2.documentElement
tag_list = domhelpers.getElementsByTagName(el2, 'foo')
actual = ''.join([node.getAttribute('in') for node in tag_list])
self.assertEqual(actual, expected)
doc3_xml = '''
<a><foo in="a"/>
<b><foo in="b"/>
<d><foo in="d"/>
<g><foo in="g"/></g>
<h><foo in="h"/></h>
</d>
<e><foo in="e"/>
<i><foo in="i"/></i>
</e>
</b>
<c><foo in="c"/>
<f><foo in="f"/>
<j><foo in="j"/></j>
</f>
</c>
</a>'''
doc3 = self.dom.parseString(doc3_xml)
tag_list = domhelpers.getElementsByTagName(doc3, 'foo')
actual = ''.join([node.getAttribute('in') for node in tag_list])
expected = 'abdgheicfj'
self.assertEqual(actual, expected)
el3 = doc3.documentElement
tag_list = domhelpers.getElementsByTagName(el3, 'foo')
actual = ''.join([node.getAttribute('in') for node in tag_list])
self.assertEqual(actual, expected)
doc4_xml = '<foo><bar></bar><baz><foo/></baz></foo>'
doc4 = self.dom.parseString(doc4_xml)
actual = domhelpers.getElementsByTagName(doc4, 'foo')
root = doc4.documentElement
expected = [root, root.childNodes[-1].childNodes[0]]
self.assertEqual(actual, expected)
actual = domhelpers.getElementsByTagName(root, 'foo')
self.assertEqual(actual, expected)
def test_gatherTextNodes(self):
doc1 = self.dom.parseString('<a>foo</a>')
actual = domhelpers.gatherTextNodes(doc1)
expected = 'foo'
self.assertEqual(actual, expected)
actual = domhelpers.gatherTextNodes(doc1.documentElement)
self.assertEqual(actual, expected)
doc2_xml = '<a>a<b>b</b><c>c</c>def<g>g<h>h</h></g></a>'
doc2 = self.dom.parseString(doc2_xml)
actual = domhelpers.gatherTextNodes(doc2)
expected = 'abcdefgh'
self.assertEqual(actual, expected)
actual = domhelpers.gatherTextNodes(doc2.documentElement)
self.assertEqual(actual, expected)
doc3_xml = ('<a>a<b>b<d>d<g>g</g><h>h</h></d><e>e<i>i</i></e></b>' +
'<c>c<f>f<j>j</j></f></c></a>')
doc3 = self.dom.parseString(doc3_xml)
actual = domhelpers.gatherTextNodes(doc3)
expected = 'abdgheicfj'
self.assertEqual(actual, expected)
actual = domhelpers.gatherTextNodes(doc3.documentElement)
self.assertEqual(actual, expected)
def test_clearNode(self):
doc1 = self.dom.parseString('<a><b><c><d/></c></b></a>')
a_node = doc1.documentElement
domhelpers.clearNode(a_node)
self.assertEqual(
a_node.toxml(),
self.dom.Element('a').toxml())
doc2 = self.dom.parseString('<a><b><c><d/></c></b></a>')
b_node = doc2.documentElement.childNodes[0]
domhelpers.clearNode(b_node)
actual = doc2.documentElement.toxml()
expected = self.dom.Element('a')
expected.appendChild(self.dom.Element('b'))
self.assertEqual(actual, expected.toxml())
def test_get(self):
doc1 = self.dom.parseString('<a><b id="bar"/><c class="foo"/></a>')
doc = self.dom.Document()
node = domhelpers.get(doc1, "foo")
actual = node.toxml()
expected = doc.createElement('c')
expected.setAttribute('class', 'foo')
self.assertEqual(actual, expected.toxml())
node = domhelpers.get(doc1, "bar")
actual = node.toxml()
expected = doc.createElement('b')
expected.setAttribute('id', 'bar')
self.assertEqual(actual, expected.toxml())
self.assertRaises(domhelpers.NodeLookupError,
domhelpers.get,
doc1,
"pzork")
def test_getIfExists(self):
doc1 = self.dom.parseString('<a><b id="bar"/><c class="foo"/></a>')
doc = self.dom.Document()
node = domhelpers.getIfExists(doc1, "foo")
actual = node.toxml()
expected = doc.createElement('c')
expected.setAttribute('class', 'foo')
self.assertEqual(actual, expected.toxml())
node = domhelpers.getIfExists(doc1, "pzork")
self.assertIdentical(node, None)
def test_getAndClear(self):
doc1 = self.dom.parseString('<a><b id="foo"><c></c></b></a>')
doc = self.dom.Document()
node = domhelpers.getAndClear(doc1, "foo")
actual = node.toxml()
expected = doc.createElement('b')
expected.setAttribute('id', 'foo')
self.assertEqual(actual, expected.toxml())
def test_locateNodes(self):
doc1 = self.dom.parseString('<a><b foo="olive"><c foo="olive"/></b><d foo="poopy"/></a>')
doc = self.dom.Document()
node_list = domhelpers.locateNodes(
doc1.childNodes, 'foo', 'olive', noNesting=1)
actual = ''.join([node.toxml() for node in node_list])
expected = doc.createElement('b')
expected.setAttribute('foo', 'olive')
c = doc.createElement('c')
c.setAttribute('foo', 'olive')
expected.appendChild(c)
self.assertEqual(actual, expected.toxml())
node_list = domhelpers.locateNodes(
doc1.childNodes, 'foo', 'olive', noNesting=0)
actual = ''.join([node.toxml() for node in node_list])
self.assertEqual(actual, expected.toxml() + c.toxml())
def test_getParents(self):
doc1 = self.dom.parseString('<a><b><c><d/></c><e/></b><f/></a>')
node_list = domhelpers.getParents(
doc1.childNodes[0].childNodes[0].childNodes[0])
actual = ''.join([node.tagName for node in node_list
if hasattr(node, 'tagName')])
self.assertEqual(actual, 'cba')
def test_findElementsWithAttribute(self):
doc1 = self.dom.parseString('<a foo="1"><b foo="2"/><c foo="1"/><d/></a>')
node_list = domhelpers.findElementsWithAttribute(doc1, 'foo')
actual = ''.join([node.tagName for node in node_list])
self.assertEqual(actual, 'abc')
node_list = domhelpers.findElementsWithAttribute(doc1, 'foo', '1')
actual = ''.join([node.tagName for node in node_list])
self.assertEqual(actual, 'ac')
def test_findNodesNamed(self):
doc1 = self.dom.parseString('<doc><foo/><bar/><foo>a</foo></doc>')
node_list = domhelpers.findNodesNamed(doc1, 'foo')
actual = len(node_list)
self.assertEqual(actual, 2)
def test_escape(self):
j = 'this string " contains many & characters> xml< won\'t like'
expected = 'this string &quot; contains many &amp; characters&gt; xml&lt; won\'t like'
self.assertEqual(domhelpers.escape(j), expected)
def test_unescape(self):
j = 'this string &quot; has &&amp; entities &gt; &lt; and some characters xml won\'t like<'
expected = 'this string " has && entities > < and some characters xml won\'t like<'
self.assertEqual(domhelpers.unescape(j), expected)
def test_getNodeText(self):
"""
L{getNodeText} returns the concatenation of all the text data at or
beneath the node passed to it.
"""
node = self.dom.parseString('<foo><bar>baz</bar><bar>quux</bar></foo>')
self.assertEqual(domhelpers.getNodeText(node), "bazquux")
class MicroDOMHelpersTests(DOMHelpersTestsMixin, TestCase):
dom = microdom
def test_gatherTextNodesDropsWhitespace(self):
"""
Microdom discards whitespace-only text nodes, so L{gatherTextNodes}
returns only the text from nodes which had non-whitespace characters.
"""
doc4_xml = '''<html>
<head>
</head>
<body>
stuff
</body>
</html>
'''
doc4 = self.dom.parseString(doc4_xml)
actual = domhelpers.gatherTextNodes(doc4)
expected = '\n stuff\n '
self.assertEqual(actual, expected)
actual = domhelpers.gatherTextNodes(doc4.documentElement)
self.assertEqual(actual, expected)
def test_textEntitiesNotDecoded(self):
"""
Microdom does not decode entities in text nodes.
"""
doc5_xml = '<x>Souffl&amp;</x>'
doc5 = self.dom.parseString(doc5_xml)
actual = domhelpers.gatherTextNodes(doc5)
expected = 'Souffl&amp;'
self.assertEqual(actual, expected)
actual = domhelpers.gatherTextNodes(doc5.documentElement)
self.assertEqual(actual, expected)
class MiniDOMHelpersTests(DOMHelpersTestsMixin, TestCase):
dom = minidom
def test_textEntitiesDecoded(self):
"""
Minidom does decode entities in text nodes.
"""
doc5_xml = '<x>Souffl&amp;</x>'
doc5 = self.dom.parseString(doc5_xml)
actual = domhelpers.gatherTextNodes(doc5)
expected = 'Souffl&'
self.assertEqual(actual, expected)
actual = domhelpers.gatherTextNodes(doc5.documentElement)
self.assertEqual(actual, expected)
def test_getNodeUnicodeText(self):
"""
L{domhelpers.getNodeText} returns a C{unicode} string when text
nodes are represented in the DOM with unicode, whether or not there
are non-ASCII characters present.
"""
node = self.dom.parseString("<foo>bar</foo>")
text = domhelpers.getNodeText(node)
self.assertEqual(text, u"bar")
self.assertIsInstance(text, unicode)
node = self.dom.parseString(u"<foo>\N{SNOWMAN}</foo>".encode('utf-8'))
text = domhelpers.getNodeText(node)
self.assertEqual(text, u"\N{SNOWMAN}")
self.assertIsInstance(text, unicode)

View file

@ -0,0 +1,481 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
HTTP errors.
"""
from __future__ import division, absolute_import
import re
import sys
import traceback
from twisted.trial import unittest
from twisted.python.compat import nativeString, _PY3
from twisted.web import error
from twisted.web.template import Tag
class CodeToMessageTests(unittest.TestCase):
"""
L{_codeToMessages} inverts L{_responses.RESPONSES}
"""
def test_validCode(self):
m = error._codeToMessage(b"302")
self.assertEqual(m, b"Found")
def test_invalidCode(self):
m = error._codeToMessage(b"987")
self.assertEqual(m, None)
def test_nonintegerCode(self):
m = error._codeToMessage(b"InvalidCode")
self.assertEqual(m, None)
class ErrorTests(unittest.TestCase):
"""
Tests for how L{Error} attributes are initialized.
"""
def test_noMessageValidStatus(self):
"""
If no C{message} argument is passed to the L{Error} constructor and the
C{code} argument is a valid HTTP status code, C{code} is mapped to a
descriptive string to which C{message} is assigned.
"""
e = error.Error(b"200")
self.assertEqual(e.message, b"OK")
def test_noMessageInvalidStatus(self):
"""
If no C{message} argument is passed to the L{Error} constructor and
C{code} isn't a valid HTTP status code, C{message} stays L{None}.
"""
e = error.Error(b"InvalidCode")
self.assertEqual(e.message, None)
def test_messageExists(self):
"""
If a C{message} argument is passed to the L{Error} constructor, the
C{message} isn't affected by the value of C{status}.
"""
e = error.Error(b"200", b"My own message")
self.assertEqual(e.message, b"My own message")
def test_str(self):
"""
C{str()} on an L{Error} returns the code and message it was
instantiated with.
"""
# Bytestring status
e = error.Error(b"200", b"OK")
self.assertEqual(str(e), "200 OK")
# int status
e = error.Error(200, b"OK")
self.assertEqual(str(e), "200 OK")
class PageRedirectTests(unittest.TestCase):
"""
Tests for how L{PageRedirect} attributes are initialized.
"""
def test_noMessageValidStatus(self):
"""
If no C{message} argument is passed to the L{PageRedirect} constructor
and the C{code} argument is a valid HTTP status code, C{code} is mapped
to a descriptive string to which C{message} is assigned.
"""
e = error.PageRedirect(b"200", location=b"/foo")
self.assertEqual(e.message, b"OK to /foo")
def test_noMessageValidStatusNoLocation(self):
"""
If no C{message} argument is passed to the L{PageRedirect} constructor
and C{location} is also empty and the C{code} argument is a valid HTTP
status code, C{code} is mapped to a descriptive string to which
C{message} is assigned without trying to include an empty location.
"""
e = error.PageRedirect(b"200")
self.assertEqual(e.message, b"OK")
def test_noMessageInvalidStatusLocationExists(self):
"""
If no C{message} argument is passed to the L{PageRedirect} constructor
and C{code} isn't a valid HTTP status code, C{message} stays L{None}.
"""
e = error.PageRedirect(b"InvalidCode", location=b"/foo")
self.assertEqual(e.message, None)
def test_messageExistsLocationExists(self):
"""
If a C{message} argument is passed to the L{PageRedirect} constructor,
the C{message} isn't affected by the value of C{status}.
"""
e = error.PageRedirect(b"200", b"My own message", location=b"/foo")
self.assertEqual(e.message, b"My own message to /foo")
def test_messageExistsNoLocation(self):
"""
If a C{message} argument is passed to the L{PageRedirect} constructor
and no location is provided, C{message} doesn't try to include the
empty location.
"""
e = error.PageRedirect(b"200", b"My own message")
self.assertEqual(e.message, b"My own message")
class InfiniteRedirectionTests(unittest.TestCase):
"""
Tests for how L{InfiniteRedirection} attributes are initialized.
"""
def test_noMessageValidStatus(self):
"""
If no C{message} argument is passed to the L{InfiniteRedirection}
constructor and the C{code} argument is a valid HTTP status code,
C{code} is mapped to a descriptive string to which C{message} is
assigned.
"""
e = error.InfiniteRedirection(b"200", location=b"/foo")
self.assertEqual(e.message, b"OK to /foo")
def test_noMessageValidStatusNoLocation(self):
"""
If no C{message} argument is passed to the L{InfiniteRedirection}
constructor and C{location} is also empty and the C{code} argument is a
valid HTTP status code, C{code} is mapped to a descriptive string to
which C{message} is assigned without trying to include an empty
location.
"""
e = error.InfiniteRedirection(b"200")
self.assertEqual(e.message, b"OK")
def test_noMessageInvalidStatusLocationExists(self):
"""
If no C{message} argument is passed to the L{InfiniteRedirection}
constructor and C{code} isn't a valid HTTP status code, C{message} stays
L{None}.
"""
e = error.InfiniteRedirection(b"InvalidCode", location=b"/foo")
self.assertEqual(e.message, None)
def test_messageExistsLocationExists(self):
"""
If a C{message} argument is passed to the L{InfiniteRedirection}
constructor, the C{message} isn't affected by the value of C{status}.
"""
e = error.InfiniteRedirection(b"200", b"My own message",
location=b"/foo")
self.assertEqual(e.message, b"My own message to /foo")
def test_messageExistsNoLocation(self):
"""
If a C{message} argument is passed to the L{InfiniteRedirection}
constructor and no location is provided, C{message} doesn't try to
include the empty location.
"""
e = error.InfiniteRedirection(b"200", b"My own message")
self.assertEqual(e.message, b"My own message")
class RedirectWithNoLocationTests(unittest.TestCase):
"""
L{RedirectWithNoLocation} is a subclass of L{Error} which sets
a custom message in the constructor.
"""
def test_validMessage(self):
"""
When C{code}, C{message}, and C{uri} are passed to the
L{RedirectWithNoLocation} constructor, the C{message} and C{uri}
attributes are set, respectively.
"""
e = error.RedirectWithNoLocation(b"302", b"REDIRECT",
b"https://example.com")
self.assertEqual(e.message, b"REDIRECT to https://example.com")
self.assertEqual(e.uri, b"https://example.com")
class MissingRenderMethodTests(unittest.TestCase):
"""
Tests for how L{MissingRenderMethod} exceptions are initialized and
displayed.
"""
def test_constructor(self):
"""
Given C{element} and C{renderName} arguments, the
L{MissingRenderMethod} constructor assigns the values to the
corresponding attributes.
"""
elt = object()
e = error.MissingRenderMethod(elt, 'renderThing')
self.assertIs(e.element, elt)
self.assertIs(e.renderName, 'renderThing')
def test_repr(self):
"""
A L{MissingRenderMethod} is represented using a custom string
containing the element's representation and the method name.
"""
elt = object()
e = error.MissingRenderMethod(elt, 'renderThing')
self.assertEqual(
repr(e),
("'MissingRenderMethod': "
"%r had no render method named 'renderThing'") % elt)
class MissingTemplateLoaderTests(unittest.TestCase):
"""
Tests for how L{MissingTemplateLoader} exceptions are initialized and
displayed.
"""
def test_constructor(self):
"""
Given an C{element} argument, the L{MissingTemplateLoader} constructor
assigns the value to the corresponding attribute.
"""
elt = object()
e = error.MissingTemplateLoader(elt)
self.assertIs(e.element, elt)
def test_repr(self):
"""
A L{MissingTemplateLoader} is represented using a custom string
containing the element's representation and the method name.
"""
elt = object()
e = error.MissingTemplateLoader(elt)
self.assertEqual(
repr(e),
"'MissingTemplateLoader': %r had no loader" % elt)
class FlattenerErrorTests(unittest.TestCase):
"""
Tests for L{FlattenerError}.
"""
def makeFlattenerError(self, roots=[]):
try:
raise RuntimeError("oh noes")
except Exception as e:
tb = traceback.extract_tb(sys.exc_info()[2])
return error.FlattenerError(e, roots, tb)
def fakeFormatRoot(self, obj):
return 'R(%s)' % obj
def test_constructor(self):
"""
Given C{exception}, C{roots}, and C{traceback} arguments, the
L{FlattenerError} constructor assigns the roots to the C{_roots}
attribute.
"""
e = self.makeFlattenerError(roots=['a', 'b'])
self.assertEqual(e._roots, ['a', 'b'])
def test_str(self):
"""
The string form of a L{FlattenerError} is identical to its
representation.
"""
e = self.makeFlattenerError()
self.assertEqual(str(e), repr(e))
def test_reprWithRootsAndWithTraceback(self):
"""
The representation of a L{FlattenerError} initialized with roots and a
traceback contains a formatted representation of those roots (using
C{_formatRoot}) and a formatted traceback.
"""
e = self.makeFlattenerError(['a', 'b'])
e._formatRoot = self.fakeFormatRoot
self.assertTrue(
re.match('Exception while flattening:\n'
' R\(a\)\n'
' R\(b\)\n'
' File "[^"]*", line [0-9]*, in makeFlattenerError\n'
' raise RuntimeError\("oh noes"\)\n'
'RuntimeError: oh noes\n$',
repr(e), re.M | re.S),
repr(e))
def test_reprWithoutRootsAndWithTraceback(self):
"""
The representation of a L{FlattenerError} initialized without roots but
with a traceback contains a formatted traceback but no roots.
"""
e = self.makeFlattenerError([])
self.assertTrue(
re.match('Exception while flattening:\n'
' File "[^"]*", line [0-9]*, in makeFlattenerError\n'
' raise RuntimeError\("oh noes"\)\n'
'RuntimeError: oh noes\n$',
repr(e), re.M | re.S),
repr(e))
def test_reprWithoutRootsAndWithoutTraceback(self):
"""
The representation of a L{FlattenerError} initialized without roots but
with a traceback contains a formatted traceback but no roots.
"""
e = error.FlattenerError(RuntimeError("oh noes"), [], None)
self.assertTrue(
re.match('Exception while flattening:\n'
'RuntimeError: oh noes\n$',
repr(e), re.M | re.S),
repr(e))
def test_formatRootShortUnicodeString(self):
"""
The C{_formatRoot} method formats a short unicode string using the
built-in repr.
"""
e = self.makeFlattenerError()
self.assertEqual(e._formatRoot(nativeString('abcd')), repr('abcd'))
def test_formatRootLongUnicodeString(self):
"""
The C{_formatRoot} method formats a long unicode string using the
built-in repr with an ellipsis.
"""
e = self.makeFlattenerError()
longString = nativeString('abcde-' * 20)
self.assertEqual(e._formatRoot(longString),
repr('abcde-abcde-abcde-ab<...>e-abcde-abcde-abcde-'))
def test_formatRootShortByteString(self):
"""
The C{_formatRoot} method formats a short byte string using the
built-in repr.
"""
e = self.makeFlattenerError()
self.assertEqual(e._formatRoot(b'abcd'), repr(b'abcd'))
def test_formatRootLongByteString(self):
"""
The C{_formatRoot} method formats a long byte string using the
built-in repr with an ellipsis.
"""
e = self.makeFlattenerError()
longString = b'abcde-' * 20
self.assertEqual(e._formatRoot(longString),
repr(b'abcde-abcde-abcde-ab<...>e-abcde-abcde-abcde-'))
def test_formatRootTagNoFilename(self):
"""
The C{_formatRoot} method formats a C{Tag} with no filename information
as 'Tag <tagName>'.
"""
e = self.makeFlattenerError()
self.assertEqual(e._formatRoot(Tag('a-tag')), 'Tag <a-tag>')
def test_formatRootTagWithFilename(self):
"""
The C{_formatRoot} method formats a C{Tag} with filename information
using the filename, line, column, and tag information
"""
e = self.makeFlattenerError()
t = Tag('a-tag', filename='tpl.py', lineNumber=10, columnNumber=20)
self.assertEqual(e._formatRoot(t),
'File "tpl.py", line 10, column 20, in "a-tag"')
def test_string(self):
"""
If a L{FlattenerError} is created with a string root, up to around 40
bytes from that string are included in the string representation of the
exception.
"""
self.assertEqual(
str(error.FlattenerError(RuntimeError("reason"),
['abc123xyz'], [])),
"Exception while flattening:\n"
" 'abc123xyz'\n"
"RuntimeError: reason\n")
self.assertEqual(
str(error.FlattenerError(
RuntimeError("reason"), ['0123456789' * 10], [])),
"Exception while flattening:\n"
" '01234567890123456789"
"<...>01234567890123456789'\n" # TODO: re-add 0
"RuntimeError: reason\n")
def test_unicode(self):
"""
If a L{FlattenerError} is created with a unicode root, up to around 40
characters from that string are included in the string representation
of the exception.
"""
# the response includes the output of repr(), which differs between
# Python 2 and 3
u = {'u': ''} if _PY3 else {'u': 'u'}
self.assertEqual(
str(error.FlattenerError(
RuntimeError("reason"), [u'abc\N{SNOWMAN}xyz'], [])),
"Exception while flattening:\n"
" %(u)s'abc\\u2603xyz'\n" # Codepoint for SNOWMAN
"RuntimeError: reason\n" % u)
self.assertEqual(
str(error.FlattenerError(
RuntimeError("reason"), [u'01234567\N{SNOWMAN}9' * 10],
[])),
"Exception while flattening:\n"
" %(u)s'01234567\\u2603901234567\\u26039"
"<...>01234567\\u2603901234567"
"\\u26039'\n"
"RuntimeError: reason\n" % u)
class UnsupportedMethodTests(unittest.SynchronousTestCase):
"""
Tests for L{UnsupportedMethod}.
"""
def test_str(self):
"""
The C{__str__} for L{UnsupportedMethod} makes it clear that what it
shows is a list of the supported methods, not the method that was
unsupported.
"""
b = "b" if _PY3 else ""
e = error.UnsupportedMethod([b"HEAD", b"PATCH"])
self.assertEqual(
str(e), "Expected one of [{b}'HEAD', {b}'PATCH']".format(b=b),
)

View file

@ -0,0 +1,537 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for the flattening portion of L{twisted.web.template}, implemented in
L{twisted.web._flatten}.
"""
import sys
import traceback
from xml.etree.cElementTree import XML
from collections import OrderedDict
from zope.interface import implementer
from twisted.python.compat import _PY35PLUS
from twisted.trial.unittest import TestCase
from twisted.test.testutils import XMLAssertionMixin
from twisted.internet.defer import passthru, succeed, gatherResults
from twisted.web.iweb import IRenderable
from twisted.web.error import UnfilledSlot, UnsupportedType, FlattenerError
from twisted.web.template import tags, Tag, Comment, CDATA, CharRef, slot
from twisted.web.template import Element, renderer, TagLoader, flattenString
from twisted.web.test._util import FlattenTestCase
class SerializationTests(FlattenTestCase, XMLAssertionMixin):
"""
Tests for flattening various things.
"""
def test_nestedTags(self):
"""
Test that nested tags flatten correctly.
"""
return self.assertFlattensTo(
tags.html(tags.body('42'), hi='there'),
b'<html hi="there"><body>42</body></html>')
def test_serializeString(self):
"""
Test that strings will be flattened and escaped correctly.
"""
return gatherResults([
self.assertFlattensTo('one', b'one'),
self.assertFlattensTo('<abc&&>123', b'&lt;abc&amp;&amp;&gt;123'),
])
def test_serializeSelfClosingTags(self):
"""
The serialized form of a self-closing tag is C{'<tagName />'}.
"""
return self.assertFlattensTo(tags.img(), b'<img />')
def test_serializeAttribute(self):
"""
The serialized form of attribute I{a} with value I{b} is C{'a="b"'}.
"""
self.assertFlattensImmediately(tags.img(src='foo'),
b'<img src="foo" />')
def test_serializedMultipleAttributes(self):
"""
Multiple attributes are separated by a single space in their serialized
form.
"""
tag = tags.img()
tag.attributes = OrderedDict([("src", "foo"), ("name", "bar")])
self.assertFlattensImmediately(tag, b'<img src="foo" name="bar" />')
def checkAttributeSanitization(self, wrapData, wrapTag):
"""
Common implementation of L{test_serializedAttributeWithSanitization}
and L{test_serializedDeferredAttributeWithSanitization},
L{test_serializedAttributeWithTransparentTag}.
@param wrapData: A 1-argument callable that wraps around the
attribute's value so other tests can customize it.
@param wrapData: callable taking L{bytes} and returning something
flattenable
@param wrapTag: A 1-argument callable that wraps around the outer tag
so other tests can customize it.
@type wrapTag: callable taking L{Tag} and returning L{Tag}.
"""
self.assertFlattensImmediately(
wrapTag(tags.img(src=wrapData("<>&\""))),
b'<img src="&lt;&gt;&amp;&quot;" />')
def test_serializedAttributeWithSanitization(self):
"""
Attribute values containing C{"<"}, C{">"}, C{"&"}, or C{'"'} have
C{"&lt;"}, C{"&gt;"}, C{"&amp;"}, or C{"&quot;"} substituted for those
bytes in the serialized output.
"""
self.checkAttributeSanitization(passthru, passthru)
def test_serializedDeferredAttributeWithSanitization(self):
"""
Like L{test_serializedAttributeWithSanitization}, but when the contents
of the attribute are in a L{Deferred
<twisted.internet.defer.Deferred>}.
"""
self.checkAttributeSanitization(succeed, passthru)
def test_serializedAttributeWithSlotWithSanitization(self):
"""
Like L{test_serializedAttributeWithSanitization} but with a slot.
"""
toss = []
self.checkAttributeSanitization(
lambda value: toss.append(value) or slot("stuff"),
lambda tag: tag.fillSlots(stuff=toss.pop())
)
def test_serializedAttributeWithTransparentTag(self):
"""
Attribute values which are supplied via the value of a C{t:transparent}
tag have the same substitution rules to them as values supplied
directly.
"""
self.checkAttributeSanitization(tags.transparent, passthru)
def test_serializedAttributeWithTransparentTagWithRenderer(self):
"""
Like L{test_serializedAttributeWithTransparentTag}, but when the
attribute is rendered by a renderer on an element.
"""
class WithRenderer(Element):
def __init__(self, value, loader):
self.value = value
super(WithRenderer, self).__init__(loader)
@renderer
def stuff(self, request, tag):
return self.value
toss = []
self.checkAttributeSanitization(
lambda value: toss.append(value) or
tags.transparent(render="stuff"),
lambda tag: WithRenderer(toss.pop(), TagLoader(tag))
)
def test_serializedAttributeWithRenderable(self):
"""
Like L{test_serializedAttributeWithTransparentTag}, but when the
attribute is a provider of L{IRenderable} rather than a transparent
tag.
"""
@implementer(IRenderable)
class Arbitrary(object):
def __init__(self, value):
self.value = value
def render(self, request):
return self.value
self.checkAttributeSanitization(Arbitrary, passthru)
def checkTagAttributeSerialization(self, wrapTag):
"""
Common implementation of L{test_serializedAttributeWithTag} and
L{test_serializedAttributeWithDeferredTag}.
@param wrapTag: A 1-argument callable that wraps around the attribute's
value so other tests can customize it.
@param wrapTag: callable taking L{Tag} and returning something
flattenable
"""
innerTag = tags.a('<>&"')
outerTag = tags.img(src=wrapTag(innerTag))
outer = self.assertFlattensImmediately(
outerTag,
b'<img src="&lt;a&gt;&amp;lt;&amp;gt;&amp;amp;&quot;&lt;/a&gt;" />')
inner = self.assertFlattensImmediately(
innerTag, b'<a>&lt;&gt;&amp;"</a>')
# Since the above quoting is somewhat tricky, validate it by making sure
# that the main use-case for tag-within-attribute is supported here: if
# we serialize a tag, it is quoted *such that it can be parsed out again
# as a tag*.
self.assertXMLEqual(XML(outer).attrib['src'], inner)
def test_serializedAttributeWithTag(self):
"""
L{Tag} objects which are serialized within the context of an attribute
are serialized such that the text content of the attribute may be
parsed to retrieve the tag.
"""
self.checkTagAttributeSerialization(passthru)
def test_serializedAttributeWithDeferredTag(self):
"""
Like L{test_serializedAttributeWithTag}, but when the L{Tag} is in a
L{Deferred <twisted.internet.defer.Deferred>}.
"""
self.checkTagAttributeSerialization(succeed)
def test_serializedAttributeWithTagWithAttribute(self):
"""
Similar to L{test_serializedAttributeWithTag}, but for the additional
complexity where the tag which is the attribute value itself has an
attribute value which contains bytes which require substitution.
"""
flattened = self.assertFlattensImmediately(
tags.img(src=tags.a(href='<>&"')),
b'<img src="&lt;a href='
b'&quot;&amp;lt;&amp;gt;&amp;amp;&amp;quot;&quot;&gt;'
b'&lt;/a&gt;" />')
# As in checkTagAttributeSerialization, belt-and-suspenders:
self.assertXMLEqual(XML(flattened).attrib['src'],
b'<a href="&lt;&gt;&amp;&quot;"></a>')
def test_serializeComment(self):
"""
Test that comments are correctly flattened and escaped.
"""
return self.assertFlattensTo(Comment('foo bar'), b'<!--foo bar-->'),
def test_commentEscaping(self):
"""
The data in a L{Comment} is escaped and mangled in the flattened output
so that the result is a legal SGML and XML comment.
SGML comment syntax is complicated and hard to use. This rule is more
restrictive, and more compatible:
Comments start with <!-- and end with --> and never contain -- or >.
Also by XML syntax, a comment may not end with '-'.
@see: U{http://www.w3.org/TR/REC-xml/#sec-comments}
"""
def verifyComment(c):
self.assertTrue(
c.startswith(b'<!--'),
"%r does not start with the comment prefix" % (c,))
self.assertTrue(
c.endswith(b'-->'),
"%r does not end with the comment suffix" % (c,))
# If it is shorter than 7, then the prefix and suffix overlap
# illegally.
self.assertTrue(
len(c) >= 7,
"%r is too short to be a legal comment" % (c,))
content = c[4:-3]
self.assertNotIn(b'--', content)
self.assertNotIn(b'>', content)
if content:
self.assertNotEqual(content[-1], b'-')
results = []
for c in [
'',
'foo---bar',
'foo---bar-',
'foo>bar',
'foo-->bar',
'----------------',
]:
d = flattenString(None, Comment(c))
d.addCallback(verifyComment)
results.append(d)
return gatherResults(results)
def test_serializeCDATA(self):
"""
Test that CDATA is correctly flattened and escaped.
"""
return gatherResults([
self.assertFlattensTo(CDATA('foo bar'), b'<![CDATA[foo bar]]>'),
self.assertFlattensTo(
CDATA('foo ]]> bar'),
b'<![CDATA[foo ]]]]><![CDATA[> bar]]>'),
])
def test_serializeUnicode(self):
"""
Test that unicode is encoded correctly in the appropriate places, and
raises an error when it occurs in inappropriate place.
"""
snowman = u'\N{SNOWMAN}'
return gatherResults([
self.assertFlattensTo(snowman, b'\xe2\x98\x83'),
self.assertFlattensTo(tags.p(snowman), b'<p>\xe2\x98\x83</p>'),
self.assertFlattensTo(Comment(snowman), b'<!--\xe2\x98\x83-->'),
self.assertFlattensTo(CDATA(snowman), b'<![CDATA[\xe2\x98\x83]]>'),
self.assertFlatteningRaises(
Tag(snowman), UnicodeEncodeError),
self.assertFlatteningRaises(
Tag('p', attributes={snowman: ''}), UnicodeEncodeError),
])
def test_serializeCharRef(self):
"""
A character reference is flattened to a string using the I{&#NNNN;}
syntax.
"""
ref = CharRef(ord(u"\N{SNOWMAN}"))
return self.assertFlattensTo(ref, b"&#9731;")
def test_serializeDeferred(self):
"""
Test that a deferred is substituted with the current value in the
callback chain when flattened.
"""
return self.assertFlattensTo(succeed('two'), b'two')
def test_serializeSameDeferredTwice(self):
"""
Test that the same deferred can be flattened twice.
"""
d = succeed('three')
return gatherResults([
self.assertFlattensTo(d, b'three'),
self.assertFlattensTo(d, b'three'),
])
def test_serializeCoroutine(self):
"""
Test that a coroutine returning a value is substituted with the that
value when flattened.
"""
from textwrap import dedent
namespace = {}
exec(dedent(
"""
async def coro(x):
return x
"""
), namespace)
coro = namespace["coro"]
return self.assertFlattensTo(coro('four'), b'four')
if not _PY35PLUS:
test_serializeCoroutine.skip = (
"coroutines not available before Python 3.5"
)
def test_serializeCoroutineWithAwait(self):
"""
Test that a coroutine returning an awaited deferred value is
substituted with that value when flattened.
"""
from textwrap import dedent
namespace = dict(succeed=succeed)
exec(dedent(
"""
async def coro(x):
return await succeed(x)
"""
), namespace)
coro = namespace["coro"]
return self.assertFlattensTo(coro('four'), b'four')
if not _PY35PLUS:
test_serializeCoroutineWithAwait.skip = (
"coroutines not available before Python 3.5"
)
def test_serializeIRenderable(self):
"""
Test that flattening respects all of the IRenderable interface.
"""
@implementer(IRenderable)
class FakeElement(object):
def render(ign,ored):
return tags.p(
'hello, ',
tags.transparent(render='test'), ' - ',
tags.transparent(render='test'))
def lookupRenderMethod(ign, name):
self.assertEqual(name, 'test')
return lambda ign, node: node('world')
return gatherResults([
self.assertFlattensTo(FakeElement(), b'<p>hello, world - world</p>'),
])
def test_serializeSlots(self):
"""
Test that flattening a slot will use the slot value from the tag.
"""
t1 = tags.p(slot('test'))
t2 = t1.clone()
t2.fillSlots(test='hello, world')
return gatherResults([
self.assertFlatteningRaises(t1, UnfilledSlot),
self.assertFlattensTo(t2, b'<p>hello, world</p>'),
])
def test_serializeDeferredSlots(self):
"""
Test that a slot with a deferred as its value will be flattened using
the value from the deferred.
"""
t = tags.p(slot('test'))
t.fillSlots(test=succeed(tags.em('four>')))
return self.assertFlattensTo(t, b'<p><em>four&gt;</em></p>')
def test_unknownTypeRaises(self):
"""
Test that flattening an unknown type of thing raises an exception.
"""
return self.assertFlatteningRaises(None, UnsupportedType)
# Use the co_filename mechanism (instead of the __file__ mechanism) because
# it is the mechanism traceback formatting uses. The two do not necessarily
# agree with each other. This requires a code object compiled in this file.
# The easiest way to get a code object is with a new function. I'll use a
# lambda to avoid adding anything else to this namespace. The result will
# be a string which agrees with the one the traceback module will put into a
# traceback for frames associated with functions defined in this file.
HERE = (lambda: None).__code__.co_filename
class FlattenerErrorTests(TestCase):
"""
Tests for L{FlattenerError}.
"""
def test_renderable(self):
"""
If a L{FlattenerError} is created with an L{IRenderable} provider root,
the repr of that object is included in the string representation of the
exception.
"""
@implementer(IRenderable)
class Renderable(object):
def __repr__(self):
return "renderable repr"
self.assertEqual(
str(FlattenerError(
RuntimeError("reason"), [Renderable()], [])),
"Exception while flattening:\n"
" renderable repr\n"
"RuntimeError: reason\n")
def test_tag(self):
"""
If a L{FlattenerError} is created with a L{Tag} instance with source
location information, the source location is included in the string
representation of the exception.
"""
tag = Tag(
'div', filename='/foo/filename.xhtml', lineNumber=17, columnNumber=12)
self.assertEqual(
str(FlattenerError(RuntimeError("reason"), [tag], [])),
"Exception while flattening:\n"
" File \"/foo/filename.xhtml\", line 17, column 12, in \"div\"\n"
"RuntimeError: reason\n")
def test_tagWithoutLocation(self):
"""
If a L{FlattenerError} is created with a L{Tag} instance without source
location information, only the tagName is included in the string
representation of the exception.
"""
self.assertEqual(
str(FlattenerError(RuntimeError("reason"), [Tag('span')], [])),
"Exception while flattening:\n"
" Tag <span>\n"
"RuntimeError: reason\n")
def test_traceback(self):
"""
If a L{FlattenerError} is created with traceback frames, they are
included in the string representation of the exception.
"""
# Try to be realistic in creating the data passed in for the traceback
# frames.
def f():
g()
def g():
raise RuntimeError("reason")
try:
f()
except RuntimeError as e:
# Get the traceback, minus the info for *this* frame
tbinfo = traceback.extract_tb(sys.exc_info()[2])[1:]
exc = e
else:
self.fail("f() must raise RuntimeError")
self.assertEqual(
str(FlattenerError(exc, [], tbinfo)),
"Exception while flattening:\n"
" File \"%s\", line %d, in f\n"
" g()\n"
" File \"%s\", line %d, in g\n"
" raise RuntimeError(\"reason\")\n"
"RuntimeError: reason\n" % (
HERE, f.__code__.co_firstlineno + 1,
HERE, g.__code__.co_firstlineno + 1))

View file

@ -0,0 +1,43 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
from twisted.trial import unittest
from twisted.web import html
class WebHtmlTests(unittest.TestCase):
"""
Unit tests for L{twisted.web.html}.
"""
def test_deprecation(self):
"""
Calls to L{twisted.web.html} members emit a deprecation warning.
"""
def assertDeprecationWarningOf(method):
"""
Check that a deprecation warning is present.
"""
warningsShown = self.flushWarnings([self.test_deprecation])
self.assertEqual(len(warningsShown), 1)
self.assertIdentical(
warningsShown[0]['category'], DeprecationWarning)
self.assertEqual(
warningsShown[0]['message'],
'twisted.web.html.%s was deprecated in Twisted 15.3.0; '
'please use twisted.web.template instead' % (
method,),
)
html.PRE('')
assertDeprecationWarningOf('PRE')
html.UL([])
assertDeprecationWarningOf('UL')
html.linkList([])
assertDeprecationWarningOf('linkList')
html.output(lambda: None)
assertDeprecationWarningOf('output')

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,660 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.web.http_headers}.
"""
from __future__ import division, absolute_import
from twisted.trial.unittest import TestCase
from twisted.python.compat import _PY3, unicode
from twisted.web.http_headers import Headers
from twisted.web.test.requesthelper import (
bytesLinearWhitespaceComponents,
sanitizedBytes,
textLinearWhitespaceComponents,
)
def assertSanitized(testCase, components, expected):
"""
Assert that the components are sanitized to the expected value as
both a header name and value, across all of L{Header}'s setters
and getters.
@param testCase: A test case.
@param components: A sequence of values that contain linear
whitespace to use as header names and values; see
C{textLinearWhitespaceComponents} and
C{bytesLinearWhitespaceComponents}
@param expected: The expected sanitized form of the component for
both headers names and their values.
"""
for component in components:
headers = []
headers.append(Headers({component: [component]}))
added = Headers()
added.addRawHeader(component, component)
headers.append(added)
setHeader = Headers()
setHeader.setRawHeaders(component, [component])
headers.append(setHeader)
for header in headers:
testCase.assertEqual(list(header.getAllRawHeaders()),
[(expected, [expected])])
testCase.assertEqual(header.getRawHeaders(expected), [expected])
class BytesHeadersTests(TestCase):
"""
Tests for L{Headers}, using L{bytes} arguments for methods.
"""
def test_sanitizeLinearWhitespace(self):
"""
Linear whitespace in header names or values is replaced with a
single space.
"""
assertSanitized(self, bytesLinearWhitespaceComponents, sanitizedBytes)
def test_initializer(self):
"""
The header values passed to L{Headers.__init__} can be retrieved via
L{Headers.getRawHeaders}.
"""
h = Headers({b'Foo': [b'bar']})
self.assertEqual(h.getRawHeaders(b'foo'), [b'bar'])
def test_setRawHeaders(self):
"""
L{Headers.setRawHeaders} sets the header values for the given
header name to the sequence of byte string values.
"""
rawValue = [b"value1", b"value2"]
h = Headers()
h.setRawHeaders(b"test", rawValue)
self.assertTrue(h.hasHeader(b"test"))
self.assertTrue(h.hasHeader(b"Test"))
self.assertEqual(h.getRawHeaders(b"test"), rawValue)
def test_rawHeadersTypeChecking(self):
"""
L{Headers.setRawHeaders} requires values to be of type list.
"""
h = Headers()
self.assertRaises(TypeError, h.setRawHeaders, b'key', {b'Foo': b'bar'})
def test_addRawHeader(self):
"""
L{Headers.addRawHeader} adds a new value for a given header.
"""
h = Headers()
h.addRawHeader(b"test", b"lemur")
self.assertEqual(h.getRawHeaders(b"test"), [b"lemur"])
h.addRawHeader(b"test", b"panda")
self.assertEqual(h.getRawHeaders(b"test"), [b"lemur", b"panda"])
def test_getRawHeadersNoDefault(self):
"""
L{Headers.getRawHeaders} returns L{None} if the header is not found and
no default is specified.
"""
self.assertIsNone(Headers().getRawHeaders(b"test"))
def test_getRawHeadersDefaultValue(self):
"""
L{Headers.getRawHeaders} returns the specified default value when no
header is found.
"""
h = Headers()
default = object()
self.assertIdentical(h.getRawHeaders(b"test", default), default)
def test_getRawHeadersWithDefaultMatchingValue(self):
"""
If the object passed as the value list to L{Headers.setRawHeaders}
is later passed as a default to L{Headers.getRawHeaders}, the
result nevertheless contains encoded values.
"""
h = Headers()
default = [u"value"]
h.setRawHeaders(b"key", default)
self.assertIsInstance(h.getRawHeaders(b"key", default)[0], bytes)
self.assertEqual(h.getRawHeaders(b"key", default), [b"value"])
def test_getRawHeaders(self):
"""
L{Headers.getRawHeaders} returns the values which have been set for a
given header.
"""
h = Headers()
h.setRawHeaders(b"test", [b"lemur"])
self.assertEqual(h.getRawHeaders(b"test"), [b"lemur"])
self.assertEqual(h.getRawHeaders(b"Test"), [b"lemur"])
def test_hasHeaderTrue(self):
"""
Check that L{Headers.hasHeader} returns C{True} when the given header
is found.
"""
h = Headers()
h.setRawHeaders(b"test", [b"lemur"])
self.assertTrue(h.hasHeader(b"test"))
self.assertTrue(h.hasHeader(b"Test"))
def test_hasHeaderFalse(self):
"""
L{Headers.hasHeader} returns C{False} when the given header is not
found.
"""
self.assertFalse(Headers().hasHeader(b"test"))
def test_removeHeader(self):
"""
Check that L{Headers.removeHeader} removes the given header.
"""
h = Headers()
h.setRawHeaders(b"foo", [b"lemur"])
self.assertTrue(h.hasHeader(b"foo"))
h.removeHeader(b"foo")
self.assertFalse(h.hasHeader(b"foo"))
h.setRawHeaders(b"bar", [b"panda"])
self.assertTrue(h.hasHeader(b"bar"))
h.removeHeader(b"Bar")
self.assertFalse(h.hasHeader(b"bar"))
def test_removeHeaderDoesntExist(self):
"""
L{Headers.removeHeader} is a no-operation when the specified header is
not found.
"""
h = Headers()
h.removeHeader(b"test")
self.assertEqual(list(h.getAllRawHeaders()), [])
def test_canonicalNameCaps(self):
"""
L{Headers._canonicalNameCaps} returns the canonical capitalization for
the given header.
"""
h = Headers()
self.assertEqual(h._canonicalNameCaps(b"test"), b"Test")
self.assertEqual(h._canonicalNameCaps(b"test-stuff"), b"Test-Stuff")
self.assertEqual(h._canonicalNameCaps(b"content-md5"), b"Content-MD5")
self.assertEqual(h._canonicalNameCaps(b"dnt"), b"DNT")
self.assertEqual(h._canonicalNameCaps(b"etag"), b"ETag")
self.assertEqual(h._canonicalNameCaps(b"p3p"), b"P3P")
self.assertEqual(h._canonicalNameCaps(b"te"), b"TE")
self.assertEqual(h._canonicalNameCaps(b"www-authenticate"),
b"WWW-Authenticate")
self.assertEqual(h._canonicalNameCaps(b"x-xss-protection"),
b"X-XSS-Protection")
def test_getAllRawHeaders(self):
"""
L{Headers.getAllRawHeaders} returns an iterable of (k, v) pairs, where
C{k} is the canonicalized representation of the header name, and C{v}
is a sequence of values.
"""
h = Headers()
h.setRawHeaders(b"test", [b"lemurs"])
h.setRawHeaders(b"www-authenticate", [b"basic aksljdlk="])
allHeaders = set([(k, tuple(v)) for k, v in h.getAllRawHeaders()])
self.assertEqual(allHeaders,
set([(b"WWW-Authenticate", (b"basic aksljdlk=",)),
(b"Test", (b"lemurs",))]))
def test_headersComparison(self):
"""
A L{Headers} instance compares equal to itself and to another
L{Headers} instance with the same values.
"""
first = Headers()
first.setRawHeaders(b"foo", [b"panda"])
second = Headers()
second.setRawHeaders(b"foo", [b"panda"])
third = Headers()
third.setRawHeaders(b"foo", [b"lemur", b"panda"])
self.assertEqual(first, first)
self.assertEqual(first, second)
self.assertNotEqual(first, third)
def test_otherComparison(self):
"""
An instance of L{Headers} does not compare equal to other unrelated
objects.
"""
h = Headers()
self.assertNotEqual(h, ())
self.assertNotEqual(h, object())
self.assertNotEqual(h, b"foo")
def test_repr(self):
"""
The L{repr} of a L{Headers} instance shows the names and values of all
the headers it contains.
"""
foo = b"foo"
bar = b"bar"
baz = b"baz"
self.assertEqual(
repr(Headers({foo: [bar, baz]})),
"Headers({%r: [%r, %r]})" % (foo, bar, baz))
def test_reprWithRawBytes(self):
"""
The L{repr} of a L{Headers} instance shows the names and values of all
the headers it contains, not attempting to decode any raw bytes.
"""
# There's no such thing as undecodable latin-1, you'll just get
# some mojibake
foo = b"foo"
# But this is invalid UTF-8! So, any accidental decoding/encoding will
# throw an exception.
bar = b"bar\xe1"
baz = b"baz\xe1"
self.assertEqual(
repr(Headers({foo: [bar, baz]})),
"Headers({%r: [%r, %r]})" % (foo, bar, baz))
def test_subclassRepr(self):
"""
The L{repr} of an instance of a subclass of L{Headers} uses the name
of the subclass instead of the string C{"Headers"}.
"""
foo = b"foo"
bar = b"bar"
baz = b"baz"
class FunnyHeaders(Headers):
pass
self.assertEqual(
repr(FunnyHeaders({foo: [bar, baz]})),
"FunnyHeaders({%r: [%r, %r]})" % (foo, bar, baz))
def test_copy(self):
"""
L{Headers.copy} creates a new independent copy of an existing
L{Headers} instance, allowing future modifications without impacts
between the copies.
"""
h = Headers()
h.setRawHeaders(b'test', [b'foo'])
i = h.copy()
self.assertEqual(i.getRawHeaders(b'test'), [b'foo'])
h.addRawHeader(b'test', b'bar')
self.assertEqual(i.getRawHeaders(b'test'), [b'foo'])
i.addRawHeader(b'test', b'baz')
self.assertEqual(h.getRawHeaders(b'test'), [b'foo', b'bar'])
class UnicodeHeadersTests(TestCase):
"""
Tests for L{Headers}, using L{unicode} arguments for methods.
"""
def test_sanitizeLinearWhitespace(self):
"""
Linear whitespace in header names or values is replaced with a
single space.
"""
assertSanitized(self, textLinearWhitespaceComponents, sanitizedBytes)
def test_initializer(self):
"""
The header values passed to L{Headers.__init__} can be retrieved via
L{Headers.getRawHeaders}. If a L{bytes} argument is given, it returns
L{bytes} values, and if a L{unicode} argument is given, it returns
L{unicode} values. Both are the same header value, just encoded or
decoded.
"""
h = Headers({u'Foo': [u'bar']})
self.assertEqual(h.getRawHeaders(b'foo'), [b'bar'])
self.assertEqual(h.getRawHeaders(u'foo'), [u'bar'])
def test_setRawHeaders(self):
"""
L{Headers.setRawHeaders} sets the header values for the given
header name to the sequence of strings, encoded.
"""
rawValue = [u"value1", u"value2"]
rawEncodedValue = [b"value1", b"value2"]
h = Headers()
h.setRawHeaders("test", rawValue)
self.assertTrue(h.hasHeader(b"test"))
self.assertTrue(h.hasHeader(b"Test"))
self.assertTrue(h.hasHeader("test"))
self.assertTrue(h.hasHeader("Test"))
self.assertEqual(h.getRawHeaders("test"), rawValue)
self.assertEqual(h.getRawHeaders(b"test"), rawEncodedValue)
def test_nameNotEncodable(self):
"""
Passing L{unicode} to any function that takes a header name will encode
said header name as ISO-8859-1, and if it cannot be encoded, it will
raise a L{UnicodeDecodeError}.
"""
h = Headers()
# Only these two functions take names
with self.assertRaises(UnicodeEncodeError):
h.setRawHeaders(u"\u2603", [u"val"])
with self.assertRaises(UnicodeEncodeError):
h.hasHeader(u"\u2603")
def test_nameEncoding(self):
"""
Passing L{unicode} to any function that takes a header name will encode
said header name as ISO-8859-1.
"""
h = Headers()
# We set it using a Unicode string.
h.setRawHeaders(u"\u00E1", [b"foo"])
# It's encoded to the ISO-8859-1 value, which we can use to access it
self.assertTrue(h.hasHeader(b"\xe1"))
self.assertEqual(h.getRawHeaders(b"\xe1"), [b'foo'])
# We can still access it using the Unicode string..
self.assertTrue(h.hasHeader(u"\u00E1"))
def test_rawHeadersValueEncoding(self):
"""
Passing L{unicode} to L{Headers.setRawHeaders} will encode the name as
ISO-8859-1 and values as UTF-8.
"""
h = Headers()
h.setRawHeaders(u"\u00E1", [u"\u2603", b"foo"])
self.assertTrue(h.hasHeader(b"\xe1"))
self.assertEqual(h.getRawHeaders(b"\xe1"), [b'\xe2\x98\x83', b'foo'])
def test_rawHeadersTypeChecking(self):
"""
L{Headers.setRawHeaders} requires values to be of type list.
"""
h = Headers()
self.assertRaises(TypeError, h.setRawHeaders, u'key', {u'Foo': u'bar'})
def test_addRawHeader(self):
"""
L{Headers.addRawHeader} adds a new value for a given header.
"""
h = Headers()
h.addRawHeader(u"test", u"lemur")
self.assertEqual(h.getRawHeaders(u"test"), [u"lemur"])
h.addRawHeader(u"test", u"panda")
self.assertEqual(h.getRawHeaders(u"test"), [u"lemur", u"panda"])
self.assertEqual(h.getRawHeaders(b"test"), [b"lemur", b"panda"])
def test_getRawHeadersNoDefault(self):
"""
L{Headers.getRawHeaders} returns L{None} if the header is not found and
no default is specified.
"""
self.assertIsNone(Headers().getRawHeaders(u"test"))
def test_getRawHeadersDefaultValue(self):
"""
L{Headers.getRawHeaders} returns the specified default value when no
header is found.
"""
h = Headers()
default = object()
self.assertIdentical(h.getRawHeaders(u"test", default), default)
self.assertIdentical(h.getRawHeaders(u"test", None), None)
self.assertEqual(h.getRawHeaders(u"test", [None]), [None])
self.assertEqual(
h.getRawHeaders(u"test", [u"\N{SNOWMAN}"]),
[u"\N{SNOWMAN}"],
)
def test_getRawHeadersWithDefaultMatchingValue(self):
"""
If the object passed as the value list to L{Headers.setRawHeaders}
is later passed as a default to L{Headers.getRawHeaders}, the
result nevertheless contains decoded values.
"""
h = Headers()
default = [b"value"]
h.setRawHeaders(b"key", default)
self.assertIsInstance(h.getRawHeaders(u"key", default)[0], unicode)
self.assertEqual(h.getRawHeaders(u"key", default), [u"value"])
def test_getRawHeaders(self):
"""
L{Headers.getRawHeaders} returns the values which have been set for a
given header.
"""
h = Headers()
h.setRawHeaders(u"test\u00E1", [u"lemur"])
self.assertEqual(h.getRawHeaders(u"test\u00E1"), [u"lemur"])
self.assertEqual(h.getRawHeaders(u"Test\u00E1"), [u"lemur"])
self.assertEqual(h.getRawHeaders(b"test\xe1"), [b"lemur"])
self.assertEqual(h.getRawHeaders(b"Test\xe1"), [b"lemur"])
def test_hasHeaderTrue(self):
"""
Check that L{Headers.hasHeader} returns C{True} when the given header
is found.
"""
h = Headers()
h.setRawHeaders(u"test\u00E1", [u"lemur"])
self.assertTrue(h.hasHeader(u"test\u00E1"))
self.assertTrue(h.hasHeader(u"Test\u00E1"))
self.assertTrue(h.hasHeader(b"test\xe1"))
self.assertTrue(h.hasHeader(b"Test\xe1"))
def test_hasHeaderFalse(self):
"""
L{Headers.hasHeader} returns C{False} when the given header is not
found.
"""
self.assertFalse(Headers().hasHeader(u"test\u00E1"))
def test_removeHeader(self):
"""
Check that L{Headers.removeHeader} removes the given header.
"""
h = Headers()
h.setRawHeaders(u"foo", [u"lemur"])
self.assertTrue(h.hasHeader(u"foo"))
h.removeHeader(u"foo")
self.assertFalse(h.hasHeader(u"foo"))
self.assertFalse(h.hasHeader(b"foo"))
h.setRawHeaders(u"bar", [u"panda"])
self.assertTrue(h.hasHeader(u"bar"))
h.removeHeader(u"Bar")
self.assertFalse(h.hasHeader(u"bar"))
self.assertFalse(h.hasHeader(b"bar"))
def test_removeHeaderDoesntExist(self):
"""
L{Headers.removeHeader} is a no-operation when the specified header is
not found.
"""
h = Headers()
h.removeHeader(u"test")
self.assertEqual(list(h.getAllRawHeaders()), [])
def test_getAllRawHeaders(self):
"""
L{Headers.getAllRawHeaders} returns an iterable of (k, v) pairs, where
C{k} is the canonicalized representation of the header name, and C{v}
is a sequence of values.
"""
h = Headers()
h.setRawHeaders(u"test\u00E1", [u"lemurs"])
h.setRawHeaders(u"www-authenticate", [u"basic aksljdlk="])
h.setRawHeaders(u"content-md5", [u"kjdfdfgdfgnsd"])
allHeaders = set([(k, tuple(v)) for k, v in h.getAllRawHeaders()])
self.assertEqual(allHeaders,
set([(b"WWW-Authenticate", (b"basic aksljdlk=",)),
(b"Content-MD5", (b"kjdfdfgdfgnsd",)),
(b"Test\xe1", (b"lemurs",))]))
def test_headersComparison(self):
"""
A L{Headers} instance compares equal to itself and to another
L{Headers} instance with the same values.
"""
first = Headers()
first.setRawHeaders(u"foo\u00E1", [u"panda"])
second = Headers()
second.setRawHeaders(u"foo\u00E1", [u"panda"])
third = Headers()
third.setRawHeaders(u"foo\u00E1", [u"lemur", u"panda"])
self.assertEqual(first, first)
self.assertEqual(first, second)
self.assertNotEqual(first, third)
# Headers instantiated with bytes equivs are also the same
firstBytes = Headers()
firstBytes.setRawHeaders(b"foo\xe1", [b"panda"])
secondBytes = Headers()
secondBytes.setRawHeaders(b"foo\xe1", [b"panda"])
thirdBytes = Headers()
thirdBytes.setRawHeaders(b"foo\xe1", [b"lemur", u"panda"])
self.assertEqual(first, firstBytes)
self.assertEqual(second, secondBytes)
self.assertEqual(third, thirdBytes)
def test_otherComparison(self):
"""
An instance of L{Headers} does not compare equal to other unrelated
objects.
"""
h = Headers()
self.assertNotEqual(h, ())
self.assertNotEqual(h, object())
self.assertNotEqual(h, u"foo")
def test_repr(self):
"""
The L{repr} of a L{Headers} instance shows the names and values of all
the headers it contains. This shows only reprs of bytes values, as
undecodable headers may cause an exception.
"""
foo = u"foo\u00E1"
bar = u"bar\u2603"
baz = u"baz"
fooEncoded = "'foo\\xe1'"
barEncoded = "'bar\\xe2\\x98\\x83'"
if _PY3:
fooEncoded = "b" + fooEncoded
barEncoded = "b" + barEncoded
self.assertEqual(
repr(Headers({foo: [bar, baz]})),
"Headers({%s: [%s, %r]})" % (fooEncoded,
barEncoded,
baz.encode('utf8')))
def test_subclassRepr(self):
"""
The L{repr} of an instance of a subclass of L{Headers} uses the name
of the subclass instead of the string C{"Headers"}.
"""
foo = u"foo\u00E1"
bar = u"bar\u2603"
baz = u"baz"
fooEncoded = "'foo\\xe1'"
barEncoded = "'bar\\xe2\\x98\\x83'"
if _PY3:
fooEncoded = "b" + fooEncoded
barEncoded = "b" + barEncoded
class FunnyHeaders(Headers):
pass
self.assertEqual(
repr(FunnyHeaders({foo: [bar, baz]})),
"FunnyHeaders({%s: [%s, %r]})" % (fooEncoded,
barEncoded,
baz.encode('utf8')))
def test_copy(self):
"""
L{Headers.copy} creates a new independent copy of an existing
L{Headers} instance, allowing future modifications without impacts
between the copies.
"""
h = Headers()
h.setRawHeaders(u'test\u00E1', [u'foo\u2603'])
i = h.copy()
# The copy contains the same value as the original
self.assertEqual(i.getRawHeaders(u'test\u00E1'), [u'foo\u2603'])
self.assertEqual(i.getRawHeaders(b'test\xe1'), [b'foo\xe2\x98\x83'])
# Add a header to the original
h.addRawHeader(u'test\u00E1', u'bar')
# Verify that the copy has not changed
self.assertEqual(i.getRawHeaders(u'test\u00E1'), [u'foo\u2603'])
self.assertEqual(i.getRawHeaders(b'test\xe1'), [b'foo\xe2\x98\x83'])
# Add a header to the copy
i.addRawHeader(u'test\u00E1', b'baz')
# Verify that the orignal does not have it
self.assertEqual(
h.getRawHeaders(u'test\u00E1'), [u'foo\u2603', u'bar'])
self.assertEqual(
h.getRawHeaders(b'test\xe1'), [b'foo\xe2\x98\x83', b'bar'])

View file

@ -0,0 +1,677 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.web._auth}.
"""
from __future__ import division, absolute_import
import base64
from zope.interface import implementer
from zope.interface.verify import verifyObject
from twisted.trial import unittest
from twisted.python.failure import Failure
from twisted.internet.error import ConnectionDone
from twisted.internet.address import IPv4Address
from twisted.cred import error, portal
from twisted.cred.checkers import InMemoryUsernamePasswordDatabaseDontUse
from twisted.cred.checkers import ANONYMOUS, AllowAnonymousAccess
from twisted.cred.credentials import IUsernamePassword
from twisted.web.iweb import ICredentialFactory
from twisted.web.resource import IResource, Resource, getChildForRequest
from twisted.web._auth import basic, digest
from twisted.web._auth.wrapper import HTTPAuthSessionWrapper, UnauthorizedResource
from twisted.web._auth.basic import BasicCredentialFactory
from twisted.web.server import NOT_DONE_YET
from twisted.web.static import Data
from twisted.web.test.test_web import DummyRequest
from twisted.test.proto_helpers import EventLoggingObserver
from twisted.logger import globalLogPublisher
def b64encode(s):
return base64.b64encode(s).strip()
class BasicAuthTestsMixin:
"""
L{TestCase} mixin class which defines a number of tests for
L{basic.BasicCredentialFactory}. Because this mixin defines C{setUp}, it
must be inherited before L{TestCase}.
"""
def setUp(self):
self.request = self.makeRequest()
self.realm = b'foo'
self.username = b'dreid'
self.password = b'S3CuR1Ty'
self.credentialFactory = basic.BasicCredentialFactory(self.realm)
def makeRequest(self, method=b'GET', clientAddress=None):
"""
Create a request object to be passed to
L{basic.BasicCredentialFactory.decode} along with a response value.
Override this in a subclass.
"""
raise NotImplementedError("%r did not implement makeRequest" % (
self.__class__,))
def test_interface(self):
"""
L{BasicCredentialFactory} implements L{ICredentialFactory}.
"""
self.assertTrue(
verifyObject(ICredentialFactory, self.credentialFactory))
def test_usernamePassword(self):
"""
L{basic.BasicCredentialFactory.decode} turns a base64-encoded response
into a L{UsernamePassword} object with a password which reflects the
one which was encoded in the response.
"""
response = b64encode(b''.join([self.username, b':', self.password]))
creds = self.credentialFactory.decode(response, self.request)
self.assertTrue(IUsernamePassword.providedBy(creds))
self.assertTrue(creds.checkPassword(self.password))
self.assertFalse(creds.checkPassword(self.password + b'wrong'))
def test_incorrectPadding(self):
"""
L{basic.BasicCredentialFactory.decode} decodes a base64-encoded
response with incorrect padding.
"""
response = b64encode(b''.join([self.username, b':', self.password]))
response = response.strip(b'=')
creds = self.credentialFactory.decode(response, self.request)
self.assertTrue(verifyObject(IUsernamePassword, creds))
self.assertTrue(creds.checkPassword(self.password))
def test_invalidEncoding(self):
"""
L{basic.BasicCredentialFactory.decode} raises L{LoginFailed} if passed
a response which is not base64-encoded.
"""
response = b'x' # one byte cannot be valid base64 text
self.assertRaises(
error.LoginFailed,
self.credentialFactory.decode, response, self.makeRequest())
def test_invalidCredentials(self):
"""
L{basic.BasicCredentialFactory.decode} raises L{LoginFailed} when
passed a response which is not valid base64-encoded text.
"""
response = b64encode(b'123abc+/')
self.assertRaises(
error.LoginFailed,
self.credentialFactory.decode,
response, self.makeRequest())
class RequestMixin:
def makeRequest(self, method=b'GET', clientAddress=None):
"""
Create a L{DummyRequest} (change me to create a
L{twisted.web.http.Request} instead).
"""
if clientAddress is None:
clientAddress = IPv4Address("TCP", "localhost", 1234)
request = DummyRequest(b'/')
request.method = method
request.client = clientAddress
return request
class BasicAuthTests(RequestMixin, BasicAuthTestsMixin, unittest.TestCase):
"""
Basic authentication tests which use L{twisted.web.http.Request}.
"""
class DigestAuthTests(RequestMixin, unittest.TestCase):
"""
Digest authentication tests which use L{twisted.web.http.Request}.
"""
def setUp(self):
"""
Create a DigestCredentialFactory for testing
"""
self.realm = b"test realm"
self.algorithm = b"md5"
self.credentialFactory = digest.DigestCredentialFactory(
self.algorithm, self.realm)
self.request = self.makeRequest()
def test_decode(self):
"""
L{digest.DigestCredentialFactory.decode} calls the C{decode} method on
L{twisted.cred.digest.DigestCredentialFactory} with the HTTP method and
host of the request.
"""
host = b'169.254.0.1'
method = b'GET'
done = [False]
response = object()
def check(_response, _method, _host):
self.assertEqual(response, _response)
self.assertEqual(method, _method)
self.assertEqual(host, _host)
done[0] = True
self.patch(self.credentialFactory.digest, 'decode', check)
req = self.makeRequest(method, IPv4Address('TCP', host, 81))
self.credentialFactory.decode(response, req)
self.assertTrue(done[0])
def test_interface(self):
"""
L{DigestCredentialFactory} implements L{ICredentialFactory}.
"""
self.assertTrue(
verifyObject(ICredentialFactory, self.credentialFactory))
def test_getChallenge(self):
"""
The challenge issued by L{DigestCredentialFactory.getChallenge} must
include C{'qop'}, C{'realm'}, C{'algorithm'}, C{'nonce'}, and
C{'opaque'} keys. The values for the C{'realm'} and C{'algorithm'}
keys must match the values supplied to the factory's initializer.
None of the values may have newlines in them.
"""
challenge = self.credentialFactory.getChallenge(self.request)
self.assertEqual(challenge['qop'], b'auth')
self.assertEqual(challenge['realm'], b'test realm')
self.assertEqual(challenge['algorithm'], b'md5')
self.assertIn('nonce', challenge)
self.assertIn('opaque', challenge)
for v in challenge.values():
self.assertNotIn(b'\n', v)
def test_getChallengeWithoutClientIP(self):
"""
L{DigestCredentialFactory.getChallenge} can issue a challenge even if
the L{Request} it is passed returns L{None} from C{getClientIP}.
"""
request = self.makeRequest(b'GET', None)
challenge = self.credentialFactory.getChallenge(request)
self.assertEqual(challenge['qop'], b'auth')
self.assertEqual(challenge['realm'], b'test realm')
self.assertEqual(challenge['algorithm'], b'md5')
self.assertIn('nonce', challenge)
self.assertIn('opaque', challenge)
class UnauthorizedResourceTests(RequestMixin, unittest.TestCase):
"""
Tests for L{UnauthorizedResource}.
"""
def test_getChildWithDefault(self):
"""
An L{UnauthorizedResource} is every child of itself.
"""
resource = UnauthorizedResource([])
self.assertIdentical(
resource.getChildWithDefault("foo", None), resource)
self.assertIdentical(
resource.getChildWithDefault("bar", None), resource)
def _unauthorizedRenderTest(self, request):
"""
Render L{UnauthorizedResource} for the given request object and verify
that the response code is I{Unauthorized} and that a I{WWW-Authenticate}
header is set in the response containing a challenge.
"""
resource = UnauthorizedResource([
BasicCredentialFactory('example.com')])
request.render(resource)
self.assertEqual(request.responseCode, 401)
self.assertEqual(
request.responseHeaders.getRawHeaders(b'www-authenticate'),
[b'basic realm="example.com"'])
def test_render(self):
"""
L{UnauthorizedResource} renders with a 401 response code and a
I{WWW-Authenticate} header and puts a simple unauthorized message
into the response body.
"""
request = self.makeRequest()
self._unauthorizedRenderTest(request)
self.assertEqual(b'Unauthorized', b''.join(request.written))
def test_renderHEAD(self):
"""
The rendering behavior of L{UnauthorizedResource} for a I{HEAD} request
is like its handling of a I{GET} request, but no response body is
written.
"""
request = self.makeRequest(method=b'HEAD')
self._unauthorizedRenderTest(request)
self.assertEqual(b'', b''.join(request.written))
def test_renderQuotesRealm(self):
"""
The realm value included in the I{WWW-Authenticate} header set in
the response when L{UnauthorizedResounrce} is rendered has quotes
and backslashes escaped.
"""
resource = UnauthorizedResource([
BasicCredentialFactory('example\\"foo')])
request = self.makeRequest()
request.render(resource)
self.assertEqual(
request.responseHeaders.getRawHeaders(b'www-authenticate'),
[b'basic realm="example\\\\\\"foo"'])
def test_renderQuotesDigest(self):
"""
The digest value included in the I{WWW-Authenticate} header
set in the response when L{UnauthorizedResource} is rendered
has quotes and backslashes escaped.
"""
resource = UnauthorizedResource([
digest.DigestCredentialFactory(b'md5', b'example\\"foo')])
request = self.makeRequest()
request.render(resource)
authHeader = request.responseHeaders.getRawHeaders(
b'www-authenticate'
)[0]
self.assertIn(b'realm="example\\\\\\"foo"', authHeader)
self.assertIn(b'hm="md5', authHeader)
implementer(portal.IRealm)
class Realm(object):
"""
A simple L{IRealm} implementation which gives out L{WebAvatar} for any
avatarId.
@type loggedIn: C{int}
@ivar loggedIn: The number of times C{requestAvatar} has been invoked for
L{IResource}.
@type loggedOut: C{int}
@ivar loggedOut: The number of times the logout callback has been invoked.
"""
def __init__(self, avatarFactory):
self.loggedOut = 0
self.loggedIn = 0
self.avatarFactory = avatarFactory
def requestAvatar(self, avatarId, mind, *interfaces):
if IResource in interfaces:
self.loggedIn += 1
return IResource, self.avatarFactory(avatarId), self.logout
raise NotImplementedError()
def logout(self):
self.loggedOut += 1
class HTTPAuthHeaderTests(unittest.TestCase):
"""
Tests for L{HTTPAuthSessionWrapper}.
"""
makeRequest = DummyRequest
def setUp(self):
"""
Create a realm, portal, and L{HTTPAuthSessionWrapper} to use in the tests.
"""
self.username = b'foo bar'
self.password = b'bar baz'
self.avatarContent = b"contents of the avatar resource itself"
self.childName = b"foo-child"
self.childContent = b"contents of the foo child of the avatar"
self.checker = InMemoryUsernamePasswordDatabaseDontUse()
self.checker.addUser(self.username, self.password)
self.avatar = Data(self.avatarContent, 'text/plain')
self.avatar.putChild(
self.childName, Data(self.childContent, 'text/plain'))
self.avatars = {self.username: self.avatar}
self.realm = Realm(self.avatars.get)
self.portal = portal.Portal(self.realm, [self.checker])
self.credentialFactories = []
self.wrapper = HTTPAuthSessionWrapper(
self.portal, self.credentialFactories)
def _authorizedBasicLogin(self, request):
"""
Add an I{basic authorization} header to the given request and then
dispatch it, starting from C{self.wrapper} and returning the resulting
L{IResource}.
"""
authorization = b64encode(self.username + b':' + self.password)
request.requestHeaders.addRawHeader(b'authorization',
b'Basic ' + authorization)
return getChildForRequest(self.wrapper, request)
def test_getChildWithDefault(self):
"""
Resource traversal which encounters an L{HTTPAuthSessionWrapper}
results in an L{UnauthorizedResource} instance when the request does
not have the required I{Authorization} headers.
"""
request = self.makeRequest([self.childName])
child = getChildForRequest(self.wrapper, request)
d = request.notifyFinish()
def cbFinished(result):
self.assertEqual(request.responseCode, 401)
d.addCallback(cbFinished)
request.render(child)
return d
def _invalidAuthorizationTest(self, response):
"""
Create a request with the given value as the value of an
I{Authorization} header and perform resource traversal with it,
starting at C{self.wrapper}. Assert that the result is a 401 response
code. Return a L{Deferred} which fires when this is all done.
"""
self.credentialFactories.append(BasicCredentialFactory('example.com'))
request = self.makeRequest([self.childName])
request.requestHeaders.addRawHeader(b'authorization', response)
child = getChildForRequest(self.wrapper, request)
d = request.notifyFinish()
def cbFinished(result):
self.assertEqual(request.responseCode, 401)
d.addCallback(cbFinished)
request.render(child)
return d
def test_getChildWithDefaultUnauthorizedUser(self):
"""
Resource traversal which enouncters an L{HTTPAuthSessionWrapper}
results in an L{UnauthorizedResource} when the request has an
I{Authorization} header with a user which does not exist.
"""
return self._invalidAuthorizationTest(
b'Basic ' + b64encode(b'foo:bar'))
def test_getChildWithDefaultUnauthorizedPassword(self):
"""
Resource traversal which enouncters an L{HTTPAuthSessionWrapper}
results in an L{UnauthorizedResource} when the request has an
I{Authorization} header with a user which exists and the wrong
password.
"""
return self._invalidAuthorizationTest(
b'Basic ' + b64encode(self.username + b':bar'))
def test_getChildWithDefaultUnrecognizedScheme(self):
"""
Resource traversal which enouncters an L{HTTPAuthSessionWrapper}
results in an L{UnauthorizedResource} when the request has an
I{Authorization} header with an unrecognized scheme.
"""
return self._invalidAuthorizationTest(b'Quux foo bar baz')
def test_getChildWithDefaultAuthorized(self):
"""
Resource traversal which encounters an L{HTTPAuthSessionWrapper}
results in an L{IResource} which renders the L{IResource} avatar
retrieved from the portal when the request has a valid I{Authorization}
header.
"""
self.credentialFactories.append(BasicCredentialFactory('example.com'))
request = self.makeRequest([self.childName])
child = self._authorizedBasicLogin(request)
d = request.notifyFinish()
def cbFinished(ignored):
self.assertEqual(request.written, [self.childContent])
d.addCallback(cbFinished)
request.render(child)
return d
def test_renderAuthorized(self):
"""
Resource traversal which terminates at an L{HTTPAuthSessionWrapper}
and includes correct authentication headers results in the
L{IResource} avatar (not one of its children) retrieved from the
portal being rendered.
"""
self.credentialFactories.append(BasicCredentialFactory('example.com'))
# Request it exactly, not any of its children.
request = self.makeRequest([])
child = self._authorizedBasicLogin(request)
d = request.notifyFinish()
def cbFinished(ignored):
self.assertEqual(request.written, [self.avatarContent])
d.addCallback(cbFinished)
request.render(child)
return d
def test_getChallengeCalledWithRequest(self):
"""
When L{HTTPAuthSessionWrapper} finds an L{ICredentialFactory} to issue
a challenge, it calls the C{getChallenge} method with the request as an
argument.
"""
@implementer(ICredentialFactory)
class DumbCredentialFactory(object):
scheme = b'dumb'
def __init__(self):
self.requests = []
def getChallenge(self, request):
self.requests.append(request)
return {}
factory = DumbCredentialFactory()
self.credentialFactories.append(factory)
request = self.makeRequest([self.childName])
child = getChildForRequest(self.wrapper, request)
d = request.notifyFinish()
def cbFinished(ignored):
self.assertEqual(factory.requests, [request])
d.addCallback(cbFinished)
request.render(child)
return d
def _logoutTest(self):
"""
Issue a request for an authentication-protected resource using valid
credentials and then return the C{DummyRequest} instance which was
used.
This is a helper for tests about the behavior of the logout
callback.
"""
self.credentialFactories.append(BasicCredentialFactory('example.com'))
class SlowerResource(Resource):
def render(self, request):
return NOT_DONE_YET
self.avatar.putChild(self.childName, SlowerResource())
request = self.makeRequest([self.childName])
child = self._authorizedBasicLogin(request)
request.render(child)
self.assertEqual(self.realm.loggedOut, 0)
return request
def test_logout(self):
"""
The realm's logout callback is invoked after the resource is rendered.
"""
request = self._logoutTest()
request.finish()
self.assertEqual(self.realm.loggedOut, 1)
def test_logoutOnError(self):
"""
The realm's logout callback is also invoked if there is an error
generating the response (for example, if the client disconnects
early).
"""
request = self._logoutTest()
request.processingFailed(
Failure(ConnectionDone("Simulated disconnect")))
self.assertEqual(self.realm.loggedOut, 1)
def test_decodeRaises(self):
"""
Resource traversal which enouncters an L{HTTPAuthSessionWrapper}
results in an L{UnauthorizedResource} when the request has a I{Basic
Authorization} header which cannot be decoded using base64.
"""
self.credentialFactories.append(BasicCredentialFactory('example.com'))
request = self.makeRequest([self.childName])
request.requestHeaders.addRawHeader(b'authorization', b'Basic decode should fail')
child = getChildForRequest(self.wrapper, request)
self.assertIsInstance(child, UnauthorizedResource)
def test_selectParseResponse(self):
"""
L{HTTPAuthSessionWrapper._selectParseHeader} returns a two-tuple giving
the L{ICredentialFactory} to use to parse the header and a string
containing the portion of the header which remains to be parsed.
"""
basicAuthorization = b'Basic abcdef123456'
self.assertEqual(
self.wrapper._selectParseHeader(basicAuthorization),
(None, None))
factory = BasicCredentialFactory('example.com')
self.credentialFactories.append(factory)
self.assertEqual(
self.wrapper._selectParseHeader(basicAuthorization),
(factory, b'abcdef123456'))
def test_unexpectedDecodeError(self):
"""
Any unexpected exception raised by the credential factory's C{decode}
method results in a 500 response code and causes the exception to be
logged.
"""
logObserver = EventLoggingObserver.createWithCleanup(
self,
globalLogPublisher
)
class UnexpectedException(Exception):
pass
class BadFactory(object):
scheme = b'bad'
def getChallenge(self, client):
return {}
def decode(self, response, request):
raise UnexpectedException()
self.credentialFactories.append(BadFactory())
request = self.makeRequest([self.childName])
request.requestHeaders.addRawHeader(b'authorization', b'Bad abc')
child = getChildForRequest(self.wrapper, request)
request.render(child)
self.assertEqual(request.responseCode, 500)
self.assertEquals(1, len(logObserver))
self.assertIsInstance(
logObserver[0]["log_failure"].value,
UnexpectedException
)
self.assertEqual(len(self.flushLoggedErrors(UnexpectedException)), 1)
def test_unexpectedLoginError(self):
"""
Any unexpected failure from L{Portal.login} results in a 500 response
code and causes the failure to be logged.
"""
logObserver = EventLoggingObserver.createWithCleanup(
self,
globalLogPublisher
)
class UnexpectedException(Exception):
pass
class BrokenChecker(object):
credentialInterfaces = (IUsernamePassword,)
def requestAvatarId(self, credentials):
raise UnexpectedException()
self.portal.registerChecker(BrokenChecker())
self.credentialFactories.append(BasicCredentialFactory('example.com'))
request = self.makeRequest([self.childName])
child = self._authorizedBasicLogin(request)
request.render(child)
self.assertEqual(request.responseCode, 500)
self.assertEquals(1, len(logObserver))
self.assertIsInstance(
logObserver[0]["log_failure"].value,
UnexpectedException
)
self.assertEqual(len(self.flushLoggedErrors(UnexpectedException)), 1)
def test_anonymousAccess(self):
"""
Anonymous requests are allowed if a L{Portal} has an anonymous checker
registered.
"""
unprotectedContents = b"contents of the unprotected child resource"
self.avatars[ANONYMOUS] = Resource()
self.avatars[ANONYMOUS].putChild(
self.childName, Data(unprotectedContents, 'text/plain'))
self.portal.registerChecker(AllowAnonymousAccess())
self.credentialFactories.append(BasicCredentialFactory('example.com'))
request = self.makeRequest([self.childName])
child = getChildForRequest(self.wrapper, request)
d = request.notifyFinish()
def cbFinished(ignored):
self.assertEqual(request.written, [unprotectedContents])
d.addCallback(cbFinished)
request.render(child)
return d

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,573 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Test for L{twisted.web.proxy}.
"""
from twisted.trial.unittest import TestCase
from twisted.test.proto_helpers import StringTransportWithDisconnection
from twisted.test.proto_helpers import MemoryReactor
from twisted.web.resource import Resource
from twisted.web.server import Site
from twisted.web.proxy import ReverseProxyResource, ProxyClientFactory
from twisted.web.proxy import ProxyClient, ProxyRequest, ReverseProxyRequest
from twisted.web.test.test_web import DummyRequest
class ReverseProxyResourceTests(TestCase):
"""
Tests for L{ReverseProxyResource}.
"""
def _testRender(self, uri, expectedURI):
"""
Check that a request pointing at C{uri} produce a new proxy connection,
with the path of this request pointing at C{expectedURI}.
"""
root = Resource()
reactor = MemoryReactor()
resource = ReverseProxyResource(u"127.0.0.1", 1234, b"/path", reactor)
root.putChild(b'index', resource)
site = Site(root)
transport = StringTransportWithDisconnection()
channel = site.buildProtocol(None)
channel.makeConnection(transport)
# Clear the timeout if the tests failed
self.addCleanup(channel.connectionLost, None)
channel.dataReceived(b"GET " +
uri +
b" HTTP/1.1\r\nAccept: text/html\r\n\r\n")
[(host, port, factory, _timeout, _bind_addr)] = reactor.tcpClients
# Check that one connection has been created, to the good host/port
self.assertEqual(host, u"127.0.0.1")
self.assertEqual(port, 1234)
# Check the factory passed to the connect, and its given path
self.assertIsInstance(factory, ProxyClientFactory)
self.assertEqual(factory.rest, expectedURI)
self.assertEqual(factory.headers[b"host"], b"127.0.0.1:1234")
def test_render(self):
"""
Test that L{ReverseProxyResource.render} initiates a connection to the
given server with a L{ProxyClientFactory} as parameter.
"""
return self._testRender(b"/index", b"/path")
def test_render_subpage(self):
"""
Test that L{ReverseProxyResource.render} will instantiate a child
resource that will initiate a connection to the given server
requesting the apropiate url subpath.
"""
return self._testRender(b"/index/page1", b"/path/page1")
def test_renderWithQuery(self):
"""
Test that L{ReverseProxyResource.render} passes query parameters to the
created factory.
"""
return self._testRender(b"/index?foo=bar", b"/path?foo=bar")
def test_getChild(self):
"""
The L{ReverseProxyResource.getChild} method should return a resource
instance with the same class as the originating resource, forward
port, host, and reactor values, and update the path value with the
value passed.
"""
reactor = MemoryReactor()
resource = ReverseProxyResource(u"127.0.0.1", 1234, b"/path", reactor)
child = resource.getChild(b'foo', None)
# The child should keep the same class
self.assertIsInstance(child, ReverseProxyResource)
self.assertEqual(child.path, b"/path/foo")
self.assertEqual(child.port, 1234)
self.assertEqual(child.host, u"127.0.0.1")
self.assertIdentical(child.reactor, resource.reactor)
def test_getChildWithSpecial(self):
"""
The L{ReverseProxyResource} return by C{getChild} has a path which has
already been quoted.
"""
resource = ReverseProxyResource(u"127.0.0.1", 1234, b"/path")
child = resource.getChild(b' /%', None)
self.assertEqual(child.path, b"/path/%20%2F%25")
class DummyChannel(object):
"""
A dummy HTTP channel, that does nothing but holds a transport and saves
connection lost.
@ivar transport: the transport used by the client.
@ivar lostReason: the reason saved at connection lost.
"""
def __init__(self, transport):
"""
Hold a reference to the transport.
"""
self.transport = transport
self.lostReason = None
def connectionLost(self, reason):
"""
Keep track of the connection lost reason.
"""
self.lostReason = reason
def getPeer(self):
"""
Get peer information from the transport.
"""
return self.transport.getPeer()
def getHost(self):
"""
Get host information from the transport.
"""
return self.transport.getHost()
class ProxyClientTests(TestCase):
"""
Tests for L{ProxyClient}.
"""
def _parseOutHeaders(self, content):
"""
Parse the headers out of some web content.
@param content: Bytes received from a web server.
@return: A tuple of (requestLine, headers, body). C{headers} is a dict
of headers, C{requestLine} is the first line (e.g. "POST /foo ...")
and C{body} is whatever is left.
"""
headers, body = content.split(b'\r\n\r\n')
headers = headers.split(b'\r\n')
requestLine = headers.pop(0)
return (
requestLine, dict(header.split(b': ') for header in headers), body)
def makeRequest(self, path):
"""
Make a dummy request object for the URL path.
@param path: A URL path, beginning with a slash.
@return: A L{DummyRequest}.
"""
return DummyRequest(path)
def makeProxyClient(self, request, method=b"GET", headers=None,
requestBody=b""):
"""
Make a L{ProxyClient} object used for testing.
@param request: The request to use.
@param method: The HTTP method to use, GET by default.
@param headers: The HTTP headers to use expressed as a dict. If not
provided, defaults to {'accept': 'text/html'}.
@param requestBody: The body of the request. Defaults to the empty
string.
@return: A L{ProxyClient}
"""
if headers is None:
headers = {b"accept": b"text/html"}
path = b'/' + request.postpath
return ProxyClient(
method, path, b'HTTP/1.0', headers, requestBody, request)
def connectProxy(self, proxyClient):
"""
Connect a proxy client to a L{StringTransportWithDisconnection}.
@param proxyClient: A L{ProxyClient}.
@return: The L{StringTransportWithDisconnection}.
"""
clientTransport = StringTransportWithDisconnection()
clientTransport.protocol = proxyClient
proxyClient.makeConnection(clientTransport)
return clientTransport
def assertForwardsHeaders(self, proxyClient, requestLine, headers):
"""
Assert that C{proxyClient} sends C{headers} when it connects.
@param proxyClient: A L{ProxyClient}.
@param requestLine: The request line we expect to be sent.
@param headers: A dict of headers we expect to be sent.
@return: If the assertion is successful, return the request body as
bytes.
"""
self.connectProxy(proxyClient)
requestContent = proxyClient.transport.value()
receivedLine, receivedHeaders, body = self._parseOutHeaders(
requestContent)
self.assertEqual(receivedLine, requestLine)
self.assertEqual(receivedHeaders, headers)
return body
def makeResponseBytes(self, code, message, headers, body):
lines = [b"HTTP/1.0 " + str(code).encode('ascii') + b' ' + message]
for header, values in headers:
for value in values:
lines.append(header + b': ' + value)
lines.extend([b'', body])
return b'\r\n'.join(lines)
def assertForwardsResponse(self, request, code, message, headers, body):
"""
Assert that C{request} has forwarded a response from the server.
@param request: A L{DummyRequest}.
@param code: The expected HTTP response code.
@param message: The expected HTTP message.
@param headers: The expected HTTP headers.
@param body: The expected response body.
"""
self.assertEqual(request.responseCode, code)
self.assertEqual(request.responseMessage, message)
receivedHeaders = list(request.responseHeaders.getAllRawHeaders())
receivedHeaders.sort()
expectedHeaders = headers[:]
expectedHeaders.sort()
self.assertEqual(receivedHeaders, expectedHeaders)
self.assertEqual(b''.join(request.written), body)
def _testDataForward(self, code, message, headers, body, method=b"GET",
requestBody=b"", loseConnection=True):
"""
Build a fake proxy connection, and send C{data} over it, checking that
it's forwarded to the originating request.
"""
request = self.makeRequest(b'foo')
client = self.makeProxyClient(
request, method, {b'accept': b'text/html'}, requestBody)
receivedBody = self.assertForwardsHeaders(
client, method + b' /foo HTTP/1.0',
{b'connection': b'close', b'accept': b'text/html'})
self.assertEqual(receivedBody, requestBody)
# Fake an answer
client.dataReceived(
self.makeResponseBytes(code, message, headers, body))
# Check that the response data has been forwarded back to the original
# requester.
self.assertForwardsResponse(request, code, message, headers, body)
# Check that when the response is done, the request is finished.
if loseConnection:
client.transport.loseConnection()
# Even if we didn't call loseConnection, the transport should be
# disconnected. This lets us not rely on the server to close our
# sockets for us.
self.assertFalse(client.transport.connected)
self.assertEqual(request.finished, 1)
def test_forward(self):
"""
When connected to the server, L{ProxyClient} should send the saved
request, with modifications of the headers, and then forward the result
to the parent request.
"""
return self._testDataForward(
200, b"OK", [(b"Foo", [b"bar", b"baz"])], b"Some data\r\n")
def test_postData(self):
"""
Try to post content in the request, and check that the proxy client
forward the body of the request.
"""
return self._testDataForward(
200, b"OK", [(b"Foo", [b"bar"])], b"Some data\r\n", b"POST", b"Some content")
def test_statusWithMessage(self):
"""
If the response contains a status with a message, it should be
forwarded to the parent request with all the information.
"""
return self._testDataForward(
404, b"Not Found", [], b"")
def test_contentLength(self):
"""
If the response contains a I{Content-Length} header, the inbound
request object should still only have C{finish} called on it once.
"""
data = b"foo bar baz"
return self._testDataForward(
200,
b"OK",
[(b"Content-Length", [str(len(data)).encode('ascii')])],
data)
def test_losesConnection(self):
"""
If the response contains a I{Content-Length} header, the outgoing
connection is closed when all response body data has been received.
"""
data = b"foo bar baz"
return self._testDataForward(
200,
b"OK",
[(b"Content-Length", [str(len(data)).encode('ascii')])],
data,
loseConnection=False)
def test_headersCleanups(self):
"""
The headers given at initialization should be modified:
B{proxy-connection} should be removed if present, and B{connection}
should be added.
"""
client = ProxyClient(b'GET', b'/foo', b'HTTP/1.0',
{b"accept": b"text/html", b"proxy-connection": b"foo"}, b'', None)
self.assertEqual(client.headers,
{b"accept": b"text/html", b"connection": b"close"})
def test_keepaliveNotForwarded(self):
"""
The proxy doesn't really know what to do with keepalive things from
the remote server, so we stomp over any keepalive header we get from
the client.
"""
headers = {
b"accept": b"text/html",
b'keep-alive': b'300',
b'connection': b'keep-alive',
}
expectedHeaders = headers.copy()
expectedHeaders[b'connection'] = b'close'
del expectedHeaders[b'keep-alive']
client = ProxyClient(b'GET', b'/foo', b'HTTP/1.0', headers, b'', None)
self.assertForwardsHeaders(
client, b'GET /foo HTTP/1.0', expectedHeaders)
def test_defaultHeadersOverridden(self):
"""
L{server.Request} within the proxy sets certain response headers by
default. When we get these headers back from the remote server, the
defaults are overridden rather than simply appended.
"""
request = self.makeRequest(b'foo')
request.responseHeaders.setRawHeaders(b'server', [b'old-bar'])
request.responseHeaders.setRawHeaders(b'date', [b'old-baz'])
request.responseHeaders.setRawHeaders(b'content-type', [b"old/qux"])
client = self.makeProxyClient(request, headers={b'accept': b'text/html'})
self.connectProxy(client)
headers = {
b'Server': [b'bar'],
b'Date': [b'2010-01-01'],
b'Content-Type': [b'application/x-baz'],
}
client.dataReceived(
self.makeResponseBytes(200, b"OK", headers.items(), b''))
self.assertForwardsResponse(
request, 200, b'OK', list(headers.items()), b'')
class ProxyClientFactoryTests(TestCase):
"""
Tests for L{ProxyClientFactory}.
"""
def test_connectionFailed(self):
"""
Check that L{ProxyClientFactory.clientConnectionFailed} produces
a B{501} response to the parent request.
"""
request = DummyRequest([b'foo'])
factory = ProxyClientFactory(b'GET', b'/foo', b'HTTP/1.0',
{b"accept": b"text/html"}, '', request)
factory.clientConnectionFailed(None, None)
self.assertEqual(request.responseCode, 501)
self.assertEqual(request.responseMessage, b"Gateway error")
self.assertEqual(
list(request.responseHeaders.getAllRawHeaders()),
[(b"Content-Type", [b"text/html"])])
self.assertEqual(
b''.join(request.written),
b"<H1>Could not connect</H1>")
self.assertEqual(request.finished, 1)
def test_buildProtocol(self):
"""
L{ProxyClientFactory.buildProtocol} should produce a L{ProxyClient}
with the same values of attributes (with updates on the headers).
"""
factory = ProxyClientFactory(b'GET', b'/foo', b'HTTP/1.0',
{b"accept": b"text/html"}, b'Some data',
None)
proto = factory.buildProtocol(None)
self.assertIsInstance(proto, ProxyClient)
self.assertEqual(proto.command, b'GET')
self.assertEqual(proto.rest, b'/foo')
self.assertEqual(proto.data, b'Some data')
self.assertEqual(proto.headers,
{b"accept": b"text/html", b"connection": b"close"})
class ProxyRequestTests(TestCase):
"""
Tests for L{ProxyRequest}.
"""
def _testProcess(self, uri, expectedURI, method=b"GET", data=b""):
"""
Build a request pointing at C{uri}, and check that a proxied request
is created, pointing a C{expectedURI}.
"""
transport = StringTransportWithDisconnection()
channel = DummyChannel(transport)
reactor = MemoryReactor()
request = ProxyRequest(channel, False, reactor)
request.gotLength(len(data))
request.handleContentChunk(data)
request.requestReceived(method, b'http://example.com' + uri,
b'HTTP/1.0')
self.assertEqual(len(reactor.tcpClients), 1)
self.assertEqual(reactor.tcpClients[0][0], u"example.com")
self.assertEqual(reactor.tcpClients[0][1], 80)
factory = reactor.tcpClients[0][2]
self.assertIsInstance(factory, ProxyClientFactory)
self.assertEqual(factory.command, method)
self.assertEqual(factory.version, b'HTTP/1.0')
self.assertEqual(factory.headers, {b'host': b'example.com'})
self.assertEqual(factory.data, data)
self.assertEqual(factory.rest, expectedURI)
self.assertEqual(factory.father, request)
def test_process(self):
"""
L{ProxyRequest.process} should create a connection to the given server,
with a L{ProxyClientFactory} as connection factory, with the correct
parameters:
- forward comment, version and data values
- update headers with the B{host} value
- remove the host from the URL
- pass the request as parent request
"""
return self._testProcess(b"/foo/bar", b"/foo/bar")
def test_processWithoutTrailingSlash(self):
"""
If the incoming request doesn't contain a slash,
L{ProxyRequest.process} should add one when instantiating
L{ProxyClientFactory}.
"""
return self._testProcess(b"", b"/")
def test_processWithData(self):
"""
L{ProxyRequest.process} should be able to retrieve request body and
to forward it.
"""
return self._testProcess(
b"/foo/bar", b"/foo/bar", b"POST", b"Some content")
def test_processWithPort(self):
"""
Check that L{ProxyRequest.process} correctly parse port in the incoming
URL, and create an outgoing connection with this port.
"""
transport = StringTransportWithDisconnection()
channel = DummyChannel(transport)
reactor = MemoryReactor()
request = ProxyRequest(channel, False, reactor)
request.gotLength(0)
request.requestReceived(b'GET', b'http://example.com:1234/foo/bar',
b'HTTP/1.0')
# That should create one connection, with the port parsed from the URL
self.assertEqual(len(reactor.tcpClients), 1)
self.assertEqual(reactor.tcpClients[0][0], u"example.com")
self.assertEqual(reactor.tcpClients[0][1], 1234)
class DummyFactory(object):
"""
A simple holder for C{host} and C{port} information.
"""
def __init__(self, host, port):
self.host = host
self.port = port
class ReverseProxyRequestTests(TestCase):
"""
Tests for L{ReverseProxyRequest}.
"""
def test_process(self):
"""
L{ReverseProxyRequest.process} should create a connection to its
factory host/port, using a L{ProxyClientFactory} instantiated with the
correct parameters, and particularly set the B{host} header to the
factory host.
"""
transport = StringTransportWithDisconnection()
channel = DummyChannel(transport)
reactor = MemoryReactor()
request = ReverseProxyRequest(channel, False, reactor)
request.factory = DummyFactory(u"example.com", 1234)
request.gotLength(0)
request.requestReceived(b'GET', b'/foo/bar', b'HTTP/1.0')
# Check that one connection has been created, to the good host/port
self.assertEqual(len(reactor.tcpClients), 1)
self.assertEqual(reactor.tcpClients[0][0], u"example.com")
self.assertEqual(reactor.tcpClients[0][1], 1234)
# Check the factory passed to the connect, and its headers
factory = reactor.tcpClients[0][2]
self.assertIsInstance(factory, ProxyClientFactory)
self.assertEqual(factory.headers, {b'host': b'example.com'})

View file

@ -0,0 +1,289 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.web.resource}.
"""
from twisted.trial.unittest import TestCase
from twisted.python.compat import _PY3
from twisted.web.error import UnsupportedMethod
from twisted.web.resource import (
NOT_FOUND, FORBIDDEN, Resource, ErrorPage, NoResource, ForbiddenResource,
getChildForRequest)
from twisted.web.http_headers import Headers
from twisted.web.test.requesthelper import DummyRequest
class ErrorPageTests(TestCase):
"""
Tests for L{ErrorPage}, L{NoResource}, and L{ForbiddenResource}.
"""
errorPage = ErrorPage
noResource = NoResource
forbiddenResource = ForbiddenResource
def test_getChild(self):
"""
The C{getChild} method of L{ErrorPage} returns the L{ErrorPage} it is
called on.
"""
page = self.errorPage(321, "foo", "bar")
self.assertIdentical(page.getChild(b"name", object()), page)
def _pageRenderingTest(self, page, code, brief, detail):
request = DummyRequest([b''])
template = (
u"\n"
u"<html>\n"
u" <head><title>%s - %s</title></head>\n"
u" <body>\n"
u" <h1>%s</h1>\n"
u" <p>%s</p>\n"
u" </body>\n"
u"</html>\n")
expected = template % (code, brief, brief, detail)
self.assertEqual(
page.render(request), expected.encode('utf-8'))
self.assertEqual(request.responseCode, code)
self.assertEqual(
request.responseHeaders,
Headers({b'content-type': [b'text/html; charset=utf-8']}))
def test_errorPageRendering(self):
"""
L{ErrorPage.render} returns a C{bytes} describing the error defined by
the response code and message passed to L{ErrorPage.__init__}. It also
uses that response code to set the response code on the L{Request}
passed in.
"""
code = 321
brief = "brief description text"
detail = "much longer text might go here"
page = self.errorPage(code, brief, detail)
self._pageRenderingTest(page, code, brief, detail)
def test_noResourceRendering(self):
"""
L{NoResource} sets the HTTP I{NOT FOUND} code.
"""
detail = "long message"
page = self.noResource(detail)
self._pageRenderingTest(page, NOT_FOUND, "No Such Resource", detail)
def test_forbiddenResourceRendering(self):
"""
L{ForbiddenResource} sets the HTTP I{FORBIDDEN} code.
"""
detail = "longer message"
page = self.forbiddenResource(detail)
self._pageRenderingTest(page, FORBIDDEN, "Forbidden Resource", detail)
class DynamicChild(Resource):
"""
A L{Resource} to be created on the fly by L{DynamicChildren}.
"""
def __init__(self, path, request):
Resource.__init__(self)
self.path = path
self.request = request
class DynamicChildren(Resource):
"""
A L{Resource} with dynamic children.
"""
def getChild(self, path, request):
return DynamicChild(path, request)
class BytesReturnedRenderable(Resource):
"""
A L{Resource} with minimal capabilities to render a response.
"""
def __init__(self, response):
"""
@param response: A C{bytes} object giving the value to return from
C{render_GET}.
"""
Resource.__init__(self)
self._response = response
def render_GET(self, request):
"""
Render a response to a I{GET} request by returning a short byte string
to be written by the server.
"""
return self._response
class ImplicitAllowedMethods(Resource):
"""
A L{Resource} which implicitly defines its allowed methods by defining
renderers to handle them.
"""
def render_GET(self, request):
pass
def render_PUT(self, request):
pass
class ResourceTests(TestCase):
"""
Tests for L{Resource}.
"""
def test_staticChildren(self):
"""
L{Resource.putChild} adds a I{static} child to the resource. That child
is returned from any call to L{Resource.getChildWithDefault} for the
child's path.
"""
resource = Resource()
child = Resource()
sibling = Resource()
resource.putChild(b"foo", child)
resource.putChild(b"bar", sibling)
self.assertIdentical(
child, resource.getChildWithDefault(b"foo", DummyRequest([])))
def test_dynamicChildren(self):
"""
L{Resource.getChildWithDefault} delegates to L{Resource.getChild} when
the requested path is not associated with any static child.
"""
path = b"foo"
request = DummyRequest([])
resource = DynamicChildren()
child = resource.getChildWithDefault(path, request)
self.assertIsInstance(child, DynamicChild)
self.assertEqual(child.path, path)
self.assertIdentical(child.request, request)
def test_staticChildPathType(self):
"""
Test that passing the wrong type to putChild results in a warning,
and a failure in Python 3
"""
resource = Resource()
child = Resource()
sibling = Resource()
resource.putChild(u"foo", child)
warnings = self.flushWarnings([self.test_staticChildPathType])
self.assertEqual(len(warnings), 1)
self.assertIn("Path segment must be bytes",
warnings[0]['message'])
if _PY3:
# We expect an error here because u"foo" != b"foo" on Py3k
self.assertIsInstance(
resource.getChildWithDefault(b"foo", DummyRequest([])),
ErrorPage)
resource.putChild(None, sibling)
warnings = self.flushWarnings([self.test_staticChildPathType])
self.assertEqual(len(warnings), 1)
self.assertIn("Path segment must be bytes",
warnings[0]['message'])
def test_defaultHEAD(self):
"""
When not otherwise overridden, L{Resource.render} treats a I{HEAD}
request as if it were a I{GET} request.
"""
expected = b"insert response here"
request = DummyRequest([])
request.method = b'HEAD'
resource = BytesReturnedRenderable(expected)
self.assertEqual(expected, resource.render(request))
def test_explicitAllowedMethods(self):
"""
The L{UnsupportedMethod} raised by L{Resource.render} for an unsupported
request method has a C{allowedMethods} attribute set to the value of the
C{allowedMethods} attribute of the L{Resource}, if it has one.
"""
expected = [b'GET', b'HEAD', b'PUT']
resource = Resource()
resource.allowedMethods = expected
request = DummyRequest([])
request.method = b'FICTIONAL'
exc = self.assertRaises(UnsupportedMethod, resource.render, request)
self.assertEqual(set(expected), set(exc.allowedMethods))
def test_implicitAllowedMethods(self):
"""
The L{UnsupportedMethod} raised by L{Resource.render} for an unsupported
request method has a C{allowedMethods} attribute set to a list of the
methods supported by the L{Resource}, as determined by the
I{render_}-prefixed methods which it defines, if C{allowedMethods} is
not explicitly defined by the L{Resource}.
"""
expected = set([b'GET', b'HEAD', b'PUT'])
resource = ImplicitAllowedMethods()
request = DummyRequest([])
request.method = b'FICTIONAL'
exc = self.assertRaises(UnsupportedMethod, resource.render, request)
self.assertEqual(expected, set(exc.allowedMethods))
class GetChildForRequestTests(TestCase):
"""
Tests for L{getChildForRequest}.
"""
def test_exhaustedPostPath(self):
"""
L{getChildForRequest} returns whatever resource has been reached by the
time the request's C{postpath} is empty.
"""
request = DummyRequest([])
resource = Resource()
result = getChildForRequest(resource, request)
self.assertIdentical(resource, result)
def test_leafResource(self):
"""
L{getChildForRequest} returns the first resource it encounters with a
C{isLeaf} attribute set to C{True}.
"""
request = DummyRequest([b"foo", b"bar"])
resource = Resource()
resource.isLeaf = True
result = getChildForRequest(resource, request)
self.assertIdentical(resource, result)
def test_postPathToPrePath(self):
"""
As path segments from the request are traversed, they are taken from
C{postpath} and put into C{prepath}.
"""
request = DummyRequest([b"foo", b"bar"])
root = Resource()
child = Resource()
child.isLeaf = True
root.putChild(b"foo", child)
self.assertIdentical(child, getChildForRequest(root, request))
self.assertEqual(request.prepath, [b"foo"])
self.assertEqual(request.postpath, [b"bar"])

View file

@ -0,0 +1,115 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.web.script}.
"""
import os
from twisted.trial.unittest import TestCase
from twisted.python.filepath import FilePath
from twisted.web.http import NOT_FOUND
from twisted.web.script import ResourceScriptDirectory, PythonScript
from twisted.web.test._util import _render
from twisted.web.test.requesthelper import DummyRequest
class ResourceScriptDirectoryTests(TestCase):
"""
Tests for L{ResourceScriptDirectory}.
"""
def test_renderNotFound(self):
"""
L{ResourceScriptDirectory.render} sets the HTTP response code to I{NOT
FOUND}.
"""
resource = ResourceScriptDirectory(self.mktemp())
request = DummyRequest([b''])
d = _render(resource, request)
def cbRendered(ignored):
self.assertEqual(request.responseCode, NOT_FOUND)
d.addCallback(cbRendered)
return d
def test_notFoundChild(self):
"""
L{ResourceScriptDirectory.getChild} returns a resource which renders an
response with the HTTP I{NOT FOUND} status code if the indicated child
does not exist as an entry in the directory used to initialized the
L{ResourceScriptDirectory}.
"""
path = self.mktemp()
os.makedirs(path)
resource = ResourceScriptDirectory(path)
request = DummyRequest([b'foo'])
child = resource.getChild("foo", request)
d = _render(child, request)
def cbRendered(ignored):
self.assertEqual(request.responseCode, NOT_FOUND)
d.addCallback(cbRendered)
return d
def test_render(self):
"""
L{ResourceScriptDirectory.getChild} returns a resource which renders a
response with the HTTP 200 status code and the content of the rpy's
C{request} global.
"""
tmp = FilePath(self.mktemp())
tmp.makedirs()
tmp.child("test.rpy").setContent(b"""
from twisted.web.resource import Resource
class TestResource(Resource):
isLeaf = True
def render_GET(self, request):
return b'ok'
resource = TestResource()""")
resource = ResourceScriptDirectory(tmp._asBytesPath())
request = DummyRequest([b''])
child = resource.getChild(b"test.rpy", request)
d = _render(child, request)
def cbRendered(ignored):
self.assertEqual(b"".join(request.written), b"ok")
d.addCallback(cbRendered)
return d
class PythonScriptTests(TestCase):
"""
Tests for L{PythonScript}.
"""
def test_notFoundRender(self):
"""
If the source file a L{PythonScript} is initialized with doesn't exist,
L{PythonScript.render} sets the HTTP response code to I{NOT FOUND}.
"""
resource = PythonScript(self.mktemp(), None)
request = DummyRequest([b''])
d = _render(resource, request)
def cbRendered(ignored):
self.assertEqual(request.responseCode, NOT_FOUND)
d.addCallback(cbRendered)
return d
def test_renderException(self):
"""
L{ResourceScriptDirectory.getChild} returns a resource which renders a
response with the HTTP 200 status code and the content of the rpy's
C{request} global.
"""
tmp = FilePath(self.mktemp())
tmp.makedirs()
child = tmp.child("test.epy")
child.setContent(b'raise Exception("nooo")')
resource = PythonScript(child._asBytesPath(), None)
request = DummyRequest([b''])
d = _render(resource, request)
def cbRendered(ignored):
self.assertIn(b"nooo", b"".join(request.written))
d.addCallback(cbRendered)
return d

View file

@ -0,0 +1,166 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.web._stan} portion of the L{twisted.web.template}
implementation.
"""
from __future__ import absolute_import, division
from twisted.web.template import Comment, CDATA, CharRef, Tag
from twisted.trial.unittest import TestCase
from twisted.python.compat import _PY3
def proto(*a, **kw):
"""
Produce a new tag for testing.
"""
return Tag('hello')(*a, **kw)
class TagTests(TestCase):
"""
Tests for L{Tag}.
"""
def test_fillSlots(self):
"""
L{Tag.fillSlots} returns self.
"""
tag = proto()
self.assertIdentical(tag, tag.fillSlots(test='test'))
def test_cloneShallow(self):
"""
L{Tag.clone} copies all attributes and children of a tag, including its
render attribute. If the shallow flag is C{False}, that's where it
stops.
"""
innerList = ["inner list"]
tag = proto("How are you", innerList,
hello="world", render="aSampleMethod")
tag.fillSlots(foo='bar')
tag.filename = "foo/bar"
tag.lineNumber = 6
tag.columnNumber = 12
clone = tag.clone(deep=False)
self.assertEqual(clone.attributes['hello'], 'world')
self.assertNotIdentical(clone.attributes, tag.attributes)
self.assertEqual(clone.children, ["How are you", innerList])
self.assertNotIdentical(clone.children, tag.children)
self.assertIdentical(clone.children[1], innerList)
self.assertEqual(tag.slotData, clone.slotData)
self.assertNotIdentical(tag.slotData, clone.slotData)
self.assertEqual(clone.filename, "foo/bar")
self.assertEqual(clone.lineNumber, 6)
self.assertEqual(clone.columnNumber, 12)
self.assertEqual(clone.render, "aSampleMethod")
def test_cloneDeep(self):
"""
L{Tag.clone} copies all attributes and children of a tag, including its
render attribute. In its normal operating mode (where the deep flag is
C{True}, as is the default), it will clone all sub-lists and sub-tags.
"""
innerTag = proto("inner")
innerList = ["inner list"]
tag = proto("How are you", innerTag, innerList,
hello="world", render="aSampleMethod")
tag.fillSlots(foo='bar')
tag.filename = "foo/bar"
tag.lineNumber = 6
tag.columnNumber = 12
clone = tag.clone()
self.assertEqual(clone.attributes['hello'], 'world')
self.assertNotIdentical(clone.attributes, tag.attributes)
self.assertNotIdentical(clone.children, tag.children)
# sanity check
self.assertIdentical(tag.children[1], innerTag)
# clone should have sub-clone
self.assertNotIdentical(clone.children[1], innerTag)
# sanity check
self.assertIdentical(tag.children[2], innerList)
# clone should have sub-clone
self.assertNotIdentical(clone.children[2], innerList)
self.assertEqual(tag.slotData, clone.slotData)
self.assertNotIdentical(tag.slotData, clone.slotData)
self.assertEqual(clone.filename, "foo/bar")
self.assertEqual(clone.lineNumber, 6)
self.assertEqual(clone.columnNumber, 12)
self.assertEqual(clone.render, "aSampleMethod")
def test_clear(self):
"""
L{Tag.clear} removes all children from a tag, but leaves its attributes
in place.
"""
tag = proto("these are", "children", "cool", andSoIs='this-attribute')
tag.clear()
self.assertEqual(tag.children, [])
self.assertEqual(tag.attributes, {'andSoIs': 'this-attribute'})
def test_suffix(self):
"""
L{Tag.__call__} accepts Python keywords with a suffixed underscore as
the DOM attribute of that literal suffix.
"""
proto = Tag('div')
tag = proto()
tag(class_='a')
self.assertEqual(tag.attributes, {'class': 'a'})
def test_commentReprPy2(self):
"""
L{Comment.__repr__} returns a value which makes it easy to see what's
in the comment.
"""
self.assertEqual(repr(Comment(u"hello there")),
"Comment(u'hello there')")
def test_cdataReprPy2(self):
"""
L{CDATA.__repr__} returns a value which makes it easy to see what's in
the comment.
"""
self.assertEqual(repr(CDATA(u"test data")),
"CDATA(u'test data')")
def test_commentReprPy3(self):
"""
L{Comment.__repr__} returns a value which makes it easy to see what's
in the comment.
"""
self.assertEqual(repr(Comment(u"hello there")),
"Comment('hello there')")
def test_cdataReprPy3(self):
"""
L{CDATA.__repr__} returns a value which makes it easy to see what's in
the comment.
"""
self.assertEqual(repr(CDATA(u"test data")),
"CDATA('test data')")
if not _PY3:
test_commentReprPy3.skip = "Only relevant on Python 3."
test_cdataReprPy3.skip = "Only relevant on Python 3."
else:
test_commentReprPy2.skip = "Only relevant on Python 2."
test_cdataReprPy2.skip = "Only relevant on Python 2."
def test_charrefRepr(self):
"""
L{CharRef.__repr__} returns a value which makes it easy to see what
character is referred to.
"""
snowman = ord(u"\N{SNOWMAN}")
self.assertEqual(repr(CharRef(snowman)), "CharRef(9731)")

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,346 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.web.tap}.
"""
from __future__ import absolute_import, division
import os
import stat
from twisted.internet import reactor, endpoints
from twisted.internet.interfaces import IReactorUNIX
from twisted.python.filepath import FilePath
from twisted.python.reflect import requireModule
from twisted.python.threadpool import ThreadPool
from twisted.python.usage import UsageError
from twisted.spread.pb import PBServerFactory
from twisted.trial.unittest import TestCase
from twisted.web import demo
from twisted.web.distrib import ResourcePublisher, UserDirectory
from twisted.web.script import PythonScript
from twisted.web.server import Site
from twisted.web.static import Data, File
from twisted.web.tap import Options, makeService
from twisted.web.tap import makePersonalServerFactory, _AddHeadersResource
from twisted.web.test.requesthelper import DummyRequest
from twisted.web.twcgi import CGIScript
from twisted.web.wsgi import WSGIResource
application = object()
class ServiceTests(TestCase):
"""
Tests for the service creation APIs in L{twisted.web.tap}.
"""
def _pathOption(self):
"""
Helper for the I{--path} tests which creates a directory and creates
an L{Options} object which uses that directory as its static
filesystem root.
@return: A two-tuple of a L{FilePath} referring to the directory and
the value associated with the C{'root'} key in the L{Options}
instance after parsing a I{--path} option.
"""
path = FilePath(self.mktemp())
path.makedirs()
options = Options()
options.parseOptions(['--path', path.path])
root = options['root']
return path, root
def test_path(self):
"""
The I{--path} option causes L{Options} to create a root resource
which serves responses from the specified path.
"""
path, root = self._pathOption()
self.assertIsInstance(root, File)
self.assertEqual(root.path, path.path)
def test_pathServer(self):
"""
The I{--path} option to L{makeService} causes it to return a service
which will listen on the server address given by the I{--port} option.
"""
path = FilePath(self.mktemp())
path.makedirs()
port = self.mktemp()
options = Options()
options.parseOptions(['--port', 'unix:' + port, '--path', path.path])
service = makeService(options)
service.startService()
self.addCleanup(service.stopService)
self.assertIsInstance(service.services[0].factory.resource, File)
self.assertEqual(service.services[0].factory.resource.path, path.path)
self.assertTrue(os.path.exists(port))
self.assertTrue(stat.S_ISSOCK(os.stat(port).st_mode))
if not IReactorUNIX.providedBy(reactor):
test_pathServer.skip = (
"The reactor does not support UNIX domain sockets")
def test_cgiProcessor(self):
"""
The I{--path} option creates a root resource which serves a
L{CGIScript} instance for any child with the C{".cgi"} extension.
"""
path, root = self._pathOption()
path.child("foo.cgi").setContent(b"")
self.assertIsInstance(root.getChild("foo.cgi", None), CGIScript)
def test_epyProcessor(self):
"""
The I{--path} option creates a root resource which serves a
L{PythonScript} instance for any child with the C{".epy"} extension.
"""
path, root = self._pathOption()
path.child("foo.epy").setContent(b"")
self.assertIsInstance(root.getChild("foo.epy", None), PythonScript)
def test_rpyProcessor(self):
"""
The I{--path} option creates a root resource which serves the
C{resource} global defined by the Python source in any child with
the C{".rpy"} extension.
"""
path, root = self._pathOption()
path.child("foo.rpy").setContent(
b"from twisted.web.static import Data\n"
b"resource = Data('content', 'major/minor')\n")
child = root.getChild("foo.rpy", None)
self.assertIsInstance(child, Data)
self.assertEqual(child.data, 'content')
self.assertEqual(child.type, 'major/minor')
def test_makePersonalServerFactory(self):
"""
L{makePersonalServerFactory} returns a PB server factory which has
as its root object a L{ResourcePublisher}.
"""
# The fact that this pile of objects can actually be used somehow is
# verified by twisted.web.test.test_distrib.
site = Site(Data(b"foo bar", "text/plain"))
serverFactory = makePersonalServerFactory(site)
self.assertIsInstance(serverFactory, PBServerFactory)
self.assertIsInstance(serverFactory.root, ResourcePublisher)
self.assertIdentical(serverFactory.root.site, site)
def test_personalServer(self):
"""
The I{--personal} option to L{makeService} causes it to return a
service which will listen on the server address given by the I{--port}
option.
"""
port = self.mktemp()
options = Options()
options.parseOptions(['--port', 'unix:' + port, '--personal'])
service = makeService(options)
service.startService()
self.addCleanup(service.stopService)
self.assertTrue(os.path.exists(port))
self.assertTrue(stat.S_ISSOCK(os.stat(port).st_mode))
if not IReactorUNIX.providedBy(reactor):
test_personalServer.skip = (
"The reactor does not support UNIX domain sockets")
def test_defaultPersonalPath(self):
"""
If the I{--port} option not specified but the I{--personal} option is,
L{Options} defaults the port to C{UserDirectory.userSocketName} in the
user's home directory.
"""
options = Options()
options.parseOptions(['--personal'])
path = os.path.expanduser(
os.path.join('~', UserDirectory.userSocketName))
self.assertEqual(options['ports'][0],
'unix:{}'.format(path))
if not IReactorUNIX.providedBy(reactor):
test_defaultPersonalPath.skip = (
"The reactor does not support UNIX domain sockets")
def test_defaultPort(self):
"""
If the I{--port} option is not specified, L{Options} defaults the port
to C{8080}.
"""
options = Options()
options.parseOptions([])
self.assertEqual(
endpoints._parseServer(options['ports'][0], None)[:2],
('TCP', (8080, None)))
def test_twoPorts(self):
"""
If the I{--http} option is given twice, there are two listeners
"""
options = Options()
options.parseOptions(['--listen', 'tcp:8001', '--listen', 'tcp:8002'])
self.assertIn('8001', options['ports'][0])
self.assertIn('8002', options['ports'][1])
def test_wsgi(self):
"""
The I{--wsgi} option takes the fully-qualifed Python name of a WSGI
application object and creates a L{WSGIResource} at the root which
serves that application.
"""
options = Options()
options.parseOptions(['--wsgi', __name__ + '.application'])
root = options['root']
self.assertTrue(root, WSGIResource)
self.assertIdentical(root._reactor, reactor)
self.assertTrue(isinstance(root._threadpool, ThreadPool))
self.assertIdentical(root._application, application)
# The threadpool should start and stop with the reactor.
self.assertFalse(root._threadpool.started)
reactor.fireSystemEvent('startup')
self.assertTrue(root._threadpool.started)
self.assertFalse(root._threadpool.joined)
reactor.fireSystemEvent('shutdown')
self.assertTrue(root._threadpool.joined)
def test_invalidApplication(self):
"""
If I{--wsgi} is given an invalid name, L{Options.parseOptions}
raises L{UsageError}.
"""
options = Options()
for name in [__name__ + '.nosuchthing', 'foo.']:
exc = self.assertRaises(
UsageError, options.parseOptions, ['--wsgi', name])
self.assertEqual(str(exc),
"No such WSGI application: %r" % (name,))
def test_HTTPSFailureOnMissingSSL(self):
"""
An L{UsageError} is raised when C{https} is requested but there is no
support for SSL.
"""
options = Options()
exception = self.assertRaises(
UsageError, options.parseOptions, ['--https=443'])
self.assertEqual('SSL support not installed', exception.args[0])
if requireModule('OpenSSL.SSL') is not None:
test_HTTPSFailureOnMissingSSL.skip = 'SSL module is available.'
def test_HTTPSAcceptedOnAvailableSSL(self):
"""
When SSL support is present, it accepts the --https option.
"""
options = Options()
options.parseOptions(['--https=443'])
self.assertIn('ssl', options['ports'][0])
self.assertIn('443', options['ports'][0])
if requireModule('OpenSSL.SSL') is None:
test_HTTPSAcceptedOnAvailableSSL.skip = 'SSL module is not available.'
def test_add_header_parsing(self):
"""
When --add-header is specific, the value is parsed.
"""
options = Options()
options.parseOptions(
['--add-header', 'K1: V1', '--add-header', 'K2: V2']
)
self.assertEqual(options['extraHeaders'], [('K1', 'V1'), ('K2', 'V2')])
def test_add_header_resource(self):
"""
When --add-header is specified, the resource is a composition that adds
headers.
"""
options = Options()
options.parseOptions(
['--add-header', 'K1: V1', '--add-header', 'K2: V2']
)
service = makeService(options)
resource = service.services[0].factory.resource
self.assertIsInstance(resource, _AddHeadersResource)
self.assertEqual(resource._headers, [('K1', 'V1'), ('K2', 'V2')])
self.assertIsInstance(resource._originalResource, demo.Test)
def test_noTracebacksDeprecation(self):
"""
Passing --notracebacks is deprecated.
"""
options = Options()
options.parseOptions(["--notracebacks"])
makeService(options)
warnings = self.flushWarnings([self.test_noTracebacksDeprecation])
self.assertEqual(warnings[0]['category'], DeprecationWarning)
self.assertEqual(
warnings[0]['message'],
"--notracebacks was deprecated in Twisted 19.7.0"
)
self.assertEqual(len(warnings), 1)
def test_displayTracebacks(self):
"""
Passing --display-tracebacks will enable traceback rendering on the
generated Site.
"""
options = Options()
options.parseOptions(["--display-tracebacks"])
service = makeService(options)
self.assertTrue(service.services[0].factory.displayTracebacks)
def test_displayTracebacksNotGiven(self):
"""
Not passing --display-tracebacks will leave traceback rendering on the
generated Site off.
"""
options = Options()
options.parseOptions([])
service = makeService(options)
self.assertFalse(service.services[0].factory.displayTracebacks)
class AddHeadersResourceTests(TestCase):
def test_getChildWithDefault(self):
"""
When getChildWithDefault is invoked, it adds the headers to the
response.
"""
resource = _AddHeadersResource(
demo.Test(), [("K1", "V1"), ("K2", "V2"), ("K1", "V3")])
request = DummyRequest([])
resource.getChildWithDefault("", request)
self.assertEqual(
request.responseHeaders.getRawHeaders("K1"), ["V1", "V3"])
self.assertEqual(request.responseHeaders.getRawHeaders("K2"), ["V2"])

View file

@ -0,0 +1,827 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.web.template}
"""
from __future__ import division, absolute_import
from zope.interface.verify import verifyObject
from twisted.internet.defer import succeed, gatherResults
from twisted.python.filepath import FilePath
from twisted.trial.unittest import TestCase
from twisted.trial.util import suppress as SUPPRESS
from twisted.web.template import (
Element, TagLoader, renderer, tags, XMLFile, XMLString)
from twisted.web.iweb import ITemplateLoader
from twisted.web.error import (FlattenerError, MissingTemplateLoader,
MissingRenderMethod)
from twisted.web.template import renderElement
from twisted.web._element import UnexposedMethodError
from twisted.web.test._util import FlattenTestCase
from twisted.web.test.test_web import DummyRequest
from twisted.web.server import NOT_DONE_YET
from twisted.python.compat import NativeStringIO as StringIO
from twisted.logger import globalLogPublisher
from twisted.test.proto_helpers import EventLoggingObserver
_xmlFileSuppress = SUPPRESS(category=DeprecationWarning,
message="Passing filenames or file objects to XMLFile is "
"deprecated since Twisted 12.1. Pass a FilePath instead.")
class TagFactoryTests(TestCase):
"""
Tests for L{_TagFactory} through the publicly-exposed L{tags} object.
"""
def test_lookupTag(self):
"""
HTML tags can be retrieved through C{tags}.
"""
tag = tags.a
self.assertEqual(tag.tagName, "a")
def test_lookupHTML5Tag(self):
"""
Twisted supports the latest and greatest HTML tags from the HTML5
specification.
"""
tag = tags.video
self.assertEqual(tag.tagName, "video")
def test_lookupTransparentTag(self):
"""
To support transparent inclusion in templates, there is a special tag,
the transparent tag, which has no name of its own but is accessed
through the "transparent" attribute.
"""
tag = tags.transparent
self.assertEqual(tag.tagName, "")
def test_lookupInvalidTag(self):
"""
Invalid tags which are not part of HTML cause AttributeErrors when
accessed through C{tags}.
"""
self.assertRaises(AttributeError, getattr, tags, "invalid")
def test_lookupXMP(self):
"""
As a special case, the <xmp> tag is simply not available through
C{tags} or any other part of the templating machinery.
"""
self.assertRaises(AttributeError, getattr, tags, "xmp")
class ElementTests(TestCase):
"""
Tests for the awesome new L{Element} class.
"""
def test_missingTemplateLoader(self):
"""
L{Element.render} raises L{MissingTemplateLoader} if the C{loader}
attribute is L{None}.
"""
element = Element()
err = self.assertRaises(MissingTemplateLoader, element.render, None)
self.assertIdentical(err.element, element)
def test_missingTemplateLoaderRepr(self):
"""
A L{MissingTemplateLoader} instance can be repr()'d without error.
"""
class PrettyReprElement(Element):
def __repr__(self):
return 'Pretty Repr Element'
self.assertIn('Pretty Repr Element',
repr(MissingTemplateLoader(PrettyReprElement())))
def test_missingRendererMethod(self):
"""
When called with the name which is not associated with a render method,
L{Element.lookupRenderMethod} raises L{MissingRenderMethod}.
"""
element = Element()
err = self.assertRaises(
MissingRenderMethod, element.lookupRenderMethod, "foo")
self.assertIdentical(err.element, element)
self.assertEqual(err.renderName, "foo")
def test_missingRenderMethodRepr(self):
"""
A L{MissingRenderMethod} instance can be repr()'d without error.
"""
class PrettyReprElement(Element):
def __repr__(self):
return 'Pretty Repr Element'
s = repr(MissingRenderMethod(PrettyReprElement(),
'expectedMethod'))
self.assertIn('Pretty Repr Element', s)
self.assertIn('expectedMethod', s)
def test_definedRenderer(self):
"""
When called with the name of a defined render method,
L{Element.lookupRenderMethod} returns that render method.
"""
class ElementWithRenderMethod(Element):
@renderer
def foo(self, request, tag):
return "bar"
foo = ElementWithRenderMethod().lookupRenderMethod("foo")
self.assertEqual(foo(None, None), "bar")
def test_render(self):
"""
L{Element.render} loads a document from the C{loader} attribute and
returns it.
"""
class TemplateLoader(object):
def load(self):
return "result"
class StubElement(Element):
loader = TemplateLoader()
element = StubElement()
self.assertEqual(element.render(None), "result")
def test_misuseRenderer(self):
"""
If the L{renderer} decorator is called without any arguments, it will
raise a comprehensible exception.
"""
te = self.assertRaises(TypeError, renderer)
self.assertEqual(str(te),
"expose() takes at least 1 argument (0 given)")
def test_renderGetDirectlyError(self):
"""
Called directly, without a default, L{renderer.get} raises
L{UnexposedMethodError} when it cannot find a renderer.
"""
self.assertRaises(UnexposedMethodError, renderer.get, None,
"notARenderer")
class XMLFileReprTests(TestCase):
"""
Tests for L{twisted.web.template.XMLFile}'s C{__repr__}.
"""
def test_filePath(self):
"""
An L{XMLFile} with a L{FilePath} returns a useful repr().
"""
path = FilePath("/tmp/fake.xml")
self.assertEqual('<XMLFile of %r>' % (path,), repr(XMLFile(path)))
def test_filename(self):
"""
An L{XMLFile} with a filename returns a useful repr().
"""
fname = "/tmp/fake.xml"
self.assertEqual('<XMLFile of %r>' % (fname,), repr(XMLFile(fname)))
test_filename.suppress = [_xmlFileSuppress]
def test_file(self):
"""
An L{XMLFile} with a file object returns a useful repr().
"""
fobj = StringIO("not xml")
self.assertEqual('<XMLFile of %r>' % (fobj,), repr(XMLFile(fobj)))
test_file.suppress = [_xmlFileSuppress]
class XMLLoaderTestsMixin(object):
"""
@ivar templateString: Simple template to use to exercise the loaders.
@ivar deprecatedUse: C{True} if this use of L{XMLFile} is deprecated and
should emit a C{DeprecationWarning}.
"""
loaderFactory = None
templateString = '<p>Hello, world.</p>'
def test_load(self):
"""
Verify that the loader returns a tag with the correct children.
"""
loader = self.loaderFactory()
tag, = loader.load()
warnings = self.flushWarnings(offendingFunctions=[self.loaderFactory])
if self.deprecatedUse:
self.assertEqual(len(warnings), 1)
self.assertEqual(warnings[0]['category'], DeprecationWarning)
self.assertEqual(
warnings[0]['message'],
"Passing filenames or file objects to XMLFile is "
"deprecated since Twisted 12.1. Pass a FilePath instead.")
else:
self.assertEqual(len(warnings), 0)
self.assertEqual(tag.tagName, 'p')
self.assertEqual(tag.children, [u'Hello, world.'])
def test_loadTwice(self):
"""
If {load()} can be called on a loader twice the result should be the
same.
"""
loader = self.loaderFactory()
tags1 = loader.load()
tags2 = loader.load()
self.assertEqual(tags1, tags2)
test_loadTwice.suppress = [_xmlFileSuppress]
class XMLStringLoaderTests(TestCase, XMLLoaderTestsMixin):
"""
Tests for L{twisted.web.template.XMLString}
"""
deprecatedUse = False
def loaderFactory(self):
"""
@return: an L{XMLString} constructed with C{self.templateString}.
"""
return XMLString(self.templateString)
class XMLFileWithFilePathTests(TestCase, XMLLoaderTestsMixin):
"""
Tests for L{twisted.web.template.XMLFile}'s L{FilePath} support.
"""
deprecatedUse = False
def loaderFactory(self):
"""
@return: an L{XMLString} constructed with a L{FilePath} pointing to a
file that contains C{self.templateString}.
"""
fp = FilePath(self.mktemp())
fp.setContent(self.templateString.encode("utf8"))
return XMLFile(fp)
class XMLFileWithFileTests(TestCase, XMLLoaderTestsMixin):
"""
Tests for L{twisted.web.template.XMLFile}'s deprecated file object support.
"""
deprecatedUse = True
def loaderFactory(self):
"""
@return: an L{XMLString} constructed with a file object that contains
C{self.templateString}.
"""
return XMLFile(StringIO(self.templateString))
class XMLFileWithFilenameTests(TestCase, XMLLoaderTestsMixin):
"""
Tests for L{twisted.web.template.XMLFile}'s deprecated filename support.
"""
deprecatedUse = True
def loaderFactory(self):
"""
@return: an L{XMLString} constructed with a filename that points to a
file containing C{self.templateString}.
"""
fp = FilePath(self.mktemp())
fp.setContent(self.templateString.encode('utf8'))
return XMLFile(fp.path)
class FlattenIntegrationTests(FlattenTestCase):
"""
Tests for integration between L{Element} and
L{twisted.web._flatten.flatten}.
"""
def test_roundTrip(self):
"""
Given a series of parsable XML strings, verify that
L{twisted.web._flatten.flatten} will flatten the L{Element} back to the
input when sent on a round trip.
"""
fragments = [
b"<p>Hello, world.</p>",
b"<p><!-- hello, world --></p>",
b"<p><![CDATA[Hello, world.]]></p>",
b'<test1 xmlns:test2="urn:test2">'
b'<test2:test3></test2:test3></test1>',
b'<test1 xmlns="urn:test2"><test3></test3></test1>',
b'<p>\xe2\x98\x83</p>',
]
deferreds = [
self.assertFlattensTo(Element(loader=XMLString(xml)), xml)
for xml in fragments]
return gatherResults(deferreds)
def test_entityConversion(self):
"""
When flattening an HTML entity, it should flatten out to the utf-8
representation if possible.
"""
element = Element(loader=XMLString('<p>&#9731;</p>'))
return self.assertFlattensTo(element, b'<p>\xe2\x98\x83</p>')
def test_missingTemplateLoader(self):
"""
Rendering an Element without a loader attribute raises the appropriate
exception.
"""
return self.assertFlatteningRaises(Element(), MissingTemplateLoader)
def test_missingRenderMethod(self):
"""
Flattening an L{Element} with a C{loader} which has a tag with a render
directive fails with L{FlattenerError} if there is no available render
method to satisfy that directive.
"""
element = Element(loader=XMLString("""
<p xmlns:t="http://twistedmatrix.com/ns/twisted.web.template/0.1"
t:render="unknownMethod" />
"""))
return self.assertFlatteningRaises(element, MissingRenderMethod)
def test_transparentRendering(self):
"""
A C{transparent} element should be eliminated from the DOM and rendered as
only its children.
"""
element = Element(loader=XMLString(
'<t:transparent '
'xmlns:t="http://twistedmatrix.com/ns/twisted.web.template/0.1">'
'Hello, world.'
'</t:transparent>'
))
return self.assertFlattensTo(element, b"Hello, world.")
def test_attrRendering(self):
"""
An Element with an attr tag renders the vaule of its attr tag as an
attribute of its containing tag.
"""
element = Element(loader=XMLString(
'<a xmlns:t="http://twistedmatrix.com/ns/twisted.web.template/0.1">'
'<t:attr name="href">http://example.com</t:attr>'
'Hello, world.'
'</a>'
))
return self.assertFlattensTo(element,
b'<a href="http://example.com">Hello, world.</a>')
def test_errorToplevelAttr(self):
"""
A template with a toplevel C{attr} tag will not load; it will raise
L{AssertionError} if you try.
"""
self.assertRaises(
AssertionError,
XMLString,
"""<t:attr
xmlns:t='http://twistedmatrix.com/ns/twisted.web.template/0.1'
name='something'
>hello</t:attr>
""")
def test_errorUnnamedAttr(self):
"""
A template with an C{attr} tag with no C{name} attribute will not load;
it will raise L{AssertionError} if you try.
"""
self.assertRaises(
AssertionError,
XMLString,
"""<html><t:attr
xmlns:t='http://twistedmatrix.com/ns/twisted.web.template/0.1'
>hello</t:attr></html>""")
def test_lenientPrefixBehavior(self):
"""
If the parser sees a prefix it doesn't recognize on an attribute, it
will pass it on through to serialization.
"""
theInput = (
'<hello:world hello:sample="testing" '
'xmlns:hello="http://made-up.example.com/ns/not-real">'
'This is a made-up tag.</hello:world>')
element = Element(loader=XMLString(theInput))
self.assertFlattensTo(element, theInput.encode('utf8'))
def test_deferredRendering(self):
"""
An Element with a render method which returns a Deferred will render
correctly.
"""
class RenderfulElement(Element):
@renderer
def renderMethod(self, request, tag):
return succeed("Hello, world.")
element = RenderfulElement(loader=XMLString("""
<p xmlns:t="http://twistedmatrix.com/ns/twisted.web.template/0.1"
t:render="renderMethod">
Goodbye, world.
</p>
"""))
return self.assertFlattensTo(element, b"Hello, world.")
def test_loaderClassAttribute(self):
"""
If there is a non-None loader attribute on the class of an Element
instance but none on the instance itself, the class attribute is used.
"""
class SubElement(Element):
loader = XMLString("<p>Hello, world.</p>")
return self.assertFlattensTo(SubElement(), b"<p>Hello, world.</p>")
def test_directiveRendering(self):
"""
An Element with a valid render directive has that directive invoked and
the result added to the output.
"""
renders = []
class RenderfulElement(Element):
@renderer
def renderMethod(self, request, tag):
renders.append((self, request))
return tag("Hello, world.")
element = RenderfulElement(loader=XMLString("""
<p xmlns:t="http://twistedmatrix.com/ns/twisted.web.template/0.1"
t:render="renderMethod" />
"""))
return self.assertFlattensTo(element, b"<p>Hello, world.</p>")
def test_directiveRenderingOmittingTag(self):
"""
An Element with a render method which omits the containing tag
successfully removes that tag from the output.
"""
class RenderfulElement(Element):
@renderer
def renderMethod(self, request, tag):
return "Hello, world."
element = RenderfulElement(loader=XMLString("""
<p xmlns:t="http://twistedmatrix.com/ns/twisted.web.template/0.1"
t:render="renderMethod">
Goodbye, world.
</p>
"""))
return self.assertFlattensTo(element, b"Hello, world.")
def test_elementContainingStaticElement(self):
"""
An Element which is returned by the render method of another Element is
rendered properly.
"""
class RenderfulElement(Element):
@renderer
def renderMethod(self, request, tag):
return tag(Element(
loader=XMLString("<em>Hello, world.</em>")))
element = RenderfulElement(loader=XMLString("""
<p xmlns:t="http://twistedmatrix.com/ns/twisted.web.template/0.1"
t:render="renderMethod" />
"""))
return self.assertFlattensTo(element, b"<p><em>Hello, world.</em></p>")
def test_elementUsingSlots(self):
"""
An Element which is returned by the render method of another Element is
rendered properly.
"""
class RenderfulElement(Element):
@renderer
def renderMethod(self, request, tag):
return tag.fillSlots(test2='world.')
element = RenderfulElement(loader=XMLString(
'<p xmlns:t="http://twistedmatrix.com/ns/twisted.web.template/0.1"'
' t:render="renderMethod">'
'<t:slot name="test1" default="Hello, " />'
'<t:slot name="test2" />'
'</p>'
))
return self.assertFlattensTo(element, b"<p>Hello, world.</p>")
def test_elementContainingDynamicElement(self):
"""
Directives in the document factory of an Element returned from a render
method of another Element are satisfied from the correct object: the
"inner" Element.
"""
class OuterElement(Element):
@renderer
def outerMethod(self, request, tag):
return tag(InnerElement(loader=XMLString("""
<t:ignored
xmlns:t="http://twistedmatrix.com/ns/twisted.web.template/0.1"
t:render="innerMethod" />
""")))
class InnerElement(Element):
@renderer
def innerMethod(self, request, tag):
return "Hello, world."
element = OuterElement(loader=XMLString("""
<p xmlns:t="http://twistedmatrix.com/ns/twisted.web.template/0.1"
t:render="outerMethod" />
"""))
return self.assertFlattensTo(element, b"<p>Hello, world.</p>")
def test_sameLoaderTwice(self):
"""
Rendering the output of a loader, or even the same element, should
return different output each time.
"""
sharedLoader = XMLString(
'<p xmlns:t="http://twistedmatrix.com/ns/twisted.web.template/0.1">'
'<t:transparent t:render="classCounter" /> '
'<t:transparent t:render="instanceCounter" />'
'</p>')
class DestructiveElement(Element):
count = 0
instanceCount = 0
loader = sharedLoader
@renderer
def classCounter(self, request, tag):
DestructiveElement.count += 1
return tag(str(DestructiveElement.count))
@renderer
def instanceCounter(self, request, tag):
self.instanceCount += 1
return tag(str(self.instanceCount))
e1 = DestructiveElement()
e2 = DestructiveElement()
self.assertFlattensImmediately(e1, b"<p>1 1</p>")
self.assertFlattensImmediately(e1, b"<p>2 2</p>")
self.assertFlattensImmediately(e2, b"<p>3 1</p>")
class TagLoaderTests(FlattenTestCase):
"""
Tests for L{TagLoader}.
"""
def setUp(self):
self.loader = TagLoader(tags.i('test'))
def test_interface(self):
"""
An instance of L{TagLoader} provides L{ITemplateLoader}.
"""
self.assertTrue(verifyObject(ITemplateLoader, self.loader))
def test_loadsList(self):
"""
L{TagLoader.load} returns a list, per L{ITemplateLoader}.
"""
self.assertIsInstance(self.loader.load(), list)
def test_flatten(self):
"""
L{TagLoader} can be used in an L{Element}, and flattens as the tag used
to construct the L{TagLoader} would flatten.
"""
e = Element(self.loader)
self.assertFlattensImmediately(e, b'<i>test</i>')
class TestElement(Element):
"""
An L{Element} that can be rendered successfully.
"""
loader = XMLString(
'<p xmlns:t="http://twistedmatrix.com/ns/twisted.web.template/0.1">'
'Hello, world.'
'</p>')
class TestFailureElement(Element):
"""
An L{Element} that can be used in place of L{FailureElement} to verify
that L{renderElement} can render failures properly.
"""
loader = XMLString(
'<p xmlns:t="http://twistedmatrix.com/ns/twisted.web.template/0.1">'
'I failed.'
'</p>')
def __init__(self, failure, loader=None):
self.failure = failure
class FailingElement(Element):
"""
An element that raises an exception when rendered.
"""
def render(self, request):
a = 42
b = 0
return a // b
class FakeSite(object):
"""
A minimal L{Site} object that we can use to test displayTracebacks
"""
displayTracebacks = False
class RenderElementTests(TestCase):
"""
Test L{renderElement}
"""
def setUp(self):
"""
Set up a common L{DummyRequest} and L{FakeSite}.
"""
self.request = DummyRequest([""])
self.request.site = FakeSite()
def test_simpleRender(self):
"""
L{renderElement} returns NOT_DONE_YET and eventually
writes the rendered L{Element} to the request before finishing the
request.
"""
element = TestElement()
d = self.request.notifyFinish()
def check(_):
self.assertEqual(
b"".join(self.request.written),
b"<!DOCTYPE html>\n"
b"<p>Hello, world.</p>")
self.assertTrue(self.request.finished)
d.addCallback(check)
self.assertIdentical(NOT_DONE_YET, renderElement(self.request, element))
return d
def test_simpleFailure(self):
"""
L{renderElement} handles failures by writing a minimal
error message to the request and finishing it.
"""
element = FailingElement()
d = self.request.notifyFinish()
def check(_):
flushed = self.flushLoggedErrors(FlattenerError)
self.assertEqual(len(flushed), 1)
self.assertEqual(
b"".join(self.request.written),
(b'<!DOCTYPE html>\n'
b'<div style="font-size:800%;'
b'background-color:#FFF;'
b'color:#F00'
b'">An error occurred while rendering the response.</div>'))
self.assertTrue(self.request.finished)
d.addCallback(check)
self.assertIdentical(NOT_DONE_YET, renderElement(self.request, element))
return d
def test_simpleFailureWithTraceback(self):
"""
L{renderElement} will render a traceback when rendering of
the element fails and our site is configured to display tracebacks.
"""
logObserver = EventLoggingObserver.createWithCleanup(
self,
globalLogPublisher
)
self.request.site.displayTracebacks = True
element = FailingElement()
d = self.request.notifyFinish()
def check(_):
self.assertEquals(1, len(logObserver))
f = logObserver[0]["log_failure"]
self.assertIsInstance(f.value, FlattenerError)
flushed = self.flushLoggedErrors(FlattenerError)
self.assertEqual(len(flushed), 1)
self.assertEqual(
b"".join(self.request.written),
b"<!DOCTYPE html>\n<p>I failed.</p>")
self.assertTrue(self.request.finished)
d.addCallback(check)
renderElement(self.request, element, _failElement=TestFailureElement)
return d
def test_nonDefaultDoctype(self):
"""
L{renderElement} will write the doctype string specified by the
doctype keyword argument.
"""
element = TestElement()
d = self.request.notifyFinish()
def check(_):
self.assertEqual(
b"".join(self.request.written),
(b'<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Strict//EN"'
b' "http://www.w3.org/TR/xhtml1/DTD/xhtml1-strict.dtd">\n'
b'<p>Hello, world.</p>'))
d.addCallback(check)
renderElement(
self.request,
element,
doctype=(
b'<!DOCTYPE html PUBLIC "-//W3C//DTD XHTML 1.0 Strict//EN"'
b' "http://www.w3.org/TR/xhtml1/DTD/xhtml1-strict.dtd">'))
return d
def test_noneDoctype(self):
"""
L{renderElement} will not write out a doctype if the doctype keyword
argument is L{None}.
"""
element = TestElement()
d = self.request.notifyFinish()
def check(_):
self.assertEqual(
b"".join(self.request.written),
b'<p>Hello, world.</p>')
d.addCallback(check)
renderElement(self.request, element, doctype=None)
return d

View file

@ -0,0 +1,366 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.web.util}.
"""
from __future__ import absolute_import, division
import gc
from twisted.python.failure import Failure
from twisted.trial.unittest import SynchronousTestCase, TestCase
from twisted.internet import defer
from twisted.python.compat import _PY3, intToBytes, networkString
from twisted.web import resource, util
from twisted.web.error import FlattenerError
from twisted.web.http import FOUND
from twisted.web.server import Request
from twisted.web.template import TagLoader, flattenString, tags
from twisted.web.test.requesthelper import DummyChannel, DummyRequest
from twisted.web.util import DeferredResource
from twisted.web.util import _SourceFragmentElement, _FrameElement
from twisted.web.util import _StackElement, FailureElement, formatFailure
from twisted.web.util import redirectTo, _SourceLineElement
class RedirectToTests(TestCase):
"""
Tests for L{redirectTo}.
"""
def test_headersAndCode(self):
"""
L{redirectTo} will set the C{Location} and C{Content-Type} headers on
its request, and set the response code to C{FOUND}, so the browser will
be redirected.
"""
request = Request(DummyChannel(), True)
request.method = b'GET'
targetURL = b"http://target.example.com/4321"
redirectTo(targetURL, request)
self.assertEqual(request.code, FOUND)
self.assertEqual(
request.responseHeaders.getRawHeaders(b'location'), [targetURL])
self.assertEqual(
request.responseHeaders.getRawHeaders(b'content-type'),
[b'text/html; charset=utf-8'])
def test_redirectToUnicodeURL(self) :
"""
L{redirectTo} will raise TypeError if unicode object is passed in URL
"""
request = Request(DummyChannel(), True)
request.method = b'GET'
targetURL = u'http://target.example.com/4321'
self.assertRaises(TypeError, redirectTo, targetURL, request)
class FailureElementTests(TestCase):
"""
Tests for L{FailureElement} and related helpers which can render a
L{Failure} as an HTML string.
"""
def setUp(self):
"""
Create a L{Failure} which can be used by the rendering tests.
"""
def lineNumberProbeAlsoBroken():
message = "This is a problem"
raise Exception(message)
# Figure out the line number from which the exception will be raised.
self.base = lineNumberProbeAlsoBroken.__code__.co_firstlineno + 1
try:
lineNumberProbeAlsoBroken()
except:
self.failure = Failure(captureVars=True)
self.frame = self.failure.frames[-1]
def test_sourceLineElement(self):
"""
L{_SourceLineElement} renders a source line and line number.
"""
element = _SourceLineElement(
TagLoader(tags.div(
tags.span(render="lineNumber"),
tags.span(render="sourceLine"))),
50, " print 'hello'")
d = flattenString(None, element)
expected = (
u"<div><span>50</span><span>"
u" \N{NO-BREAK SPACE} \N{NO-BREAK SPACE}print 'hello'</span></div>")
d.addCallback(
self.assertEqual, expected.encode('utf-8'))
return d
def test_sourceFragmentElement(self):
"""
L{_SourceFragmentElement} renders source lines at and around the line
number indicated by a frame object.
"""
element = _SourceFragmentElement(
TagLoader(tags.div(
tags.span(render="lineNumber"),
tags.span(render="sourceLine"),
render="sourceLines")),
self.frame)
source = [
u' \N{NO-BREAK SPACE} \N{NO-BREAK SPACE}message = '
u'"This is a problem"',
u' \N{NO-BREAK SPACE} \N{NO-BREAK SPACE}raise Exception(message)',
u'# Figure out the line number from which the exception will be '
u'raised.',
]
d = flattenString(None, element)
if _PY3:
stringToCheckFor = ''.join([
'<div class="snippet%sLine"><span>%d</span><span>%s</span>'
'</div>' % (
["", "Highlight"][lineNumber == 1],
self.base + lineNumber,
(u" \N{NO-BREAK SPACE}" * 4 + sourceLine))
for (lineNumber, sourceLine)
in enumerate(source)]).encode("utf8")
else:
stringToCheckFor = ''.join([
'<div class="snippet%sLine"><span>%d</span><span>%s</span>'
'</div>' % (
["", "Highlight"][lineNumber == 1],
self.base + lineNumber,
(u" \N{NO-BREAK SPACE}" * 4 + sourceLine).encode('utf8'))
for (lineNumber, sourceLine)
in enumerate(source)])
d.addCallback(self.assertEqual, stringToCheckFor)
return d
def test_frameElementFilename(self):
"""
The I{filename} renderer of L{_FrameElement} renders the filename
associated with the frame object used to initialize the
L{_FrameElement}.
"""
element = _FrameElement(
TagLoader(tags.span(render="filename")),
self.frame)
d = flattenString(None, element)
d.addCallback(
# __file__ differs depending on whether an up-to-date .pyc file
# already existed.
self.assertEqual,
b"<span>" + networkString(__file__.rstrip('c')) + b"</span>")
return d
def test_frameElementLineNumber(self):
"""
The I{lineNumber} renderer of L{_FrameElement} renders the line number
associated with the frame object used to initialize the
L{_FrameElement}.
"""
element = _FrameElement(
TagLoader(tags.span(render="lineNumber")),
self.frame)
d = flattenString(None, element)
d.addCallback(
self.assertEqual, b"<span>" + intToBytes(self.base + 1) + b"</span>")
return d
def test_frameElementFunction(self):
"""
The I{function} renderer of L{_FrameElement} renders the line number
associated with the frame object used to initialize the
L{_FrameElement}.
"""
element = _FrameElement(
TagLoader(tags.span(render="function")),
self.frame)
d = flattenString(None, element)
d.addCallback(
self.assertEqual, b"<span>lineNumberProbeAlsoBroken</span>")
return d
def test_frameElementSource(self):
"""
The I{source} renderer of L{_FrameElement} renders the source code near
the source filename/line number associated with the frame object used to
initialize the L{_FrameElement}.
"""
element = _FrameElement(None, self.frame)
renderer = element.lookupRenderMethod("source")
tag = tags.div()
result = renderer(None, tag)
self.assertIsInstance(result, _SourceFragmentElement)
self.assertIdentical(result.frame, self.frame)
self.assertEqual([tag], result.loader.load())
def test_stackElement(self):
"""
The I{frames} renderer of L{_StackElement} renders each stack frame in
the list of frames used to initialize the L{_StackElement}.
"""
element = _StackElement(None, self.failure.frames[:2])
renderer = element.lookupRenderMethod("frames")
tag = tags.div()
result = renderer(None, tag)
self.assertIsInstance(result, list)
self.assertIsInstance(result[0], _FrameElement)
self.assertIdentical(result[0].frame, self.failure.frames[0])
self.assertIsInstance(result[1], _FrameElement)
self.assertIdentical(result[1].frame, self.failure.frames[1])
# They must not share the same tag object.
self.assertNotEqual(result[0].loader.load(), result[1].loader.load())
self.assertEqual(2, len(result))
def test_failureElementTraceback(self):
"""
The I{traceback} renderer of L{FailureElement} renders the failure's
stack frames using L{_StackElement}.
"""
element = FailureElement(self.failure)
renderer = element.lookupRenderMethod("traceback")
tag = tags.div()
result = renderer(None, tag)
self.assertIsInstance(result, _StackElement)
self.assertIdentical(result.stackFrames, self.failure.frames)
self.assertEqual([tag], result.loader.load())
def test_failureElementType(self):
"""
The I{type} renderer of L{FailureElement} renders the failure's
exception type.
"""
element = FailureElement(
self.failure, TagLoader(tags.span(render="type")))
d = flattenString(None, element)
if _PY3:
exc = b"builtins.Exception"
else:
exc = b"exceptions.Exception"
d.addCallback(
self.assertEqual, b"<span>" + exc + b"</span>")
return d
def test_failureElementValue(self):
"""
The I{value} renderer of L{FailureElement} renders the value's exception
value.
"""
element = FailureElement(
self.failure, TagLoader(tags.span(render="value")))
d = flattenString(None, element)
d.addCallback(
self.assertEqual, b'<span>This is a problem</span>')
return d
class FormatFailureTests(TestCase):
"""
Tests for L{twisted.web.util.formatFailure} which returns an HTML string
representing the L{Failure} instance passed to it.
"""
def test_flattenerError(self):
"""
If there is an error flattening the L{Failure} instance,
L{formatFailure} raises L{FlattenerError}.
"""
self.assertRaises(FlattenerError, formatFailure, object())
def test_returnsBytes(self):
"""
The return value of L{formatFailure} is a C{str} instance (not a
C{unicode} instance) with numeric character references for any non-ASCII
characters meant to appear in the output.
"""
try:
raise Exception("Fake bug")
except:
result = formatFailure(Failure())
self.assertIsInstance(result, bytes)
if _PY3:
self.assertTrue(all(ch < 128 for ch in result))
else:
self.assertTrue(all(ord(ch) < 128 for ch in result))
# Indentation happens to rely on NO-BREAK SPACE
self.assertIn(b"&#160;", result)
class SDResource(resource.Resource):
def __init__(self,default):
self.default = default
def getChildWithDefault(self, name, request):
d = defer.succeed(self.default)
resource = util.DeferredResource(d)
return resource.getChildWithDefault(name, request)
class DeferredResourceTests(SynchronousTestCase):
"""
Tests for L{DeferredResource}.
"""
def testDeferredResource(self):
r = resource.Resource()
r.isLeaf = 1
s = SDResource(r)
d = DummyRequest(['foo', 'bar', 'baz'])
resource.getChildForRequest(s, d)
self.assertEqual(d.postpath, ['bar', 'baz'])
def test_render(self):
"""
L{DeferredResource} uses the request object's C{render} method to
render the resource which is the result of the L{Deferred} being
handled.
"""
rendered = []
request = DummyRequest([])
request.render = rendered.append
result = resource.Resource()
deferredResource = DeferredResource(defer.succeed(result))
deferredResource.render(request)
self.assertEqual(rendered, [result])
def test_renderNoFailure(self):
"""
If the L{Deferred} fails, L{DeferredResource} reports the failure via
C{processingFailed}, and does not cause an unhandled error to be
logged.
"""
request = DummyRequest([])
d = request.notifyFinish()
failure = Failure(RuntimeError())
deferredResource = DeferredResource(defer.fail(failure))
deferredResource.render(request)
self.assertEqual(self.failureResultOf(d), failure)
del deferredResource
gc.collect()
errors = self.flushLoggedErrors(RuntimeError)
self.assertEqual(errors, [])

View file

@ -0,0 +1,200 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.web.vhost}.
"""
from __future__ import absolute_import, division
from twisted.internet.defer import gatherResults
from twisted.trial.unittest import TestCase
from twisted.web.http import NOT_FOUND
from twisted.web.resource import NoResource
from twisted.web.static import Data
from twisted.web.server import Site
from twisted.web.vhost import (_HostResource,
NameVirtualHost,
VHostMonsterResource)
from twisted.web.test.test_web import DummyRequest
from twisted.web.test._util import _render
class HostResourceTests(TestCase):
"""
Tests for L{_HostResource}.
"""
def test_getChild(self):
"""
L{_HostResource.getChild} returns the proper I{Resource} for the vhost
embedded in the URL. Verify that returning the proper I{Resource}
required changing the I{Host} in the header.
"""
bazroot = Data(b'root data', "")
bazuri = Data(b'uri data', "")
baztest = Data(b'test data', "")
bazuri.putChild(b'test', baztest)
bazroot.putChild(b'uri', bazuri)
hr = _HostResource()
root = NameVirtualHost()
root.default = Data(b'default data', "")
root.addHost(b'baz.com', bazroot)
request = DummyRequest([b'uri', b'test'])
request.prepath = [b'bar', b'http', b'baz.com']
request.site = Site(root)
request.isSecure = lambda: False
request.host = b''
step = hr.getChild(b'baz.com', request) # Consumes rest of path
self.assertIsInstance(step, Data)
request = DummyRequest([b'uri', b'test'])
step = root.getChild(b'uri', request)
self.assertIsInstance(step, NoResource)
class NameVirtualHostTests(TestCase):
"""
Tests for L{NameVirtualHost}.
"""
def test_renderWithoutHost(self):
"""
L{NameVirtualHost.render} returns the result of rendering the
instance's C{default} if it is not L{None} and there is no I{Host}
header in the request.
"""
virtualHostResource = NameVirtualHost()
virtualHostResource.default = Data(b"correct result", "")
request = DummyRequest([''])
self.assertEqual(
virtualHostResource.render(request), b"correct result")
def test_renderWithoutHostNoDefault(self):
"""
L{NameVirtualHost.render} returns a response with a status of I{NOT
FOUND} if the instance's C{default} is L{None} and there is no I{Host}
header in the request.
"""
virtualHostResource = NameVirtualHost()
request = DummyRequest([''])
d = _render(virtualHostResource, request)
def cbRendered(ignored):
self.assertEqual(request.responseCode, NOT_FOUND)
d.addCallback(cbRendered)
return d
def test_renderWithHost(self):
"""
L{NameVirtualHost.render} returns the result of rendering the resource
which is the value in the instance's C{host} dictionary corresponding
to the key indicated by the value of the I{Host} header in the request.
"""
virtualHostResource = NameVirtualHost()
virtualHostResource.addHost(b'example.org', Data(b"winner", ""))
request = DummyRequest([b''])
request.requestHeaders.addRawHeader(b'host', b'example.org')
d = _render(virtualHostResource, request)
def cbRendered(ignored, request):
self.assertEqual(b''.join(request.written), b"winner")
d.addCallback(cbRendered, request)
# The port portion of the Host header should not be considered.
requestWithPort = DummyRequest([b''])
requestWithPort.requestHeaders.addRawHeader(b'host', b'example.org:8000')
dWithPort = _render(virtualHostResource, requestWithPort)
def cbRendered(ignored, requestWithPort):
self.assertEqual(b''.join(requestWithPort.written), b"winner")
dWithPort.addCallback(cbRendered, requestWithPort)
return gatherResults([d, dWithPort])
def test_renderWithUnknownHost(self):
"""
L{NameVirtualHost.render} returns the result of rendering the
instance's C{default} if it is not L{None} and there is no host
matching the value of the I{Host} header in the request.
"""
virtualHostResource = NameVirtualHost()
virtualHostResource.default = Data(b"correct data", "")
request = DummyRequest([b''])
request.requestHeaders.addRawHeader(b'host', b'example.com')
d = _render(virtualHostResource, request)
def cbRendered(ignored):
self.assertEqual(b''.join(request.written), b"correct data")
d.addCallback(cbRendered)
return d
def test_renderWithUnknownHostNoDefault(self):
"""
L{NameVirtualHost.render} returns a response with a status of I{NOT
FOUND} if the instance's C{default} is L{None} and there is no host
matching the value of the I{Host} header in the request.
"""
virtualHostResource = NameVirtualHost()
request = DummyRequest([''])
request.requestHeaders.addRawHeader(b'host', b'example.com')
d = _render(virtualHostResource, request)
def cbRendered(ignored):
self.assertEqual(request.responseCode, NOT_FOUND)
d.addCallback(cbRendered)
return d
def test_getChild(self):
"""
L{NameVirtualHost.getChild} returns correct I{Resource} based off
the header and modifies I{Request} to ensure proper prepath and
postpath are set.
"""
virtualHostResource = NameVirtualHost()
leafResource = Data(b"leaf data", "")
leafResource.isLeaf = True
normResource = Data(b"norm data", "")
virtualHostResource.addHost(b'leaf.example.org', leafResource)
virtualHostResource.addHost(b'norm.example.org', normResource)
request = DummyRequest([])
request.requestHeaders.addRawHeader(b'host', b'norm.example.org')
request.prepath = [b'']
self.assertIsInstance(virtualHostResource.getChild(b'', request),
NoResource)
self.assertEqual(request.prepath, [b''])
self.assertEqual(request.postpath, [])
request = DummyRequest([])
request.requestHeaders.addRawHeader(b'host', b'leaf.example.org')
request.prepath = [b'']
self.assertIsInstance(virtualHostResource.getChild(b'', request),
Data)
self.assertEqual(request.prepath, [])
self.assertEqual(request.postpath, [b''])
class VHostMonsterResourceTests(TestCase):
"""
Tests for L{VHostMonsterResource}.
"""
def test_getChild(self):
"""
L{VHostMonsterResource.getChild} returns I{_HostResource} and modifies
I{Request} with correct L{Request.isSecure}.
"""
vhm = VHostMonsterResource()
request = DummyRequest([])
self.assertIsInstance(vhm.getChild(b'http', request), _HostResource)
self.assertFalse(request.isSecure())
request = DummyRequest([])
self.assertIsInstance(vhm.getChild(b'https', request), _HostResource)
self.assertTrue(request.isSecure())

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,28 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
The L{_response} module contains constants for all standard HTTP codes, along
with a mapping to the corresponding phrases.
"""
from __future__ import division, absolute_import
import string
from twisted.trial import unittest
from twisted.web import _responses
class ResponseTests(unittest.TestCase):
def test_constants(self):
"""
All constants besides C{RESPONSES} defined in L{_response} are
integers and are keys in C{RESPONSES}.
"""
for sym in dir(_responses):
if sym == 'RESPONSES':
continue
if all((c == '_' or c in string.ascii_uppercase) for c in sym):
val = getattr(_responses, sym)
self.assertIsInstance(val, int)
self.assertIn(val, _responses.RESPONSES)

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,915 @@
# -*- test-case-name: twisted.web.test.test_xmlrpc -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for XML-RPC support in L{twisted.web.xmlrpc}.
"""
from __future__ import division, absolute_import
from twisted.python.compat import nativeString, networkString, NativeStringIO
from io import BytesIO
import datetime
from twisted.trial import unittest
from twisted.web import xmlrpc
from twisted.web.xmlrpc import XMLRPC, payloadTemplate, addIntrospection
from twisted.web.xmlrpc import _QueryFactory, withRequest, xmlrpclib
from twisted.web import server, client, http, static
from twisted.internet import reactor, defer
from twisted.internet.error import ConnectionDone
from twisted.python import failure
from twisted.python.reflect import namedModule
from twisted.test.proto_helpers import MemoryReactor, EventLoggingObserver
from twisted.web.test.test_web import DummyRequest
from twisted.logger import (globalLogPublisher, FilteringLogObserver,
LogLevelFilterPredicate, LogLevel)
try:
namedModule('twisted.internet.ssl')
except ImportError:
sslSkip = "OpenSSL not present"
else:
sslSkip = None
class AsyncXMLRPCTests(unittest.TestCase):
"""
Tests for L{XMLRPC}'s support of Deferreds.
"""
def setUp(self):
self.request = DummyRequest([''])
self.request.method = 'POST'
self.request.content = NativeStringIO(
payloadTemplate % ('async', xmlrpclib.dumps(())))
result = self.result = defer.Deferred()
class AsyncResource(XMLRPC):
def xmlrpc_async(self):
return result
self.resource = AsyncResource()
def test_deferredResponse(self):
"""
If an L{XMLRPC} C{xmlrpc_*} method returns a L{defer.Deferred}, the
response to the request is the result of that L{defer.Deferred}.
"""
self.resource.render(self.request)
self.assertEqual(self.request.written, [])
self.result.callback("result")
resp = xmlrpclib.loads(b"".join(self.request.written))
self.assertEqual(resp, (('result',), None))
self.assertEqual(self.request.finished, 1)
def test_interruptedDeferredResponse(self):
"""
While waiting for the L{Deferred} returned by an L{XMLRPC} C{xmlrpc_*}
method to fire, the connection the request was issued over may close.
If this happens, neither C{write} nor C{finish} is called on the
request.
"""
self.resource.render(self.request)
self.request.processingFailed(
failure.Failure(ConnectionDone("Simulated")))
self.result.callback("result")
self.assertEqual(self.request.written, [])
self.assertEqual(self.request.finished, 0)
class TestRuntimeError(RuntimeError):
pass
class TestValueError(ValueError):
pass
class Test(XMLRPC):
# If you add xmlrpc_ methods to this class, go change test_listMethods
# below.
FAILURE = 666
NOT_FOUND = 23
SESSION_EXPIRED = 42
def xmlrpc_echo(self, arg):
return arg
# the doc string is part of the test
def xmlrpc_add(self, a, b):
"""
This function add two numbers.
"""
return a + b
xmlrpc_add.signature = [['int', 'int', 'int'],
['double', 'double', 'double']]
# the doc string is part of the test
def xmlrpc_pair(self, string, num):
"""
This function puts the two arguments in an array.
"""
return [string, num]
xmlrpc_pair.signature = [['array', 'string', 'int']]
# the doc string is part of the test
def xmlrpc_defer(self, x):
"""Help for defer."""
return defer.succeed(x)
def xmlrpc_deferFail(self):
return defer.fail(TestValueError())
# don't add a doc string, it's part of the test
def xmlrpc_fail(self):
raise TestRuntimeError
def xmlrpc_fault(self):
return xmlrpc.Fault(12, "hello")
def xmlrpc_deferFault(self):
return defer.fail(xmlrpc.Fault(17, "hi"))
def xmlrpc_snowman(self, payload):
"""
Used to test that we can pass Unicode.
"""
snowman = u"\u2603"
if snowman != payload:
return xmlrpc.Fault(13, "Payload not unicode snowman")
return snowman
def xmlrpc_complex(self):
return {"a": ["b", "c", 12, []], "D": "foo"}
def xmlrpc_dict(self, map, key):
return map[key]
xmlrpc_dict.help = 'Help for dict.'
@withRequest
def xmlrpc_withRequest(self, request, other):
"""
A method decorated with L{withRequest} which can be called by
a test to verify that the request object really is passed as
an argument.
"""
return (
# as a proof that request is a request
request.method +
# plus proof other arguments are still passed along
' ' + other)
def lookupProcedure(self, procedurePath):
try:
return XMLRPC.lookupProcedure(self, procedurePath)
except xmlrpc.NoSuchFunction:
if procedurePath.startswith("SESSION"):
raise xmlrpc.Fault(self.SESSION_EXPIRED,
"Session non-existent/expired.")
else:
raise
class TestLookupProcedure(XMLRPC):
"""
This is a resource which customizes procedure lookup to be used by the tests
of support for this customization.
"""
def echo(self, x):
return x
def lookupProcedure(self, procedureName):
"""
Lookup a procedure from a fixed set of choices, either I{echo} or
I{system.listeMethods}.
"""
if procedureName == 'echo':
return self.echo
raise xmlrpc.NoSuchFunction(
self.NOT_FOUND, 'procedure %s not found' % (procedureName,))
class TestListProcedures(XMLRPC):
"""
This is a resource which customizes procedure enumeration to be used by the
tests of support for this customization.
"""
def listProcedures(self):
"""
Return a list of a single method this resource will claim to support.
"""
return ['foo']
class TestAuthHeader(Test):
"""
This is used to get the header info so that we can test
authentication.
"""
def __init__(self):
Test.__init__(self)
self.request = None
def render(self, request):
self.request = request
return Test.render(self, request)
def xmlrpc_authinfo(self):
return self.request.getUser(), self.request.getPassword()
class TestQueryProtocol(xmlrpc.QueryProtocol):
"""
QueryProtocol for tests that saves headers received and sent,
inside the factory.
"""
def connectionMade(self):
self.factory.transport = self.transport
xmlrpc.QueryProtocol.connectionMade(self)
def handleHeader(self, key, val):
self.factory.headers[key.lower()] = val
def sendHeader(self, key, val):
"""
Keep sent headers so we can inspect them later.
"""
self.factory.sent_headers[key.lower()] = val
xmlrpc.QueryProtocol.sendHeader(self, key, val)
class TestQueryFactory(xmlrpc._QueryFactory):
"""
QueryFactory using L{TestQueryProtocol} for saving headers.
"""
protocol = TestQueryProtocol
def __init__(self, *args, **kwargs):
self.headers = {}
self.sent_headers = {}
xmlrpc._QueryFactory.__init__(self, *args, **kwargs)
class TestQueryFactoryCancel(xmlrpc._QueryFactory):
"""
QueryFactory that saves a reference to the
L{twisted.internet.interfaces.IConnector} to test connection lost.
"""
def startedConnecting(self, connector):
self.connector = connector
class XMLRPCTests(unittest.TestCase):
def setUp(self):
self.p = reactor.listenTCP(0, server.Site(Test()),
interface="127.0.0.1")
self.port = self.p.getHost().port
self.factories = []
def tearDown(self):
self.factories = []
return self.p.stopListening()
def queryFactory(self, *args, **kwargs):
"""
Specific queryFactory for proxy that uses our custom
L{TestQueryFactory}, and save factories.
"""
factory = TestQueryFactory(*args, **kwargs)
self.factories.append(factory)
return factory
def proxy(self, factory=None):
"""
Return a new xmlrpc.Proxy for the test site created in
setUp(), using the given factory as the queryFactory, or
self.queryFactory if no factory is provided.
"""
p = xmlrpc.Proxy(networkString("http://127.0.0.1:%d/" % self.port))
if factory is None:
p.queryFactory = self.queryFactory
else:
p.queryFactory = factory
return p
def test_results(self):
inputOutput = [
("add", (2, 3), 5),
("defer", ("a",), "a"),
("dict", ({"a": 1}, "a"), 1),
("pair", ("a", 1), ["a", 1]),
("snowman", (u"\u2603"), u"\u2603"),
("complex", (), {"a": ["b", "c", 12, []], "D": "foo"})]
dl = []
for meth, args, outp in inputOutput:
d = self.proxy().callRemote(meth, *args)
d.addCallback(self.assertEqual, outp)
dl.append(d)
return defer.DeferredList(dl, fireOnOneErrback=True)
def test_headers(self):
"""
Verify that headers sent from the client side and the ones we
get back from the server side are correct.
"""
d = self.proxy().callRemote("snowman", u"\u2603")
def check_server_headers(ing):
self.assertEqual(
self.factories[0].headers[b'content-type'],
b'text/xml; charset=utf-8')
self.assertEqual(
self.factories[0].headers[b'content-length'], b'129')
def check_client_headers(ign):
self.assertEqual(
self.factories[0].sent_headers[b'user-agent'],
b'Twisted/XMLRPClib')
self.assertEqual(
self.factories[0].sent_headers[b'content-type'],
b'text/xml; charset=utf-8')
self.assertEqual(
self.factories[0].sent_headers[b'content-length'], b'155')
d.addCallback(check_server_headers)
d.addCallback(check_client_headers)
return d
def test_errors(self):
"""
Verify that for each way a method exposed via XML-RPC can fail, the
correct 'Content-type' header is set in the response and that the
client-side Deferred is errbacked with an appropriate C{Fault}
instance.
"""
logObserver = EventLoggingObserver()
filtered = FilteringLogObserver(
logObserver,
[LogLevelFilterPredicate(defaultLogLevel=LogLevel.critical)]
)
globalLogPublisher.addObserver(filtered)
self.addCleanup(lambda: globalLogPublisher.removeObserver(filtered))
dl = []
for code, methodName in [(666, "fail"), (666, "deferFail"),
(12, "fault"), (23, "noSuchMethod"),
(17, "deferFault"), (42, "SESSION_TEST")]:
d = self.proxy().callRemote(methodName)
d = self.assertFailure(d, xmlrpc.Fault)
d.addCallback(lambda exc, code=code:
self.assertEqual(exc.faultCode, code))
dl.append(d)
d = defer.DeferredList(dl, fireOnOneErrback=True)
def cb(ign):
for factory in self.factories:
self.assertEqual(factory.headers[b'content-type'],
b'text/xml; charset=utf-8')
self.assertEquals(2, len(logObserver))
f1 = logObserver[0]["log_failure"].value
f2 = logObserver[1]["log_failure"].value
if isinstance(f1, TestValueError):
self.assertIsInstance(f2, TestRuntimeError)
else:
self.assertIsInstance(f1, TestRuntimeError)
self.assertIsInstance(f2, TestValueError)
self.flushLoggedErrors(TestRuntimeError, TestValueError)
d.addCallback(cb)
return d
def test_cancel(self):
"""
A deferred from the Proxy can be cancelled, disconnecting
the L{twisted.internet.interfaces.IConnector}.
"""
def factory(*args, **kw):
factory.f = TestQueryFactoryCancel(*args, **kw)
return factory.f
d = self.proxy(factory).callRemote('add', 2, 3)
self.assertNotEqual(factory.f.connector.state, "disconnected")
d.cancel()
self.assertEqual(factory.f.connector.state, "disconnected")
d = self.assertFailure(d, defer.CancelledError)
return d
def test_errorGet(self):
"""
A classic GET on the xml server should return a NOT_ALLOWED.
"""
agent = client.Agent(reactor)
d = agent.request(b"GET", networkString("http://127.0.0.1:%d/" % (self.port,)))
def checkResponse(response):
self.assertEqual(response.code, http.NOT_ALLOWED)
d.addCallback(checkResponse)
return d
def test_errorXMLContent(self):
"""
Test that an invalid XML input returns an L{xmlrpc.Fault}.
"""
agent = client.Agent(reactor)
d = agent.request(
uri=networkString("http://127.0.0.1:%d/" % (self.port,)),
method=b"POST",
bodyProducer=client.FileBodyProducer(BytesIO(b"foo")))
d.addCallback(client.readBody)
def cb(result):
self.assertRaises(xmlrpc.Fault, xmlrpclib.loads, result)
d.addCallback(cb)
return d
def test_datetimeRoundtrip(self):
"""
If an L{xmlrpclib.DateTime} is passed as an argument to an XML-RPC
call and then returned by the server unmodified, the result should
be equal to the original object.
"""
when = xmlrpclib.DateTime()
d = self.proxy().callRemote("echo", when)
d.addCallback(self.assertEqual, when)
return d
def test_doubleEncodingError(self):
"""
If it is not possible to encode a response to the request (for example,
because L{xmlrpclib.dumps} raises an exception when encoding a
L{Fault}) the exception which prevents the response from being
generated is logged and the request object is finished anyway.
"""
logObserver = EventLoggingObserver()
filtered = FilteringLogObserver(
logObserver,
[LogLevelFilterPredicate(defaultLogLevel=LogLevel.critical)]
)
globalLogPublisher.addObserver(filtered)
self.addCleanup(lambda: globalLogPublisher.removeObserver(filtered))
d = self.proxy().callRemote("echo", "")
# *Now* break xmlrpclib.dumps. Hopefully the client already used it.
def fakeDumps(*args, **kwargs):
raise RuntimeError("Cannot encode anything at all!")
self.patch(xmlrpclib, 'dumps', fakeDumps)
# It doesn't matter how it fails, so long as it does. Also, it happens
# to fail with an implementation detail exception right now, not
# something suitable as part of a public interface.
d = self.assertFailure(d, Exception)
def cbFailed(ignored):
# The fakeDumps exception should have been logged.
self.assertEquals(1, len(logObserver))
self.assertIsInstance(
logObserver[0]["log_failure"].value,
RuntimeError
)
self.assertEqual(len(self.flushLoggedErrors(RuntimeError)), 1)
d.addCallback(cbFailed)
return d
def test_closeConnectionAfterRequest(self):
"""
The connection to the web server is closed when the request is done.
"""
d = self.proxy().callRemote('echo', '')
def responseDone(ignored):
[factory] = self.factories
self.assertFalse(factory.transport.connected)
self.assertTrue(factory.transport.disconnected)
return d.addCallback(responseDone)
def test_tcpTimeout(self):
"""
For I{HTTP} URIs, L{xmlrpc.Proxy.callRemote} passes the value it
received for the C{connectTimeout} parameter as the C{timeout} argument
to the underlying connectTCP call.
"""
reactor = MemoryReactor()
proxy = xmlrpc.Proxy(b"http://127.0.0.1:69", connectTimeout=2.0,
reactor=reactor)
proxy.callRemote("someMethod")
self.assertEqual(reactor.tcpClients[0][3], 2.0)
def test_sslTimeout(self):
"""
For I{HTTPS} URIs, L{xmlrpc.Proxy.callRemote} passes the value it
received for the C{connectTimeout} parameter as the C{timeout} argument
to the underlying connectSSL call.
"""
reactor = MemoryReactor()
proxy = xmlrpc.Proxy(b"https://127.0.0.1:69", connectTimeout=3.0,
reactor=reactor)
proxy.callRemote("someMethod")
self.assertEqual(reactor.sslClients[0][4], 3.0)
test_sslTimeout.skip = sslSkip
class XMLRPCProxyWithoutSlashTests(XMLRPCTests):
"""
Test with proxy that doesn't add a slash.
"""
def proxy(self, factory=None):
p = xmlrpc.Proxy(networkString("http://127.0.0.1:%d" % self.port))
if factory is None:
p.queryFactory = self.queryFactory
else:
p.queryFactory = factory
return p
class XMLRPCPublicLookupProcedureTests(unittest.TestCase):
"""
Tests for L{XMLRPC}'s support of subclasses which override
C{lookupProcedure} and C{listProcedures}.
"""
def createServer(self, resource):
self.p = reactor.listenTCP(
0, server.Site(resource), interface="127.0.0.1")
self.addCleanup(self.p.stopListening)
self.port = self.p.getHost().port
self.proxy = xmlrpc.Proxy(
networkString('http://127.0.0.1:%d' % self.port))
def test_lookupProcedure(self):
"""
A subclass of L{XMLRPC} can override C{lookupProcedure} to find
procedures that are not defined using a C{xmlrpc_}-prefixed method name.
"""
self.createServer(TestLookupProcedure())
what = "hello"
d = self.proxy.callRemote("echo", what)
d.addCallback(self.assertEqual, what)
return d
def test_errors(self):
"""
A subclass of L{XMLRPC} can override C{lookupProcedure} to raise
L{NoSuchFunction} to indicate that a requested method is not available
to be called, signalling a fault to the XML-RPC client.
"""
self.createServer(TestLookupProcedure())
d = self.proxy.callRemote("xxxx", "hello")
d = self.assertFailure(d, xmlrpc.Fault)
return d
def test_listMethods(self):
"""
A subclass of L{XMLRPC} can override C{listProcedures} to define
Overriding listProcedures should prevent introspection from being
broken.
"""
resource = TestListProcedures()
addIntrospection(resource)
self.createServer(resource)
d = self.proxy.callRemote("system.listMethods")
def listed(procedures):
# The list will also include other introspection procedures added by
# addIntrospection. We just want to see "foo" from our customized
# listProcedures.
self.assertIn('foo', procedures)
d.addCallback(listed)
return d
class SerializationConfigMixin:
"""
Mixin which defines a couple tests which should pass when a particular flag
is passed to L{XMLRPC}.
These are not meant to be exhaustive serialization tests, since L{xmlrpclib}
does all of the actual serialization work. They are just meant to exercise
a few codepaths to make sure we are calling into xmlrpclib correctly.
@ivar flagName: A C{str} giving the name of the flag which must be passed to
L{XMLRPC} to allow the tests to pass. Subclasses should set this.
@ivar value: A value which the specified flag will allow the serialization
of. Subclasses should set this.
"""
def setUp(self):
"""
Create a new XML-RPC server with C{allowNone} set to C{True}.
"""
kwargs = {self.flagName: True}
self.p = reactor.listenTCP(
0, server.Site(Test(**kwargs)), interface="127.0.0.1")
self.addCleanup(self.p.stopListening)
self.port = self.p.getHost().port
self.proxy = xmlrpc.Proxy(
networkString("http://127.0.0.1:%d/" % (self.port,)), **kwargs)
def test_roundtripValue(self):
"""
C{self.value} can be round-tripped over an XMLRPC method call/response.
"""
d = self.proxy.callRemote('defer', self.value)
d.addCallback(self.assertEqual, self.value)
return d
def test_roundtripNestedValue(self):
"""
A C{dict} which contains C{self.value} can be round-tripped over an
XMLRPC method call/response.
"""
d = self.proxy.callRemote('defer', {'a': self.value})
d.addCallback(self.assertEqual, {'a': self.value})
return d
class XMLRPCAllowNoneTests(SerializationConfigMixin, unittest.TestCase):
"""
Tests for passing L{None} when the C{allowNone} flag is set.
"""
flagName = "allowNone"
value = None
class XMLRPCUseDateTimeTests(SerializationConfigMixin, unittest.TestCase):
"""
Tests for passing a C{datetime.datetime} instance when the C{useDateTime}
flag is set.
"""
flagName = "useDateTime"
value = datetime.datetime(2000, 12, 28, 3, 45, 59)
class XMLRPCAuthenticatedTests(XMLRPCTests):
"""
Test with authenticated proxy. We run this with the same input/output as
above.
"""
user = b"username"
password = b"asecret"
def setUp(self):
self.p = reactor.listenTCP(0, server.Site(TestAuthHeader()),
interface="127.0.0.1")
self.port = self.p.getHost().port
self.factories = []
def test_authInfoInURL(self):
url = "http://%s:%s@127.0.0.1:%d/" % (
nativeString(self.user), nativeString(self.password), self.port)
p = xmlrpc.Proxy(networkString(url))
d = p.callRemote("authinfo")
d.addCallback(self.assertEqual, [self.user, self.password])
return d
def test_explicitAuthInfo(self):
p = xmlrpc.Proxy(networkString("http://127.0.0.1:%d/" % (
self.port,)), self.user, self.password)
d = p.callRemote("authinfo")
d.addCallback(self.assertEqual, [self.user, self.password])
return d
def test_longPassword(self):
"""
C{QueryProtocol} uses the C{base64.b64encode} function to encode user
name and password in the I{Authorization} header, so that it doesn't
embed new lines when using long inputs.
"""
longPassword = self.password * 40
p = xmlrpc.Proxy(networkString("http://127.0.0.1:%d/" % (
self.port,)), self.user, longPassword)
d = p.callRemote("authinfo")
d.addCallback(self.assertEqual, [self.user, longPassword])
return d
def test_explicitAuthInfoOverride(self):
p = xmlrpc.Proxy(networkString("http://wrong:info@127.0.0.1:%d/" % (
self.port,)), self.user, self.password)
d = p.callRemote("authinfo")
d.addCallback(self.assertEqual, [self.user, self.password])
return d
class XMLRPCIntrospectionTests(XMLRPCTests):
def setUp(self):
xmlrpc = Test()
addIntrospection(xmlrpc)
self.p = reactor.listenTCP(0, server.Site(xmlrpc),interface="127.0.0.1")
self.port = self.p.getHost().port
self.factories = []
def test_listMethods(self):
def cbMethods(meths):
meths.sort()
self.assertEqual(
meths,
['add', 'complex', 'defer', 'deferFail',
'deferFault', 'dict', 'echo', 'fail', 'fault',
'pair', 'snowman', 'system.listMethods',
'system.methodHelp',
'system.methodSignature', 'withRequest'])
d = self.proxy().callRemote("system.listMethods")
d.addCallback(cbMethods)
return d
def test_methodHelp(self):
inputOutputs = [
("defer", "Help for defer."),
("fail", ""),
("dict", "Help for dict.")]
dl = []
for meth, expected in inputOutputs:
d = self.proxy().callRemote("system.methodHelp", meth)
d.addCallback(self.assertEqual, expected)
dl.append(d)
return defer.DeferredList(dl, fireOnOneErrback=True)
def test_methodSignature(self):
inputOutputs = [
("defer", ""),
("add", [['int', 'int', 'int'],
['double', 'double', 'double']]),
("pair", [['array', 'string', 'int']])]
dl = []
for meth, expected in inputOutputs:
d = self.proxy().callRemote("system.methodSignature", meth)
d.addCallback(self.assertEqual, expected)
dl.append(d)
return defer.DeferredList(dl, fireOnOneErrback=True)
class XMLRPCClientErrorHandlingTests(unittest.TestCase):
"""
Test error handling on the xmlrpc client.
"""
def setUp(self):
self.resource = static.Data(
b"This text is not a valid XML-RPC response.",
b"text/plain")
self.resource.isLeaf = True
self.port = reactor.listenTCP(0, server.Site(self.resource),
interface='127.0.0.1')
def tearDown(self):
return self.port.stopListening()
def test_erroneousResponse(self):
"""
Test that calling the xmlrpc client on a static http server raises
an exception.
"""
proxy = xmlrpc.Proxy(networkString("http://127.0.0.1:%d/" %
(self.port.getHost().port,)))
return self.assertFailure(proxy.callRemote("someMethod"), ValueError)
class QueryFactoryParseResponseTests(unittest.TestCase):
"""
Test the behaviour of L{_QueryFactory.parseResponse}.
"""
def setUp(self):
# The _QueryFactory that we are testing. We don't care about any
# of the constructor parameters.
self.queryFactory = _QueryFactory(
path=None, host=None, method='POST', user=None, password=None,
allowNone=False, args=())
# An XML-RPC response that will parse without raising an error.
self.goodContents = xmlrpclib.dumps(('',))
# An 'XML-RPC response' that will raise a parsing error.
self.badContents = 'invalid xml'
# A dummy 'reason' to pass to clientConnectionLost. We don't care
# what it is.
self.reason = failure.Failure(ConnectionDone())
def test_parseResponseCallbackSafety(self):
"""
We can safely call L{_QueryFactory.clientConnectionLost} as a callback
of L{_QueryFactory.parseResponse}.
"""
d = self.queryFactory.deferred
# The failure mode is that this callback raises an AlreadyCalled
# error. We have to add it now so that it gets called synchronously
# and triggers the race condition.
d.addCallback(self.queryFactory.clientConnectionLost, self.reason)
self.queryFactory.parseResponse(self.goodContents)
return d
def test_parseResponseErrbackSafety(self):
"""
We can safely call L{_QueryFactory.clientConnectionLost} as an errback
of L{_QueryFactory.parseResponse}.
"""
d = self.queryFactory.deferred
# The failure mode is that this callback raises an AlreadyCalled
# error. We have to add it now so that it gets called synchronously
# and triggers the race condition.
d.addErrback(self.queryFactory.clientConnectionLost, self.reason)
self.queryFactory.parseResponse(self.badContents)
return d
def test_badStatusErrbackSafety(self):
"""
We can safely call L{_QueryFactory.clientConnectionLost} as an errback
of L{_QueryFactory.badStatus}.
"""
d = self.queryFactory.deferred
# The failure mode is that this callback raises an AlreadyCalled
# error. We have to add it now so that it gets called synchronously
# and triggers the race condition.
d.addErrback(self.queryFactory.clientConnectionLost, self.reason)
self.queryFactory.badStatus('status', 'message')
return d
def test_parseResponseWithoutData(self):
"""
Some server can send a response without any data:
L{_QueryFactory.parseResponse} should catch the error and call the
result errback.
"""
content = """
<methodResponse>
<params>
<param>
</param>
</params>
</methodResponse>"""
d = self.queryFactory.deferred
self.queryFactory.parseResponse(content)
return self.assertFailure(d, IndexError)
class XMLRPCWithRequestTests(unittest.TestCase):
def setUp(self):
self.resource = Test()
def test_withRequest(self):
"""
When an XML-RPC method is called and the implementation is
decorated with L{withRequest}, the request object is passed as
the first argument.
"""
request = DummyRequest('/RPC2')
request.method = "POST"
request.content = NativeStringIO(xmlrpclib.dumps(
("foo",), 'withRequest'))
def valid(n, request):
data = xmlrpclib.loads(request.written[0])
self.assertEqual(data, (('POST foo',), None))
d = request.notifyFinish().addCallback(valid, request)
self.resource.render_POST(request)
return d

View file

@ -0,0 +1,321 @@
# -*- test-case-name: twisted.web.test.test_cgi -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
I hold resource classes and helper classes that deal with CGI scripts.
"""
# System Imports
import os
import urllib
# Twisted Imports
from twisted.internet import protocol
from twisted.logger import Logger
from twisted.python import filepath
from twisted.spread import pb
from twisted.web import http, resource, server, static
class CGIDirectory(resource.Resource, filepath.FilePath):
def __init__(self, pathname):
resource.Resource.__init__(self)
filepath.FilePath.__init__(self, pathname)
def getChild(self, path, request):
fnp = self.child(path)
if not fnp.exists():
return static.File.childNotFound
elif fnp.isdir():
return CGIDirectory(fnp.path)
else:
return CGIScript(fnp.path)
return resource.NoResource()
def render(self, request):
notFound = resource.NoResource(
"CGI directories do not support directory listing.")
return notFound.render(request)
class CGIScript(resource.Resource):
"""
L{CGIScript} is a resource which runs child processes according to the CGI
specification.
The implementation is complex due to the fact that it requires asynchronous
IPC with an external process with an unpleasant protocol.
"""
isLeaf = 1
def __init__(self, filename, registry=None, reactor=None):
"""
Initialize, with the name of a CGI script file.
"""
self.filename = filename
if reactor is None:
# This installs a default reactor, if None was installed before.
# We do a late import here, so that importing the current module
# won't directly trigger installing a default reactor.
from twisted.internet import reactor
self._reactor = reactor
def render(self, request):
"""
Do various things to conform to the CGI specification.
I will set up the usual slew of environment variables, then spin off a
process.
@type request: L{twisted.web.http.Request}
@param request: An HTTP request.
"""
scriptName = b"/" + b"/".join(request.prepath)
serverName = request.getRequestHostname().split(b':')[0]
env = {"SERVER_SOFTWARE": server.version,
"SERVER_NAME": serverName,
"GATEWAY_INTERFACE": "CGI/1.1",
"SERVER_PROTOCOL": request.clientproto,
"SERVER_PORT": str(request.getHost().port),
"REQUEST_METHOD": request.method,
"SCRIPT_NAME": scriptName,
"SCRIPT_FILENAME": self.filename,
"REQUEST_URI": request.uri}
ip = request.getClientAddress().host
if ip is not None:
env['REMOTE_ADDR'] = ip
pp = request.postpath
if pp:
env["PATH_INFO"] = "/" + "/".join(pp)
if hasattr(request, "content"):
# 'request.content' is either a StringIO or a TemporaryFile, and
# the file pointer is sitting at the beginning (seek(0,0))
request.content.seek(0, 2)
length = request.content.tell()
request.content.seek(0, 0)
env['CONTENT_LENGTH'] = str(length)
try:
qindex = request.uri.index(b'?')
except ValueError:
env['QUERY_STRING'] = ''
qargs = []
else:
qs = env['QUERY_STRING'] = request.uri[qindex+1:]
if '=' in qs:
qargs = []
else:
qargs = [urllib.unquote(x) for x in qs.split('+')]
# Propagate HTTP headers
for title, header in request.getAllHeaders().items():
envname = title.replace(b'-', b'_').upper()
if title not in (b'content-type', b'content-length', b'proxy'):
envname = b"HTTP_" + envname
env[envname] = header
# Propagate our environment
for key, value in os.environ.items():
if key not in env:
env[key] = value
# And they're off!
self.runProcess(env, request, qargs)
return server.NOT_DONE_YET
def runProcess(self, env, request, qargs=[]):
"""
Run the cgi script.
@type env: A L{dict} of L{str}, or L{None}
@param env: The environment variables to pass to the process that will
get spawned. See
L{twisted.internet.interfaces.IReactorProcess.spawnProcess} for
more information about environments and process creation.
@type request: L{twisted.web.http.Request}
@param request: An HTTP request.
@type qargs: A L{list} of L{str}
@param qargs: The command line arguments to pass to the process that
will get spawned.
"""
p = CGIProcessProtocol(request)
self._reactor.spawnProcess(p, self.filename, [self.filename] + qargs,
env, os.path.dirname(self.filename))
class FilteredScript(CGIScript):
"""
I am a special version of a CGI script, that uses a specific executable.
This is useful for interfacing with other scripting languages that adhere
to the CGI standard. My C{filter} attribute specifies what executable to
run, and my C{filename} init parameter describes which script to pass to
the first argument of that script.
To customize me for a particular location of a CGI interpreter, override
C{filter}.
@type filter: L{str}
@ivar filter: The absolute path to the executable.
"""
filter = '/usr/bin/cat'
def runProcess(self, env, request, qargs=[]):
"""
Run a script through the C{filter} executable.
@type env: A L{dict} of L{str}, or L{None}
@param env: The environment variables to pass to the process that will
get spawned. See
L{twisted.internet.interfaces.IReactorProcess.spawnProcess}
for more information about environments and process creation.
@type request: L{twisted.web.http.Request}
@param request: An HTTP request.
@type qargs: A L{list} of L{str}
@param qargs: The command line arguments to pass to the process that
will get spawned.
"""
p = CGIProcessProtocol(request)
self._reactor.spawnProcess(p, self.filter,
[self.filter, self.filename] + qargs, env,
os.path.dirname(self.filename))
class CGIProcessProtocol(protocol.ProcessProtocol, pb.Viewable):
handling_headers = 1
headers_written = 0
headertext = b''
errortext = b''
_log = Logger()
# Remotely relay producer interface.
def view_resumeProducing(self, issuer):
self.resumeProducing()
def view_pauseProducing(self, issuer):
self.pauseProducing()
def view_stopProducing(self, issuer):
self.stopProducing()
def resumeProducing(self):
self.transport.resumeProducing()
def pauseProducing(self):
self.transport.pauseProducing()
def stopProducing(self):
self.transport.loseConnection()
def __init__(self, request):
self.request = request
def connectionMade(self):
self.request.registerProducer(self, 1)
self.request.content.seek(0, 0)
content = self.request.content.read()
if content:
self.transport.write(content)
self.transport.closeStdin()
def errReceived(self, error):
self.errortext = self.errortext + error
def outReceived(self, output):
"""
Handle a chunk of input
"""
# First, make sure that the headers from the script are sorted
# out (we'll want to do some parsing on these later.)
if self.handling_headers:
text = self.headertext + output
headerEnds = []
for delimiter in b'\n\n', b'\r\n\r\n', b'\r\r', b'\n\r\n':
headerend = text.find(delimiter)
if headerend != -1:
headerEnds.append((headerend, delimiter))
if headerEnds:
# The script is entirely in control of response headers;
# disable the default Content-Type value normally provided by
# twisted.web.server.Request.
self.request.defaultContentType = None
headerEnds.sort()
headerend, delimiter = headerEnds[0]
self.headertext = text[:headerend]
# This is a final version of the header text.
linebreak = delimiter[:len(delimiter)//2]
headers = self.headertext.split(linebreak)
for header in headers:
br = header.find(b': ')
if br == -1:
self._log.error(
'ignoring malformed CGI header: {header!r}',
header=header)
else:
headerName = header[:br].lower()
headerText = header[br+2:]
if headerName == b'location':
self.request.setResponseCode(http.FOUND)
if headerName == b'status':
try:
# "XXX <description>" sometimes happens.
statusNum = int(headerText[:3])
except:
self._log.error("malformed status header")
else:
self.request.setResponseCode(statusNum)
else:
# Don't allow the application to control
# these required headers.
if headerName.lower() not in (b'server', b'date'):
self.request.responseHeaders.addRawHeader(
headerName, headerText)
output = text[headerend+len(delimiter):]
self.handling_headers = 0
if self.handling_headers:
self.headertext = text
if not self.handling_headers:
self.request.write(output)
def processEnded(self, reason):
if reason.value.exitCode != 0:
self._log.error("CGI {uri} exited with exit code {exitCode}",
uri=self.request.uri, exitCode=reason.value.exitCode)
if self.errortext:
self._log.error("Errors from CGI {uri}: {errorText}",
uri=self.request.uri, errorText=self.errortext)
if self.handling_headers:
self._log.error("Premature end of headers in {uri}: {headerText}",
uri=self.request.uri, headerText=self.headertext)
self.request.write(
resource.ErrorPage(http.INTERNAL_SERVER_ERROR,
"CGI Script Error",
"Premature end of script headers.").render(self.request))
self.request.unregisterProducer()
self.request.finish()

View file

@ -0,0 +1,443 @@
# -*- test-case-name: twisted.web.test.test_util -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
An assortment of web server-related utilities.
"""
from __future__ import division, absolute_import
import linecache
from twisted.python import urlpath
from twisted.python.compat import _PY3, unicode, nativeString, escape
from twisted.python.reflect import fullyQualifiedName
from twisted.web import resource
from twisted.web.template import TagLoader, XMLString, Element, renderer
from twisted.web.template import flattenString
def _PRE(text):
"""
Wraps <pre> tags around some text and HTML-escape it.
This is here since once twisted.web.html was deprecated it was hard to
migrate the html.PRE from current code to twisted.web.template.
For new code consider using twisted.web.template.
@return: Escaped text wrapped in <pre> tags.
@rtype: C{str}
"""
return '<pre>%s</pre>' % (escape(text),)
def redirectTo(URL, request):
"""
Generate a redirect to the given location.
@param URL: A L{bytes} giving the location to which to redirect.
@type URL: L{bytes}
@param request: The request object to use to generate the redirect.
@type request: L{IRequest<twisted.web.iweb.IRequest>} provider
@raise TypeError: If the type of C{URL} a L{unicode} instead of L{bytes}.
@return: A C{bytes} containing HTML which tries to convince the client agent
to visit the new location even if it doesn't respect the I{FOUND}
response code. This is intended to be returned from a render method,
eg::
def render_GET(self, request):
return redirectTo(b"http://example.com/", request)
"""
if isinstance(URL, unicode) :
raise TypeError("Unicode object not allowed as URL")
request.setHeader(b"Content-Type", b"text/html; charset=utf-8")
request.redirect(URL)
content = """
<html>
<head>
<meta http-equiv=\"refresh\" content=\"0;URL=%(url)s\">
</head>
<body bgcolor=\"#FFFFFF\" text=\"#000000\">
<a href=\"%(url)s\">click here</a>
</body>
</html>
""" % {'url': nativeString(URL)}
if _PY3:
content = content.encode("utf8")
return content
class Redirect(resource.Resource):
isLeaf = True
def __init__(self, url):
resource.Resource.__init__(self)
self.url = url
def render(self, request):
return redirectTo(self.url, request)
def getChild(self, name, request):
return self
class ChildRedirector(Redirect):
isLeaf = 0
def __init__(self, url):
# XXX is this enough?
if ((url.find('://') == -1)
and (not url.startswith('..'))
and (not url.startswith('/'))):
raise ValueError("It seems you've given me a redirect (%s) that is a child of myself! That's not good, it'll cause an infinite redirect." % url)
Redirect.__init__(self, url)
def getChild(self, name, request):
newUrl = self.url
if not newUrl.endswith('/'):
newUrl += '/'
newUrl += name
return ChildRedirector(newUrl)
class ParentRedirect(resource.Resource):
"""
I redirect to URLPath.here().
"""
isLeaf = 1
def render(self, request):
return redirectTo(urlpath.URLPath.fromRequest(request).here(), request)
def getChild(self, request):
return self
class DeferredResource(resource.Resource):
"""
I wrap up a Deferred that will eventually result in a Resource
object.
"""
isLeaf = 1
def __init__(self, d):
resource.Resource.__init__(self)
self.d = d
def getChild(self, name, request):
return self
def render(self, request):
self.d.addCallback(self._cbChild, request).addErrback(
self._ebChild,request)
from twisted.web.server import NOT_DONE_YET
return NOT_DONE_YET
def _cbChild(self, child, request):
request.render(resource.getChildForRequest(child, request))
def _ebChild(self, reason, request):
request.processingFailed(reason)
class _SourceLineElement(Element):
"""
L{_SourceLineElement} is an L{IRenderable} which can render a single line of
source code.
@ivar number: A C{int} giving the line number of the source code to be
rendered.
@ivar source: A C{str} giving the source code to be rendered.
"""
def __init__(self, loader, number, source):
Element.__init__(self, loader)
self.number = number
self.source = source
@renderer
def sourceLine(self, request, tag):
"""
Render the line of source as a child of C{tag}.
"""
return tag(self.source.replace(' ', u' \N{NO-BREAK SPACE}'))
@renderer
def lineNumber(self, request, tag):
"""
Render the line number as a child of C{tag}.
"""
return tag(str(self.number))
class _SourceFragmentElement(Element):
"""
L{_SourceFragmentElement} is an L{IRenderable} which can render several lines
of source code near the line number of a particular frame object.
@ivar frame: A L{Failure<twisted.python.failure.Failure>}-style frame object
for which to load a source line to render. This is really a tuple
holding some information from a frame object. See
L{Failure.frames<twisted.python.failure.Failure>} for specifics.
"""
def __init__(self, loader, frame):
Element.__init__(self, loader)
self.frame = frame
def _getSourceLines(self):
"""
Find the source line references by C{self.frame} and yield, in source
line order, it and the previous and following lines.
@return: A generator which yields two-tuples. Each tuple gives a source
line number and the contents of that source line.
"""
filename = self.frame[1]
lineNumber = self.frame[2]
for snipLineNumber in range(lineNumber - 1, lineNumber + 2):
yield (snipLineNumber,
linecache.getline(filename, snipLineNumber).rstrip())
@renderer
def sourceLines(self, request, tag):
"""
Render the source line indicated by C{self.frame} and several
surrounding lines. The active line will be given a I{class} of
C{"snippetHighlightLine"}. Other lines will be given a I{class} of
C{"snippetLine"}.
"""
for (lineNumber, sourceLine) in self._getSourceLines():
newTag = tag.clone()
if lineNumber == self.frame[2]:
cssClass = "snippetHighlightLine"
else:
cssClass = "snippetLine"
loader = TagLoader(newTag(**{"class": cssClass}))
yield _SourceLineElement(loader, lineNumber, sourceLine)
class _FrameElement(Element):
"""
L{_FrameElement} is an L{IRenderable} which can render details about one
frame from a L{Failure<twisted.python.failure.Failure>}.
@ivar frame: A L{Failure<twisted.python.failure.Failure>}-style frame object
for which to load a source line to render. This is really a tuple
holding some information from a frame object. See
L{Failure.frames<twisted.python.failure.Failure>} for specifics.
"""
def __init__(self, loader, frame):
Element.__init__(self, loader)
self.frame = frame
@renderer
def filename(self, request, tag):
"""
Render the name of the file this frame references as a child of C{tag}.
"""
return tag(self.frame[1])
@renderer
def lineNumber(self, request, tag):
"""
Render the source line number this frame references as a child of
C{tag}.
"""
return tag(str(self.frame[2]))
@renderer
def function(self, request, tag):
"""
Render the function name this frame references as a child of C{tag}.
"""
return tag(self.frame[0])
@renderer
def source(self, request, tag):
"""
Render the source code surrounding the line this frame references,
replacing C{tag}.
"""
return _SourceFragmentElement(TagLoader(tag), self.frame)
class _StackElement(Element):
"""
L{_StackElement} renders an L{IRenderable} which can render a list of frames.
"""
def __init__(self, loader, stackFrames):
Element.__init__(self, loader)
self.stackFrames = stackFrames
@renderer
def frames(self, request, tag):
"""
Render the list of frames in this L{_StackElement}, replacing C{tag}.
"""
return [
_FrameElement(TagLoader(tag.clone()), frame)
for frame
in self.stackFrames]
class FailureElement(Element):
"""
L{FailureElement} is an L{IRenderable} which can render detailed information
about a L{Failure<twisted.python.failure.Failure>}.
@ivar failure: The L{Failure<twisted.python.failure.Failure>} instance which
will be rendered.
@since: 12.1
"""
loader = XMLString("""
<div xmlns:t="http://twistedmatrix.com/ns/twisted.web.template/0.1">
<style type="text/css">
div.error {
color: red;
font-family: Verdana, Arial, helvetica, sans-serif;
font-weight: bold;
}
div {
font-family: Verdana, Arial, helvetica, sans-serif;
}
div.stackTrace {
}
div.frame {
padding: 1em;
background: white;
border-bottom: thin black dashed;
}
div.frame:first-child {
padding: 1em;
background: white;
border-top: thin black dashed;
border-bottom: thin black dashed;
}
div.location {
}
span.function {
font-weight: bold;
font-family: "Courier New", courier, monospace;
}
div.snippet {
margin-bottom: 0.5em;
margin-left: 1em;
background: #FFFFDD;
}
div.snippetHighlightLine {
color: red;
}
span.code {
font-family: "Courier New", courier, monospace;
}
</style>
<div class="error">
<span t:render="type" />: <span t:render="value" />
</div>
<div class="stackTrace" t:render="traceback">
<div class="frame" t:render="frames">
<div class="location">
<span t:render="filename" />:<span t:render="lineNumber" /> in
<span class="function" t:render="function" />
</div>
<div class="snippet" t:render="source">
<div t:render="sourceLines">
<span class="lineno" t:render="lineNumber" />
<code class="code" t:render="sourceLine" />
</div>
</div>
</div>
</div>
<div class="error">
<span t:render="type" />: <span t:render="value" />
</div>
</div>
""")
def __init__(self, failure, loader=None):
Element.__init__(self, loader)
self.failure = failure
@renderer
def type(self, request, tag):
"""
Render the exception type as a child of C{tag}.
"""
return tag(fullyQualifiedName(self.failure.type))
@renderer
def value(self, request, tag):
"""
Render the exception value as a child of C{tag}.
"""
return tag(unicode(self.failure.value).encode('utf8'))
@renderer
def traceback(self, request, tag):
"""
Render all the frames in the wrapped
L{Failure<twisted.python.failure.Failure>}'s traceback stack, replacing
C{tag}.
"""
return _StackElement(TagLoader(tag), self.failure.frames)
def formatFailure(myFailure):
"""
Construct an HTML representation of the given failure.
Consider using L{FailureElement} instead.
@type myFailure: L{Failure<twisted.python.failure.Failure>}
@rtype: C{bytes}
@return: A string containing the HTML representation of the given failure.
"""
result = []
flattenString(None, FailureElement(myFailure)).addBoth(result.append)
if isinstance(result[0], bytes):
# Ensure the result string is all ASCII, for compatibility with the
# default encoding expected by browsers.
return result[0].decode('utf-8').encode('ascii', 'xmlcharrefreplace')
result[0].raiseException()
__all__ = [
"redirectTo", "Redirect", "ChildRedirector", "ParentRedirect",
"DeferredResource", "FailureElement", "formatFailure"]

View file

@ -0,0 +1,138 @@
# -*- test-case-name: twisted.web.
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
I am a virtual hosts implementation.
"""
from __future__ import division, absolute_import
# Twisted Imports
from twisted.python import roots
from twisted.web import resource
class VirtualHostCollection(roots.Homogenous):
"""Wrapper for virtual hosts collection.
This exists for configuration purposes.
"""
entityType = resource.Resource
def __init__(self, nvh):
self.nvh = nvh
def listStaticEntities(self):
return self.nvh.hosts.items()
def getStaticEntity(self, name):
return self.nvh.hosts.get(self)
def reallyPutEntity(self, name, entity):
self.nvh.addHost(name, entity)
def delEntity(self, name):
self.nvh.removeHost(name)
class NameVirtualHost(resource.Resource):
"""I am a resource which represents named virtual hosts.
"""
default = None
def __init__(self):
"""Initialize.
"""
resource.Resource.__init__(self)
self.hosts = {}
def listStaticEntities(self):
return resource.Resource.listStaticEntities(self) + [("Virtual Hosts", VirtualHostCollection(self))]
def getStaticEntity(self, name):
if name == "Virtual Hosts":
return VirtualHostCollection(self)
else:
return resource.Resource.getStaticEntity(self, name)
def addHost(self, name, resrc):
"""Add a host to this virtual host.
This will take a host named `name', and map it to a resource
`resrc'. For example, a setup for our virtual hosts would be::
nvh.addHost('divunal.com', divunalDirectory)
nvh.addHost('www.divunal.com', divunalDirectory)
nvh.addHost('twistedmatrix.com', twistedMatrixDirectory)
nvh.addHost('www.twistedmatrix.com', twistedMatrixDirectory)
"""
self.hosts[name] = resrc
def removeHost(self, name):
"""Remove a host."""
del self.hosts[name]
def _getResourceForRequest(self, request):
"""(Internal) Get the appropriate resource for the given host.
"""
hostHeader = request.getHeader(b'host')
if hostHeader == None:
return self.default or resource.NoResource()
else:
host = hostHeader.lower().split(b':', 1)[0]
return (self.hosts.get(host, self.default)
or resource.NoResource("host %s not in vhost map" % repr(host)))
def render(self, request):
"""Implementation of resource.Resource's render method.
"""
resrc = self._getResourceForRequest(request)
return resrc.render(request)
def getChild(self, path, request):
"""Implementation of resource.Resource's getChild method.
"""
resrc = self._getResourceForRequest(request)
if resrc.isLeaf:
request.postpath.insert(0,request.prepath.pop(-1))
return resrc
else:
return resrc.getChildWithDefault(path, request)
class _HostResource(resource.Resource):
def getChild(self, path, request):
if b':' in path:
host, port = path.split(b':', 1)
port = int(port)
else:
host, port = path, 80
request.setHost(host, port)
prefixLen = (3 + request.isSecure() + 4 + len(path) +
len(request.prepath[-3]))
request.path = b'/' + b'/'.join(request.postpath)
request.uri = request.uri[prefixLen:]
del request.prepath[:3]
return request.site.getResourceFor(request)
class VHostMonsterResource(resource.Resource):
"""
Use this to be able to record the hostname and method (http vs. https)
in the URL without disturbing your web site. If you put this resource
in a URL http://foo.com/bar then requests to
http://foo.com/bar/http/baz.com/something will be equivalent to
http://foo.com/something, except that the hostname the request will
appear to be accessing will be "baz.com". So if "baz.com" is redirecting
all requests for to foo.com, while foo.com is inaccessible from the outside,
then redirect and url generation will work correctly
"""
def getChild(self, path, request):
if path == b'http':
request.isSecure = lambda: 0
elif path == b'https':
request.isSecure = lambda: 1
return _HostResource()

View file

@ -0,0 +1,596 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
An implementation of
U{Python Web Server Gateway Interface v1.0.1<http://www.python.org/dev/peps/pep-3333/>}.
"""
__metaclass__ = type
from sys import exc_info
from warnings import warn
from zope.interface import implementer
from twisted.internet.threads import blockingCallFromThread
from twisted.python.compat import reraise, Sequence
from twisted.python.failure import Failure
from twisted.web.resource import IResource
from twisted.web.server import NOT_DONE_YET
from twisted.web.http import INTERNAL_SERVER_ERROR
from twisted.logger import Logger
# PEP-3333 -- which has superseded PEP-333 -- states that, in both Python 2
# and Python 3, text strings MUST be represented using the platform's native
# string type, limited to characters defined in ISO-8859-1. Byte strings are
# used only for values read from wsgi.input, passed to write() or yielded by
# the application.
#
# Put another way:
#
# - In Python 2, all text strings and binary data are of type str/bytes and
# NEVER of type unicode. Whether the strings contain binary data or
# ISO-8859-1 text depends on context.
#
# - In Python 3, all text strings are of type str, and all binary data are of
# type bytes. Text MUST always be limited to that which can be encoded as
# ISO-8859-1, U+0000 to U+00FF inclusive.
#
# The following pair of functions -- _wsgiString() and _wsgiStringToBytes() --
# are used to make Twisted's WSGI support compliant with the standard.
if str is bytes:
def _wsgiString(string): # Python 2.
"""
Convert C{string} to an ISO-8859-1 byte string, if it is not already.
@type string: C{str}/C{bytes} or C{unicode}
@rtype: C{str}/C{bytes}
@raise UnicodeEncodeError: If C{string} contains non-ISO-8859-1 chars.
"""
if isinstance(string, str):
return string
else:
return string.encode('iso-8859-1')
def _wsgiStringToBytes(string): # Python 2.
"""
Return C{string} as is; a WSGI string is a byte string in Python 2.
@type string: C{str}/C{bytes}
@rtype: C{str}/C{bytes}
"""
return string
else:
def _wsgiString(string): # Python 3.
"""
Convert C{string} to a WSGI "bytes-as-unicode" string.
If it's a byte string, decode as ISO-8859-1. If it's a Unicode string,
round-trip it to bytes and back using ISO-8859-1 as the encoding.
@type string: C{str} or C{bytes}
@rtype: C{str}
@raise UnicodeEncodeError: If C{string} contains non-ISO-8859-1 chars.
"""
if isinstance(string, str):
return string.encode("iso-8859-1").decode('iso-8859-1')
else:
return string.decode("iso-8859-1")
def _wsgiStringToBytes(string): # Python 3.
"""
Convert C{string} from a WSGI "bytes-as-unicode" string to an
ISO-8859-1 byte string.
@type string: C{str}
@rtype: C{bytes}
@raise UnicodeEncodeError: If C{string} contains non-ISO-8859-1 chars.
"""
return string.encode("iso-8859-1")
class _ErrorStream:
"""
File-like object instances of which are used as the value for the
C{'wsgi.errors'} key in the C{environ} dictionary passed to the application
object.
This simply passes writes on to L{logging<twisted.logger>} system as
error events from the C{'wsgi'} system. In the future, it may be desirable
to expose more information in the events it logs, such as the application
object which generated the message.
"""
_log = Logger()
def write(self, data):
"""
Generate an event for the logging system with the given bytes as the
message.
This is called in a WSGI application thread, not the I/O thread.
@type data: str
@raise TypeError: On Python 3, if C{data} is not a native string. On
Python 2 a warning will be issued.
"""
if not isinstance(data, str):
if str is bytes:
warn("write() argument should be str, not %r (%s)" % (
data, type(data).__name__), category=UnicodeWarning)
else:
raise TypeError(
"write() argument must be str, not %r (%s)"
% (data, type(data).__name__))
# Note that in old style, message was a tuple. logger._legacy
# will overwrite this value if it is not properly formatted here.
self._log.error(
data,
system='wsgi',
isError=True,
message=(data,)
)
def writelines(self, iovec):
"""
Join the given lines and pass them to C{write} to be handled in the
usual way.
This is called in a WSGI application thread, not the I/O thread.
@param iovec: A C{list} of C{'\\n'}-terminated C{str} which will be
logged.
@raise TypeError: On Python 3, if C{iovec} contains any non-native
strings. On Python 2 a warning will be issued.
"""
self.write(''.join(iovec))
def flush(self):
"""
Nothing is buffered, so flushing does nothing. This method is required
to exist by PEP 333, though.
This is called in a WSGI application thread, not the I/O thread.
"""
class _InputStream:
"""
File-like object instances of which are used as the value for the
C{'wsgi.input'} key in the C{environ} dictionary passed to the application
object.
This only exists to make the handling of C{readline(-1)} consistent across
different possible underlying file-like object implementations. The other
supported methods pass through directly to the wrapped object.
"""
def __init__(self, input):
"""
Initialize the instance.
This is called in the I/O thread, not a WSGI application thread.
"""
self._wrapped = input
def read(self, size=None):
"""
Pass through to the underlying C{read}.
This is called in a WSGI application thread, not the I/O thread.
"""
# Avoid passing None because cStringIO and file don't like it.
if size is None:
return self._wrapped.read()
return self._wrapped.read(size)
def readline(self, size=None):
"""
Pass through to the underlying C{readline}, with a size of C{-1} replaced
with a size of L{None}.
This is called in a WSGI application thread, not the I/O thread.
"""
# Check for -1 because StringIO doesn't handle it correctly. Check for
# None because files and tempfiles don't accept that.
if size == -1 or size is None:
return self._wrapped.readline()
return self._wrapped.readline(size)
def readlines(self, size=None):
"""
Pass through to the underlying C{readlines}.
This is called in a WSGI application thread, not the I/O thread.
"""
# Avoid passing None because cStringIO and file don't like it.
if size is None:
return self._wrapped.readlines()
return self._wrapped.readlines(size)
def __iter__(self):
"""
Pass through to the underlying C{__iter__}.
This is called in a WSGI application thread, not the I/O thread.
"""
return iter(self._wrapped)
class _WSGIResponse:
"""
Helper for L{WSGIResource} which drives the WSGI application using a
threadpool and hooks it up to the L{http.Request}.
@ivar started: A L{bool} indicating whether or not the response status and
headers have been written to the request yet. This may only be read or
written in the WSGI application thread.
@ivar reactor: An L{IReactorThreads} provider which is used to call methods
on the request in the I/O thread.
@ivar threadpool: A L{ThreadPool} which is used to call the WSGI
application object in a non-I/O thread.
@ivar application: The WSGI application object.
@ivar request: The L{http.Request} upon which the WSGI environment is
based and to which the application's output will be sent.
@ivar environ: The WSGI environment L{dict}.
@ivar status: The HTTP response status L{str} supplied to the WSGI
I{start_response} callable by the application.
@ivar headers: A list of HTTP response headers supplied to the WSGI
I{start_response} callable by the application.
@ivar _requestFinished: A flag which indicates whether it is possible to
generate more response data or not. This is L{False} until
L{http.Request.notifyFinish} tells us the request is done,
then L{True}.
"""
_requestFinished = False
_log = Logger()
def __init__(self, reactor, threadpool, application, request):
self.started = False
self.reactor = reactor
self.threadpool = threadpool
self.application = application
self.request = request
self.request.notifyFinish().addBoth(self._finished)
if request.prepath:
scriptName = b'/' + b'/'.join(request.prepath)
else:
scriptName = b''
if request.postpath:
pathInfo = b'/' + b'/'.join(request.postpath)
else:
pathInfo = b''
parts = request.uri.split(b'?', 1)
if len(parts) == 1:
queryString = b''
else:
queryString = parts[1]
# All keys and values need to be native strings, i.e. of type str in
# *both* Python 2 and Python 3, so says PEP-3333.
self.environ = {
'REQUEST_METHOD': _wsgiString(request.method),
'REMOTE_ADDR': _wsgiString(request.getClientAddress().host),
'SCRIPT_NAME': _wsgiString(scriptName),
'PATH_INFO': _wsgiString(pathInfo),
'QUERY_STRING': _wsgiString(queryString),
'CONTENT_TYPE': _wsgiString(
request.getHeader(b'content-type') or ''),
'CONTENT_LENGTH': _wsgiString(
request.getHeader(b'content-length') or ''),
'SERVER_NAME': _wsgiString(request.getRequestHostname()),
'SERVER_PORT': _wsgiString(str(request.getHost().port)),
'SERVER_PROTOCOL': _wsgiString(request.clientproto)}
# The application object is entirely in control of response headers;
# disable the default Content-Type value normally provided by
# twisted.web.server.Request.
self.request.defaultContentType = None
for name, values in request.requestHeaders.getAllRawHeaders():
name = 'HTTP_' + _wsgiString(name).upper().replace('-', '_')
# It might be preferable for http.HTTPChannel to clear out
# newlines.
self.environ[name] = ','.join(
_wsgiString(v) for v in values).replace('\n', ' ')
self.environ.update({
'wsgi.version': (1, 0),
'wsgi.url_scheme': request.isSecure() and 'https' or 'http',
'wsgi.run_once': False,
'wsgi.multithread': True,
'wsgi.multiprocess': False,
'wsgi.errors': _ErrorStream(),
# Attend: request.content was owned by the I/O thread up until
# this point. By wrapping it and putting the result into the
# environment dictionary, it is effectively being given to
# another thread. This means that whatever it is, it has to be
# safe to access it from two different threads. The access
# *should* all be serialized (first the I/O thread writes to
# it, then the WSGI thread reads from it, then the I/O thread
# closes it). However, since the request is made available to
# arbitrary application code during resource traversal, it's
# possible that some other code might decide to use it in the
# I/O thread concurrently with its use in the WSGI thread.
# More likely than not, this will break. This seems like an
# unlikely possibility to me, but if it is to be allowed,
# something here needs to change. -exarkun
'wsgi.input': _InputStream(request.content)})
def _finished(self, ignored):
"""
Record the end of the response generation for the request being
serviced.
"""
self._requestFinished = True
def startResponse(self, status, headers, excInfo=None):
"""
The WSGI I{start_response} callable. The given values are saved until
they are needed to generate the response.
This will be called in a non-I/O thread.
"""
if self.started and excInfo is not None:
reraise(excInfo[1], excInfo[2])
# PEP-3333 mandates that status should be a native string. In practice
# this is mandated by Twisted's HTTP implementation too, so we enforce
# on both Python 2 and Python 3.
if not isinstance(status, str):
raise TypeError(
"status must be str, not %r (%s)"
% (status, type(status).__name__))
# PEP-3333 mandates that headers should be a plain list, but in
# practice we work with any sequence type and only warn when it's not
# a plain list.
if isinstance(headers, list):
pass # This is okay.
elif isinstance(headers, Sequence):
warn("headers should be a list, not %r (%s)" % (
headers, type(headers).__name__), category=RuntimeWarning)
else:
raise TypeError(
"headers must be a list, not %r (%s)"
% (headers, type(headers).__name__))
# PEP-3333 mandates that each header should be a (str, str) tuple, but
# in practice we work with any sequence type and only warn when it's
# not a plain list.
for header in headers:
if isinstance(header, tuple):
pass # This is okay.
elif isinstance(header, Sequence):
warn("header should be a (str, str) tuple, not %r (%s)" % (
header, type(header).__name__), category=RuntimeWarning)
else:
raise TypeError(
"header must be a (str, str) tuple, not %r (%s)"
% (header, type(header).__name__))
# However, the sequence MUST contain only 2 elements.
if len(header) != 2:
raise TypeError(
"header must be a (str, str) tuple, not %r"
% (header, ))
# Both elements MUST be native strings. Non-native strings will be
# rejected by the underlying HTTP machinery in any case, but we
# reject them here in order to provide a more informative error.
for elem in header:
if not isinstance(elem, str):
raise TypeError(
"header must be (str, str) tuple, not %r"
% (header, ))
self.status = status
self.headers = headers
return self.write
def write(self, data):
"""
The WSGI I{write} callable returned by the I{start_response} callable.
The given bytes will be written to the response body, possibly flushing
the status and headers first.
This will be called in a non-I/O thread.
"""
# PEP-3333 states:
#
# The server or gateway must transmit the yielded bytestrings to the
# client in an unbuffered fashion, completing the transmission of
# each bytestring before requesting another one.
#
# This write() method is used for the imperative and (indirectly) for
# the more familiar iterable-of-bytestrings WSGI mechanism. It uses
# C{blockingCallFromThread} to schedule writes. This allows exceptions
# to propagate up from the underlying HTTP implementation. However,
# that underlying implementation does not, as yet, provide any way to
# know if the written data has been transmitted, so this method
# violates the above part of PEP-3333.
#
# PEP-3333 also says that a server may:
#
# Use a different thread to ensure that the block continues to be
# transmitted while the application produces the next block.
#
# Which suggests that this is actually compliant with PEP-3333,
# because writes are done in the reactor thread.
#
# However, providing some back-pressure may nevertheless be a Good
# Thing at some point in the future.
def wsgiWrite(started):
if not started:
self._sendResponseHeaders()
self.request.write(data)
try:
return blockingCallFromThread(
self.reactor, wsgiWrite, self.started)
finally:
self.started = True
def _sendResponseHeaders(self):
"""
Set the response code and response headers on the request object, but
do not flush them. The caller is responsible for doing a write in
order for anything to actually be written out in response to the
request.
This must be called in the I/O thread.
"""
code, message = self.status.split(None, 1)
code = int(code)
self.request.setResponseCode(code, _wsgiStringToBytes(message))
for name, value in self.headers:
# Don't allow the application to control these required headers.
if name.lower() not in ('server', 'date'):
self.request.responseHeaders.addRawHeader(
_wsgiStringToBytes(name), _wsgiStringToBytes(value))
def start(self):
"""
Start the WSGI application in the threadpool.
This must be called in the I/O thread.
"""
self.threadpool.callInThread(self.run)
def run(self):
"""
Call the WSGI application object, iterate it, and handle its output.
This must be called in a non-I/O thread (ie, a WSGI application
thread).
"""
try:
appIterator = self.application(self.environ, self.startResponse)
for elem in appIterator:
if elem:
self.write(elem)
if self._requestFinished:
break
close = getattr(appIterator, 'close', None)
if close is not None:
close()
except:
def wsgiError(started, type, value, traceback):
self._log.failure(
"WSGI application error",
failure=Failure(value, type, traceback)
)
if started:
self.request.loseConnection()
else:
self.request.setResponseCode(INTERNAL_SERVER_ERROR)
self.request.finish()
self.reactor.callFromThread(wsgiError, self.started, *exc_info())
else:
def wsgiFinish(started):
if not self._requestFinished:
if not started:
self._sendResponseHeaders()
self.request.finish()
self.reactor.callFromThread(wsgiFinish, self.started)
self.started = True
@implementer(IResource)
class WSGIResource:
"""
An L{IResource} implementation which delegates responsibility for all
resources hierarchically inferior to it to a WSGI application.
@ivar _reactor: An L{IReactorThreads} provider which will be passed on to
L{_WSGIResponse} to schedule calls in the I/O thread.
@ivar _threadpool: A L{ThreadPool} which will be passed on to
L{_WSGIResponse} to run the WSGI application object.
@ivar _application: The WSGI application object.
"""
# Further resource segments are left up to the WSGI application object to
# handle.
isLeaf = True
def __init__(self, reactor, threadpool, application):
self._reactor = reactor
self._threadpool = threadpool
self._application = application
def render(self, request):
"""
Turn the request into the appropriate C{environ} C{dict} suitable to be
passed to the WSGI application object and then pass it on.
The WSGI application object is given almost complete control of the
rendering process. C{NOT_DONE_YET} will always be returned in order
and response completion will be dictated by the application object, as
will the status, headers, and the response body.
"""
response = _WSGIResponse(
self._reactor, self._threadpool, self._application, request)
response.start()
return NOT_DONE_YET
def getChildWithDefault(self, name, request):
"""
Reject attempts to retrieve a child resource. All path segments beyond
the one which refers to this resource are handled by the WSGI
application object.
"""
raise RuntimeError("Cannot get IResource children from WSGIResource")
def putChild(self, path, child):
"""
Reject attempts to add a child resource to this resource. The WSGI
application object handles all path segments beneath this resource, so
L{IResource} children can never be found.
"""
raise RuntimeError("Cannot put IResource children under WSGIResource")
__all__ = ['WSGIResource']

View file

@ -0,0 +1,591 @@
# -*- test-case-name: twisted.web.test.test_xmlrpc -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
A generic resource for publishing objects via XML-RPC.
Maintainer: Itamar Shtull-Trauring
@var Fault: See L{xmlrpclib.Fault}
@type Fault: L{xmlrpclib.Fault}
"""
from __future__ import division, absolute_import
from twisted.python.compat import _PY3, intToBytes, nativeString, urllib_parse
from twisted.python.compat import unicode
# System Imports
import base64
if _PY3:
import xmlrpc.client as xmlrpclib
else:
import xmlrpclib
# Sibling Imports
from twisted.web import resource, server, http
from twisted.internet import defer, protocol, reactor
from twisted.python import reflect, failure
from twisted.logger import Logger
# These are deprecated, use the class level definitions
NOT_FOUND = 8001
FAILURE = 8002
# Useful so people don't need to import xmlrpclib directly
Fault = xmlrpclib.Fault
Binary = xmlrpclib.Binary
Boolean = xmlrpclib.Boolean
DateTime = xmlrpclib.DateTime
def withRequest(f):
"""
Decorator to cause the request to be passed as the first argument
to the method.
If an I{xmlrpc_} method is wrapped with C{withRequest}, the
request object is passed as the first argument to that method.
For example::
@withRequest
def xmlrpc_echo(self, request, s):
return s
@since: 10.2
"""
f.withRequest = True
return f
class NoSuchFunction(Fault):
"""
There is no function by the given name.
"""
class Handler:
"""
Handle a XML-RPC request and store the state for a request in progress.
Override the run() method and return result using self.result,
a Deferred.
We require this class since we're not using threads, so we can't
encapsulate state in a running function if we're going to have
to wait for results.
For example, lets say we want to authenticate against twisted.cred,
run a LDAP query and then pass its result to a database query, all
as a result of a single XML-RPC command. We'd use a Handler instance
to store the state of the running command.
"""
def __init__(self, resource, *args):
self.resource = resource # the XML-RPC resource we are connected to
self.result = defer.Deferred()
self.run(*args)
def run(self, *args):
# event driven equivalent of 'raise UnimplementedError'
self.result.errback(
NotImplementedError("Implement run() in subclasses"))
class XMLRPC(resource.Resource):
"""
A resource that implements XML-RPC.
You probably want to connect this to '/RPC2'.
Methods published can return XML-RPC serializable results, Faults,
Binary, Boolean, DateTime, Deferreds, or Handler instances.
By default methods beginning with 'xmlrpc_' are published.
Sub-handlers for prefixed methods (e.g., system.listMethods)
can be added with putSubHandler. By default, prefixes are
separated with a '.'. Override self.separator to change this.
@ivar allowNone: Permit XML translating of Python constant None.
@type allowNone: C{bool}
@ivar useDateTime: Present C{datetime} values as C{datetime.datetime}
objects?
@type useDateTime: C{bool}
"""
# Error codes for Twisted, if they conflict with yours then
# modify them at runtime.
NOT_FOUND = 8001
FAILURE = 8002
isLeaf = 1
separator = '.'
allowedMethods = (b'POST',)
_log = Logger()
def __init__(self, allowNone=False, useDateTime=False):
resource.Resource.__init__(self)
self.subHandlers = {}
self.allowNone = allowNone
self.useDateTime = useDateTime
def __setattr__(self, name, value):
self.__dict__[name] = value
def putSubHandler(self, prefix, handler):
self.subHandlers[prefix] = handler
def getSubHandler(self, prefix):
return self.subHandlers.get(prefix, None)
def getSubHandlerPrefixes(self):
return list(self.subHandlers.keys())
def render_POST(self, request):
request.content.seek(0, 0)
request.setHeader(b"content-type", b"text/xml; charset=utf-8")
try:
args, functionPath = xmlrpclib.loads(request.content.read(),
use_datetime=self.useDateTime)
except Exception as e:
f = Fault(self.FAILURE, "Can't deserialize input: %s" % (e,))
self._cbRender(f, request)
else:
try:
function = self.lookupProcedure(functionPath)
except Fault as f:
self._cbRender(f, request)
else:
# Use this list to track whether the response has failed or not.
# This will be used later on to decide if the result of the
# Deferred should be written out and Request.finish called.
responseFailed = []
request.notifyFinish().addErrback(responseFailed.append)
if getattr(function, 'withRequest', False):
d = defer.maybeDeferred(function, request, *args)
else:
d = defer.maybeDeferred(function, *args)
d.addErrback(self._ebRender)
d.addCallback(self._cbRender, request, responseFailed)
return server.NOT_DONE_YET
def _cbRender(self, result, request, responseFailed=None):
if responseFailed:
return
if isinstance(result, Handler):
result = result.result
if not isinstance(result, Fault):
result = (result,)
try:
try:
content = xmlrpclib.dumps(
result, methodresponse=True,
allow_none=self.allowNone)
except Exception as e:
f = Fault(self.FAILURE, "Can't serialize output: %s" % (e,))
content = xmlrpclib.dumps(f, methodresponse=True,
allow_none=self.allowNone)
if isinstance(content, unicode):
content = content.encode('utf8')
request.setHeader(
b"content-length", intToBytes(len(content)))
request.write(content)
except:
self._log.failure('')
request.finish()
def _ebRender(self, failure):
if isinstance(failure.value, Fault):
return failure.value
self._log.failure('', failure)
return Fault(self.FAILURE, "error")
def lookupProcedure(self, procedurePath):
"""
Given a string naming a procedure, return a callable object for that
procedure or raise NoSuchFunction.
The returned object will be called, and should return the result of the
procedure, a Deferred, or a Fault instance.
Override in subclasses if you want your own policy. The base
implementation that given C{'foo'}, C{self.xmlrpc_foo} will be returned.
If C{procedurePath} contains C{self.separator}, the sub-handler for the
initial prefix is used to search for the remaining path.
If you override C{lookupProcedure}, you may also want to override
C{listProcedures} to accurately report the procedures supported by your
resource, so that clients using the I{system.listMethods} procedure
receive accurate results.
@since: 11.1
"""
if procedurePath.find(self.separator) != -1:
prefix, procedurePath = procedurePath.split(self.separator, 1)
handler = self.getSubHandler(prefix)
if handler is None:
raise NoSuchFunction(self.NOT_FOUND,
"no such subHandler %s" % prefix)
return handler.lookupProcedure(procedurePath)
f = getattr(self, "xmlrpc_%s" % procedurePath, None)
if not f:
raise NoSuchFunction(self.NOT_FOUND,
"procedure %s not found" % procedurePath)
elif not callable(f):
raise NoSuchFunction(self.NOT_FOUND,
"procedure %s not callable" % procedurePath)
else:
return f
def listProcedures(self):
"""
Return a list of the names of all xmlrpc procedures.
@since: 11.1
"""
return reflect.prefixedMethodNames(self.__class__, 'xmlrpc_')
class XMLRPCIntrospection(XMLRPC):
"""
Implement the XML-RPC Introspection API.
By default, the methodHelp method returns the 'help' method attribute,
if it exists, otherwise the __doc__ method attribute, if it exists,
otherwise the empty string.
To enable the methodSignature method, add a 'signature' method attribute
containing a list of lists. See methodSignature's documentation for the
format. Note the type strings should be XML-RPC types, not Python types.
"""
def __init__(self, parent):
"""
Implement Introspection support for an XMLRPC server.
@param parent: the XMLRPC server to add Introspection support to.
@type parent: L{XMLRPC}
"""
XMLRPC.__init__(self)
self._xmlrpc_parent = parent
def xmlrpc_listMethods(self):
"""
Return a list of the method names implemented by this server.
"""
functions = []
todo = [(self._xmlrpc_parent, '')]
while todo:
obj, prefix = todo.pop(0)
functions.extend([prefix + name for name in obj.listProcedures()])
todo.extend([ (obj.getSubHandler(name),
prefix + name + obj.separator)
for name in obj.getSubHandlerPrefixes() ])
return functions
xmlrpc_listMethods.signature = [['array']]
def xmlrpc_methodHelp(self, method):
"""
Return a documentation string describing the use of the given method.
"""
method = self._xmlrpc_parent.lookupProcedure(method)
return (getattr(method, 'help', None)
or getattr(method, '__doc__', None) or '')
xmlrpc_methodHelp.signature = [['string', 'string']]
def xmlrpc_methodSignature(self, method):
"""
Return a list of type signatures.
Each type signature is a list of the form [rtype, type1, type2, ...]
where rtype is the return type and typeN is the type of the Nth
argument. If no signature information is available, the empty
string is returned.
"""
method = self._xmlrpc_parent.lookupProcedure(method)
return getattr(method, 'signature', None) or ''
xmlrpc_methodSignature.signature = [['array', 'string'],
['string', 'string']]
def addIntrospection(xmlrpc):
"""
Add Introspection support to an XMLRPC server.
@param parent: the XMLRPC server to add Introspection support to.
@type parent: L{XMLRPC}
"""
xmlrpc.putSubHandler('system', XMLRPCIntrospection(xmlrpc))
class QueryProtocol(http.HTTPClient):
def connectionMade(self):
self._response = None
self.sendCommand(b'POST', self.factory.path)
self.sendHeader(b'User-Agent', b'Twisted/XMLRPClib')
self.sendHeader(b'Host', self.factory.host)
self.sendHeader(b'Content-type', b'text/xml; charset=utf-8')
payload = self.factory.payload
self.sendHeader(b'Content-length', intToBytes(len(payload)))
if self.factory.user:
auth = b':'.join([self.factory.user, self.factory.password])
authHeader = b''.join([b'Basic ', base64.b64encode(auth)])
self.sendHeader(b'Authorization', authHeader)
self.endHeaders()
self.transport.write(payload)
def handleStatus(self, version, status, message):
if status != b'200':
self.factory.badStatus(status, message)
def handleResponse(self, contents):
"""
Handle the XML-RPC response received from the server.
Specifically, disconnect from the server and store the XML-RPC
response so that it can be properly handled when the disconnect is
finished.
"""
self.transport.loseConnection()
self._response = contents
def connectionLost(self, reason):
"""
The connection to the server has been lost.
If we have a full response from the server, then parse it and fired a
Deferred with the return value or C{Fault} that the server gave us.
"""
http.HTTPClient.connectionLost(self, reason)
if self._response is not None:
response, self._response = self._response, None
self.factory.parseResponse(response)
payloadTemplate = """<?xml version="1.0"?>
<methodCall>
<methodName>%s</methodName>
%s
</methodCall>
"""
class _QueryFactory(protocol.ClientFactory):
"""
XML-RPC Client Factory
@ivar path: The path portion of the URL to which to post method calls.
@type path: L{bytes}
@ivar host: The value to use for the Host HTTP header.
@type host: L{bytes}
@ivar user: The username with which to authenticate with the server
when making calls.
@type user: L{bytes} or L{None}
@ivar password: The password with which to authenticate with the server
when making calls.
@type password: L{bytes} or L{None}
@ivar useDateTime: Accept datetime values as datetime.datetime objects.
also passed to the underlying xmlrpclib implementation. Defaults to
C{False}.
@type useDateTime: C{bool}
"""
deferred = None
protocol = QueryProtocol
def __init__(self, path, host, method, user=None, password=None,
allowNone=False, args=(), canceller=None, useDateTime=False):
"""
@param method: The name of the method to call.
@type method: C{str}
@param allowNone: allow the use of None values in parameters. It's
passed to the underlying xmlrpclib implementation. Defaults to
C{False}.
@type allowNone: C{bool} or L{None}
@param args: the arguments to pass to the method.
@type args: C{tuple}
@param canceller: A 1-argument callable passed to the deferred as the
canceller callback.
@type canceller: callable or L{None}
"""
self.path, self.host = path, host
self.user, self.password = user, password
self.payload = payloadTemplate % (method,
xmlrpclib.dumps(args, allow_none=allowNone))
if isinstance(self.payload, unicode):
self.payload = self.payload.encode('utf8')
self.deferred = defer.Deferred(canceller)
self.useDateTime = useDateTime
def parseResponse(self, contents):
if not self.deferred:
return
try:
response = xmlrpclib.loads(contents,
use_datetime=self.useDateTime)[0][0]
except:
deferred, self.deferred = self.deferred, None
deferred.errback(failure.Failure())
else:
deferred, self.deferred = self.deferred, None
deferred.callback(response)
def clientConnectionLost(self, _, reason):
if self.deferred is not None:
deferred, self.deferred = self.deferred, None
deferred.errback(reason)
clientConnectionFailed = clientConnectionLost
def badStatus(self, status, message):
deferred, self.deferred = self.deferred, None
deferred.errback(ValueError(status, message))
class Proxy:
"""
A Proxy for making remote XML-RPC calls.
Pass the URL of the remote XML-RPC server to the constructor.
Use C{proxy.callRemote('foobar', *args)} to call remote method
'foobar' with *args.
@ivar user: The username with which to authenticate with the server
when making calls. If specified, overrides any username information
embedded in C{url}. If not specified, a value may be taken from
C{url} if present.
@type user: L{bytes} or L{None}
@ivar password: The password with which to authenticate with the server
when making calls. If specified, overrides any password information
embedded in C{url}. If not specified, a value may be taken from
C{url} if present.
@type password: L{bytes} or L{None}
@ivar allowNone: allow the use of None values in parameters. It's
passed to the underlying L{xmlrpclib} implementation. Defaults to
C{False}.
@type allowNone: C{bool} or L{None}
@ivar useDateTime: Accept datetime values as datetime.datetime objects.
also passed to the underlying L{xmlrpclib} implementation. Defaults to
C{False}.
@type useDateTime: C{bool}
@ivar connectTimeout: Number of seconds to wait before assuming the
connection has failed.
@type connectTimeout: C{float}
@ivar _reactor: The reactor used to create connections.
@type _reactor: Object providing L{twisted.internet.interfaces.IReactorTCP}
@ivar queryFactory: Object returning a factory for XML-RPC protocol. Mainly
useful for tests.
"""
queryFactory = _QueryFactory
def __init__(self, url, user=None, password=None, allowNone=False,
useDateTime=False, connectTimeout=30.0, reactor=reactor):
"""
@param url: The URL to which to post method calls. Calls will be made
over SSL if the scheme is HTTPS. If netloc contains username or
password information, these will be used to authenticate, as long as
the C{user} and C{password} arguments are not specified.
@type url: L{bytes}
"""
scheme, netloc, path, params, query, fragment = urllib_parse.urlparse(
url)
netlocParts = netloc.split(b'@')
if len(netlocParts) == 2:
userpass = netlocParts.pop(0).split(b':')
self.user = userpass.pop(0)
try:
self.password = userpass.pop(0)
except:
self.password = None
else:
self.user = self.password = None
hostport = netlocParts[0].split(b':')
self.host = hostport.pop(0)
try:
self.port = int(hostport.pop(0))
except:
self.port = None
self.path = path
if self.path in [b'', None]:
self.path = b'/'
self.secure = (scheme == b'https')
if user is not None:
self.user = user
if password is not None:
self.password = password
self.allowNone = allowNone
self.useDateTime = useDateTime
self.connectTimeout = connectTimeout
self._reactor = reactor
def callRemote(self, method, *args):
"""
Call remote XML-RPC C{method} with given arguments.
@return: a L{defer.Deferred} that will fire with the method response,
or a failure if the method failed. Generally, the failure type will
be L{Fault}, but you can also have an C{IndexError} on some buggy
servers giving empty responses.
If the deferred is cancelled before the request completes, the
connection is closed and the deferred will fire with a
L{defer.CancelledError}.
"""
def cancel(d):
factory.deferred = None
connector.disconnect()
factory = self.queryFactory(
self.path, self.host, method, self.user,
self.password, self.allowNone, args, cancel, self.useDateTime)
if self.secure:
from twisted.internet import ssl
connector = self._reactor.connectSSL(
nativeString(self.host), self.port or 443,
factory, ssl.ClientContextFactory(),
timeout=self.connectTimeout)
else:
connector = self._reactor.connectTCP(
nativeString(self.host), self.port or 80, factory,
timeout=self.connectTimeout)
return factory.deferred
__all__ = [
"XMLRPC", "Handler", "NoSuchFunction", "Proxy",
"Fault", "Binary", "Boolean", "DateTime"]