Ausgabe der neuen DB Einträge
This commit is contained in:
parent
bad48e1627
commit
cfbbb9ee3d
2399 changed files with 843193 additions and 43 deletions
|
|
@ -0,0 +1,15 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Twisted Protocols: A collection of internet protocol implementations.
|
||||
"""
|
||||
|
||||
from incremental import Version
|
||||
from twisted.python.deprecate import deprecatedModuleAttribute
|
||||
|
||||
|
||||
deprecatedModuleAttribute(
|
||||
Version('Twisted', 17, 9, 0),
|
||||
"There is no replacement for this module.",
|
||||
"twisted.protocols", "dict")
|
||||
2897
venv/lib/python3.9/site-packages/twisted/protocols/amp.py
Normal file
2897
venv/lib/python3.9/site-packages/twisted/protocols/amp.py
Normal file
File diff suppressed because it is too large
Load diff
953
venv/lib/python3.9/site-packages/twisted/protocols/basic.py
Normal file
953
venv/lib/python3.9/site-packages/twisted/protocols/basic.py
Normal file
|
|
@ -0,0 +1,953 @@
|
|||
# -*- test-case-name: twisted.protocols.test.test_basic -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
|
||||
"""
|
||||
Basic protocols, such as line-oriented, netstring, and int prefixed strings.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division
|
||||
|
||||
# System imports
|
||||
import re
|
||||
from struct import pack, unpack, calcsize
|
||||
from io import BytesIO
|
||||
import math
|
||||
|
||||
from zope.interface import implementer
|
||||
|
||||
# Twisted imports
|
||||
from twisted.python.compat import _PY3
|
||||
from twisted.internet import protocol, defer, interfaces
|
||||
from twisted.python import log
|
||||
|
||||
|
||||
# Unfortunately we cannot use regular string formatting on Python 3; see
|
||||
# http://bugs.python.org/issue3982 for details.
|
||||
if _PY3:
|
||||
def _formatNetstring(data):
|
||||
return b''.join([str(len(data)).encode("ascii"), b':', data, b','])
|
||||
else:
|
||||
def _formatNetstring(data):
|
||||
return b'%d:%s,' % (len(data), data)
|
||||
_formatNetstring.__doc__ = """
|
||||
Convert some C{bytes} into netstring format.
|
||||
|
||||
@param data: C{bytes} that will be reformatted.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
DEBUG = 0
|
||||
|
||||
class NetstringParseError(ValueError):
|
||||
"""
|
||||
The incoming data is not in valid Netstring format.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
class IncompleteNetstring(Exception):
|
||||
"""
|
||||
Not enough data to complete a netstring.
|
||||
"""
|
||||
|
||||
|
||||
class NetstringReceiver(protocol.Protocol):
|
||||
"""
|
||||
A protocol that sends and receives netstrings.
|
||||
|
||||
See U{http://cr.yp.to/proto/netstrings.txt} for the specification of
|
||||
netstrings. Every netstring starts with digits that specify the length
|
||||
of the data. This length specification is separated from the data by
|
||||
a colon. The data is terminated with a comma.
|
||||
|
||||
Override L{stringReceived} to handle received netstrings. This
|
||||
method is called with the netstring payload as a single argument
|
||||
whenever a complete netstring is received.
|
||||
|
||||
Security features:
|
||||
1. Messages are limited in size, useful if you don't want
|
||||
someone sending you a 500MB netstring (change C{self.MAX_LENGTH}
|
||||
to the maximum length you wish to accept).
|
||||
2. The connection is lost if an illegal message is received.
|
||||
|
||||
@ivar MAX_LENGTH: Defines the maximum length of netstrings that can be
|
||||
received.
|
||||
@type MAX_LENGTH: C{int}
|
||||
|
||||
@ivar _LENGTH: A pattern describing all strings that contain a netstring
|
||||
length specification. Examples for length specifications are C{b'0:'},
|
||||
C{b'12:'}, and C{b'179:'}. C{b'007:'} is not a valid length
|
||||
specification, since leading zeros are not allowed.
|
||||
@type _LENGTH: C{re.Match}
|
||||
|
||||
@ivar _LENGTH_PREFIX: A pattern describing all strings that contain
|
||||
the first part of a netstring length specification (without the
|
||||
trailing comma). Examples are '0', '12', and '179'. '007' does not
|
||||
start a netstring length specification, since leading zeros are
|
||||
not allowed.
|
||||
@type _LENGTH_PREFIX: C{re.Match}
|
||||
|
||||
@ivar _PARSING_LENGTH: Indicates that the C{NetstringReceiver} is in
|
||||
the state of parsing the length portion of a netstring.
|
||||
@type _PARSING_LENGTH: C{int}
|
||||
|
||||
@ivar _PARSING_PAYLOAD: Indicates that the C{NetstringReceiver} is in
|
||||
the state of parsing the payload portion (data and trailing comma)
|
||||
of a netstring.
|
||||
@type _PARSING_PAYLOAD: C{int}
|
||||
|
||||
@ivar brokenPeer: Indicates if the connection is still functional
|
||||
@type brokenPeer: C{int}
|
||||
|
||||
@ivar _state: Indicates if the protocol is consuming the length portion
|
||||
(C{PARSING_LENGTH}) or the payload (C{PARSING_PAYLOAD}) of a netstring
|
||||
@type _state: C{int}
|
||||
|
||||
@ivar _remainingData: Holds the chunk of data that has not yet been consumed
|
||||
@type _remainingData: C{string}
|
||||
|
||||
@ivar _payload: Holds the payload portion of a netstring including the
|
||||
trailing comma
|
||||
@type _payload: C{BytesIO}
|
||||
|
||||
@ivar _expectedPayloadSize: Holds the payload size plus one for the trailing
|
||||
comma.
|
||||
@type _expectedPayloadSize: C{int}
|
||||
"""
|
||||
MAX_LENGTH = 99999
|
||||
_LENGTH = re.compile(br'(0|[1-9]\d*)(:)')
|
||||
|
||||
_LENGTH_PREFIX = re.compile(br'(0|[1-9]\d*)$')
|
||||
|
||||
# Some error information for NetstringParseError instances.
|
||||
_MISSING_LENGTH = ("The received netstring does not start with a "
|
||||
"length specification.")
|
||||
_OVERFLOW = ("The length specification of the received netstring "
|
||||
"cannot be represented in Python - it causes an "
|
||||
"OverflowError!")
|
||||
_TOO_LONG = ("The received netstring is longer than the maximum %s "
|
||||
"specified by self.MAX_LENGTH")
|
||||
_MISSING_COMMA = "The received netstring is not terminated by a comma."
|
||||
|
||||
# The following constants are used for determining if the NetstringReceiver
|
||||
# is parsing the length portion of a netstring, or the payload.
|
||||
_PARSING_LENGTH, _PARSING_PAYLOAD = range(2)
|
||||
|
||||
def makeConnection(self, transport):
|
||||
"""
|
||||
Initializes the protocol.
|
||||
"""
|
||||
protocol.Protocol.makeConnection(self, transport)
|
||||
self._remainingData = b""
|
||||
self._currentPayloadSize = 0
|
||||
self._payload = BytesIO()
|
||||
self._state = self._PARSING_LENGTH
|
||||
self._expectedPayloadSize = 0
|
||||
self.brokenPeer = 0
|
||||
|
||||
|
||||
def sendString(self, string):
|
||||
"""
|
||||
Sends a netstring.
|
||||
|
||||
Wraps up C{string} by adding length information and a
|
||||
trailing comma; writes the result to the transport.
|
||||
|
||||
@param string: The string to send. The necessary framing (length
|
||||
prefix, etc) will be added.
|
||||
@type string: C{bytes}
|
||||
"""
|
||||
self.transport.write(_formatNetstring(string))
|
||||
|
||||
|
||||
def dataReceived(self, data):
|
||||
"""
|
||||
Receives some characters of a netstring.
|
||||
|
||||
Whenever a complete netstring is received, this method extracts
|
||||
its payload and calls L{stringReceived} to process it.
|
||||
|
||||
@param data: A chunk of data representing a (possibly partial)
|
||||
netstring
|
||||
@type data: C{bytes}
|
||||
"""
|
||||
self._remainingData += data
|
||||
while self._remainingData:
|
||||
try:
|
||||
self._consumeData()
|
||||
except IncompleteNetstring:
|
||||
break
|
||||
except NetstringParseError:
|
||||
self._handleParseError()
|
||||
break
|
||||
|
||||
|
||||
def stringReceived(self, string):
|
||||
"""
|
||||
Override this for notification when each complete string is received.
|
||||
|
||||
@param string: The complete string which was received with all
|
||||
framing (length prefix, etc) removed.
|
||||
@type string: C{bytes}
|
||||
|
||||
@raise NotImplementedError: because the method has to be implemented
|
||||
by the child class.
|
||||
"""
|
||||
raise NotImplementedError()
|
||||
|
||||
|
||||
def _maxLengthSize(self):
|
||||
"""
|
||||
Calculate and return the string size of C{self.MAX_LENGTH}.
|
||||
|
||||
@return: The size of the string representation for C{self.MAX_LENGTH}
|
||||
@rtype: C{float}
|
||||
"""
|
||||
return math.ceil(math.log10(self.MAX_LENGTH)) + 1
|
||||
|
||||
|
||||
def _consumeData(self):
|
||||
"""
|
||||
Consumes the content of C{self._remainingData}.
|
||||
|
||||
@raise IncompleteNetstring: if C{self._remainingData} does not
|
||||
contain enough data to complete the current netstring.
|
||||
@raise NetstringParseError: if the received data do not
|
||||
form a valid netstring.
|
||||
"""
|
||||
if self._state == self._PARSING_LENGTH:
|
||||
self._consumeLength()
|
||||
self._prepareForPayloadConsumption()
|
||||
if self._state == self._PARSING_PAYLOAD:
|
||||
self._consumePayload()
|
||||
|
||||
|
||||
def _consumeLength(self):
|
||||
"""
|
||||
Consumes the length portion of C{self._remainingData}.
|
||||
|
||||
@raise IncompleteNetstring: if C{self._remainingData} contains
|
||||
a partial length specification (digits without trailing
|
||||
comma).
|
||||
@raise NetstringParseError: if the received data do not form a valid
|
||||
netstring.
|
||||
"""
|
||||
lengthMatch = self._LENGTH.match(self._remainingData)
|
||||
if not lengthMatch:
|
||||
self._checkPartialLengthSpecification()
|
||||
raise IncompleteNetstring()
|
||||
self._processLength(lengthMatch)
|
||||
|
||||
|
||||
def _checkPartialLengthSpecification(self):
|
||||
"""
|
||||
Makes sure that the received data represents a valid number.
|
||||
|
||||
Checks if C{self._remainingData} represents a number smaller or
|
||||
equal to C{self.MAX_LENGTH}.
|
||||
|
||||
@raise NetstringParseError: if C{self._remainingData} is no
|
||||
number or is too big (checked by L{_extractLength}).
|
||||
"""
|
||||
partialLengthMatch = self._LENGTH_PREFIX.match(self._remainingData)
|
||||
if not partialLengthMatch:
|
||||
raise NetstringParseError(self._MISSING_LENGTH)
|
||||
lengthSpecification = (partialLengthMatch.group(1))
|
||||
self._extractLength(lengthSpecification)
|
||||
|
||||
|
||||
def _processLength(self, lengthMatch):
|
||||
"""
|
||||
Processes the length definition of a netstring.
|
||||
|
||||
Extracts and stores in C{self._expectedPayloadSize} the number
|
||||
representing the netstring size. Removes the prefix
|
||||
representing the length specification from
|
||||
C{self._remainingData}.
|
||||
|
||||
@raise NetstringParseError: if the received netstring does not
|
||||
start with a number or the number is bigger than
|
||||
C{self.MAX_LENGTH}.
|
||||
@param lengthMatch: A regular expression match object matching
|
||||
a netstring length specification
|
||||
@type lengthMatch: C{re.Match}
|
||||
"""
|
||||
endOfNumber = lengthMatch.end(1)
|
||||
startOfData = lengthMatch.end(2)
|
||||
lengthString = self._remainingData[:endOfNumber]
|
||||
# Expect payload plus trailing comma:
|
||||
self._expectedPayloadSize = self._extractLength(lengthString) + 1
|
||||
self._remainingData = self._remainingData[startOfData:]
|
||||
|
||||
|
||||
def _extractLength(self, lengthAsString):
|
||||
"""
|
||||
Attempts to extract the length information of a netstring.
|
||||
|
||||
@raise NetstringParseError: if the number is bigger than
|
||||
C{self.MAX_LENGTH}.
|
||||
@param lengthAsString: A chunk of data starting with a length
|
||||
specification
|
||||
@type lengthAsString: C{bytes}
|
||||
@return: The length of the netstring
|
||||
@rtype: C{int}
|
||||
"""
|
||||
self._checkStringSize(lengthAsString)
|
||||
length = int(lengthAsString)
|
||||
if length > self.MAX_LENGTH:
|
||||
raise NetstringParseError(self._TOO_LONG % (self.MAX_LENGTH,))
|
||||
return length
|
||||
|
||||
|
||||
def _checkStringSize(self, lengthAsString):
|
||||
"""
|
||||
Checks the sanity of lengthAsString.
|
||||
|
||||
Checks if the size of the length specification exceeds the
|
||||
size of the string representing self.MAX_LENGTH. If this is
|
||||
not the case, the number represented by lengthAsString is
|
||||
certainly bigger than self.MAX_LENGTH, and a
|
||||
NetstringParseError can be raised.
|
||||
|
||||
This method should make sure that netstrings with extremely
|
||||
long length specifications are refused before even attempting
|
||||
to convert them to an integer (which might trigger a
|
||||
MemoryError).
|
||||
"""
|
||||
if len(lengthAsString) > self._maxLengthSize():
|
||||
raise NetstringParseError(self._TOO_LONG % (self.MAX_LENGTH,))
|
||||
|
||||
|
||||
def _prepareForPayloadConsumption(self):
|
||||
"""
|
||||
Sets up variables necessary for consuming the payload of a netstring.
|
||||
"""
|
||||
self._state = self._PARSING_PAYLOAD
|
||||
self._currentPayloadSize = 0
|
||||
self._payload.seek(0)
|
||||
self._payload.truncate()
|
||||
|
||||
|
||||
def _consumePayload(self):
|
||||
"""
|
||||
Consumes the payload portion of C{self._remainingData}.
|
||||
|
||||
If the payload is complete, checks for the trailing comma and
|
||||
processes the payload. If not, raises an L{IncompleteNetstring}
|
||||
exception.
|
||||
|
||||
@raise IncompleteNetstring: if the payload received so far
|
||||
contains fewer characters than expected.
|
||||
@raise NetstringParseError: if the payload does not end with a
|
||||
comma.
|
||||
"""
|
||||
self._extractPayload()
|
||||
if self._currentPayloadSize < self._expectedPayloadSize:
|
||||
raise IncompleteNetstring()
|
||||
self._checkForTrailingComma()
|
||||
self._state = self._PARSING_LENGTH
|
||||
self._processPayload()
|
||||
|
||||
|
||||
def _extractPayload(self):
|
||||
"""
|
||||
Extracts payload information from C{self._remainingData}.
|
||||
|
||||
Splits C{self._remainingData} at the end of the netstring. The
|
||||
first part becomes C{self._payload}, the second part is stored
|
||||
in C{self._remainingData}.
|
||||
|
||||
If the netstring is not yet complete, the whole content of
|
||||
C{self._remainingData} is moved to C{self._payload}.
|
||||
"""
|
||||
if self._payloadComplete():
|
||||
remainingPayloadSize = (self._expectedPayloadSize -
|
||||
self._currentPayloadSize)
|
||||
self._payload.write(self._remainingData[:remainingPayloadSize])
|
||||
self._remainingData = self._remainingData[remainingPayloadSize:]
|
||||
self._currentPayloadSize = self._expectedPayloadSize
|
||||
else:
|
||||
self._payload.write(self._remainingData)
|
||||
self._currentPayloadSize += len(self._remainingData)
|
||||
self._remainingData = b""
|
||||
|
||||
|
||||
def _payloadComplete(self):
|
||||
"""
|
||||
Checks if enough data have been received to complete the netstring.
|
||||
|
||||
@return: C{True} iff the received data contain at least as many
|
||||
characters as specified in the length section of the
|
||||
netstring
|
||||
@rtype: C{bool}
|
||||
"""
|
||||
return (len(self._remainingData) + self._currentPayloadSize >=
|
||||
self._expectedPayloadSize)
|
||||
|
||||
|
||||
def _processPayload(self):
|
||||
"""
|
||||
Processes the actual payload with L{stringReceived}.
|
||||
|
||||
Strips C{self._payload} of the trailing comma and calls
|
||||
L{stringReceived} with the result.
|
||||
"""
|
||||
self.stringReceived(self._payload.getvalue()[:-1])
|
||||
|
||||
|
||||
def _checkForTrailingComma(self):
|
||||
"""
|
||||
Checks if the netstring has a trailing comma at the expected position.
|
||||
|
||||
@raise NetstringParseError: if the last payload character is
|
||||
anything but a comma.
|
||||
"""
|
||||
if self._payload.getvalue()[-1:] != b",":
|
||||
raise NetstringParseError(self._MISSING_COMMA)
|
||||
|
||||
|
||||
def _handleParseError(self):
|
||||
"""
|
||||
Terminates the connection and sets the flag C{self.brokenPeer}.
|
||||
"""
|
||||
self.transport.loseConnection()
|
||||
self.brokenPeer = 1
|
||||
|
||||
|
||||
|
||||
class LineOnlyReceiver(protocol.Protocol):
|
||||
"""
|
||||
A protocol that receives only lines.
|
||||
|
||||
This is purely a speed optimisation over LineReceiver, for the
|
||||
cases that raw mode is known to be unnecessary.
|
||||
|
||||
@cvar delimiter: The line-ending delimiter to use. By default this is
|
||||
C{b'\\r\\n'}.
|
||||
@cvar MAX_LENGTH: The maximum length of a line to allow (If a
|
||||
sent line is longer than this, the connection is dropped).
|
||||
Default is 16384.
|
||||
"""
|
||||
_buffer = b''
|
||||
delimiter = b'\r\n'
|
||||
MAX_LENGTH = 16384
|
||||
|
||||
def dataReceived(self, data):
|
||||
"""
|
||||
Translates bytes into lines, and calls lineReceived.
|
||||
"""
|
||||
lines = (self._buffer+data).split(self.delimiter)
|
||||
self._buffer = lines.pop(-1)
|
||||
for line in lines:
|
||||
if self.transport.disconnecting:
|
||||
# this is necessary because the transport may be told to lose
|
||||
# the connection by a line within a larger packet, and it is
|
||||
# important to disregard all the lines in that packet following
|
||||
# the one that told it to close.
|
||||
return
|
||||
if len(line) > self.MAX_LENGTH:
|
||||
return self.lineLengthExceeded(line)
|
||||
else:
|
||||
self.lineReceived(line)
|
||||
if len(self._buffer) > self.MAX_LENGTH:
|
||||
return self.lineLengthExceeded(self._buffer)
|
||||
|
||||
|
||||
def lineReceived(self, line):
|
||||
"""
|
||||
Override this for when each line is received.
|
||||
|
||||
@param line: The line which was received with the delimiter removed.
|
||||
@type line: C{bytes}
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def sendLine(self, line):
|
||||
"""
|
||||
Sends a line to the other end of the connection.
|
||||
|
||||
@param line: The line to send, not including the delimiter.
|
||||
@type line: C{bytes}
|
||||
"""
|
||||
return self.transport.writeSequence((line, self.delimiter))
|
||||
|
||||
|
||||
def lineLengthExceeded(self, line):
|
||||
"""
|
||||
Called when the maximum line length has been reached.
|
||||
Override if it needs to be dealt with in some special way.
|
||||
"""
|
||||
return self.transport.loseConnection()
|
||||
|
||||
|
||||
|
||||
class _PauseableMixin:
|
||||
paused = False
|
||||
|
||||
def pauseProducing(self):
|
||||
self.paused = True
|
||||
self.transport.pauseProducing()
|
||||
|
||||
|
||||
def resumeProducing(self):
|
||||
self.paused = False
|
||||
self.transport.resumeProducing()
|
||||
self.dataReceived(b'')
|
||||
|
||||
|
||||
def stopProducing(self):
|
||||
self.paused = True
|
||||
self.transport.stopProducing()
|
||||
|
||||
|
||||
|
||||
class LineReceiver(protocol.Protocol, _PauseableMixin):
|
||||
"""
|
||||
A protocol that receives lines and/or raw data, depending on mode.
|
||||
|
||||
In line mode, each line that's received becomes a callback to
|
||||
L{lineReceived}. In raw data mode, each chunk of raw data becomes a
|
||||
callback to L{LineReceiver.rawDataReceived}.
|
||||
The L{setLineMode} and L{setRawMode} methods switch between the two modes.
|
||||
|
||||
This is useful for line-oriented protocols such as IRC, HTTP, POP, etc.
|
||||
|
||||
@cvar delimiter: The line-ending delimiter to use. By default this is
|
||||
C{b'\\r\\n'}.
|
||||
@cvar MAX_LENGTH: The maximum length of a line to allow (If a
|
||||
sent line is longer than this, the connection is dropped).
|
||||
Default is 16384.
|
||||
"""
|
||||
line_mode = 1
|
||||
_buffer = b''
|
||||
_busyReceiving = False
|
||||
delimiter = b'\r\n'
|
||||
MAX_LENGTH = 16384
|
||||
|
||||
def clearLineBuffer(self):
|
||||
"""
|
||||
Clear buffered data.
|
||||
|
||||
@return: All of the cleared buffered data.
|
||||
@rtype: C{bytes}
|
||||
"""
|
||||
b, self._buffer = self._buffer, b""
|
||||
return b
|
||||
|
||||
|
||||
def dataReceived(self, data):
|
||||
"""
|
||||
Protocol.dataReceived.
|
||||
Translates bytes into lines, and calls lineReceived (or
|
||||
rawDataReceived, depending on mode.)
|
||||
"""
|
||||
if self._busyReceiving:
|
||||
self._buffer += data
|
||||
return
|
||||
|
||||
try:
|
||||
self._busyReceiving = True
|
||||
self._buffer += data
|
||||
while self._buffer and not self.paused:
|
||||
if self.line_mode:
|
||||
try:
|
||||
line, self._buffer = self._buffer.split(
|
||||
self.delimiter, 1)
|
||||
except ValueError:
|
||||
if len(self._buffer) >= (self.MAX_LENGTH
|
||||
+ len(self.delimiter)):
|
||||
line, self._buffer = self._buffer, b''
|
||||
return self.lineLengthExceeded(line)
|
||||
return
|
||||
else:
|
||||
lineLength = len(line)
|
||||
if lineLength > self.MAX_LENGTH:
|
||||
exceeded = line + self.delimiter + self._buffer
|
||||
self._buffer = b''
|
||||
return self.lineLengthExceeded(exceeded)
|
||||
why = self.lineReceived(line)
|
||||
if (why or self.transport and
|
||||
self.transport.disconnecting):
|
||||
return why
|
||||
else:
|
||||
data = self._buffer
|
||||
self._buffer = b''
|
||||
why = self.rawDataReceived(data)
|
||||
if why:
|
||||
return why
|
||||
finally:
|
||||
self._busyReceiving = False
|
||||
|
||||
|
||||
def setLineMode(self, extra=b''):
|
||||
"""
|
||||
Sets the line-mode of this receiver.
|
||||
|
||||
If you are calling this from a rawDataReceived callback,
|
||||
you can pass in extra unhandled data, and that data will
|
||||
be parsed for lines. Further data received will be sent
|
||||
to lineReceived rather than rawDataReceived.
|
||||
|
||||
Do not pass extra data if calling this function from
|
||||
within a lineReceived callback.
|
||||
"""
|
||||
self.line_mode = 1
|
||||
if extra:
|
||||
return self.dataReceived(extra)
|
||||
|
||||
|
||||
def setRawMode(self):
|
||||
"""
|
||||
Sets the raw mode of this receiver.
|
||||
Further data received will be sent to rawDataReceived rather
|
||||
than lineReceived.
|
||||
"""
|
||||
self.line_mode = 0
|
||||
|
||||
|
||||
def rawDataReceived(self, data):
|
||||
"""
|
||||
Override this for when raw data is received.
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def lineReceived(self, line):
|
||||
"""
|
||||
Override this for when each line is received.
|
||||
|
||||
@param line: The line which was received with the delimiter removed.
|
||||
@type line: C{bytes}
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def sendLine(self, line):
|
||||
"""
|
||||
Sends a line to the other end of the connection.
|
||||
|
||||
@param line: The line to send, not including the delimiter.
|
||||
@type line: C{bytes}
|
||||
"""
|
||||
return self.transport.write(line + self.delimiter)
|
||||
|
||||
|
||||
def lineLengthExceeded(self, line):
|
||||
"""
|
||||
Called when the maximum line length has been reached.
|
||||
Override if it needs to be dealt with in some special way.
|
||||
|
||||
The argument 'line' contains the remainder of the buffer, starting
|
||||
with (at least some part) of the line which is too long. This may
|
||||
be more than one line, or may be only the initial portion of the
|
||||
line.
|
||||
"""
|
||||
return self.transport.loseConnection()
|
||||
|
||||
|
||||
|
||||
class StringTooLongError(AssertionError):
|
||||
"""
|
||||
Raised when trying to send a string too long for a length prefixed
|
||||
protocol.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
class _RecvdCompatHack(object):
|
||||
"""
|
||||
Emulates the to-be-deprecated C{IntNStringReceiver.recvd} attribute.
|
||||
|
||||
The C{recvd} attribute was where the working buffer for buffering and
|
||||
parsing netstrings was kept. It was updated each time new data arrived and
|
||||
each time some of that data was parsed and delivered to application code.
|
||||
The piecemeal updates to its string value were expensive and have been
|
||||
removed from C{IntNStringReceiver} in the normal case. However, for
|
||||
applications directly reading this attribute, this descriptor restores that
|
||||
behavior. It only copies the working buffer when necessary (ie, when
|
||||
accessed). This avoids the cost for applications not using the data.
|
||||
|
||||
This is a custom descriptor rather than a property, because we still need
|
||||
the default __set__ behavior in both new-style and old-style subclasses.
|
||||
"""
|
||||
def __get__(self, oself, type=None):
|
||||
return oself._unprocessed[oself._compatibilityOffset:]
|
||||
|
||||
|
||||
|
||||
class IntNStringReceiver(protocol.Protocol, _PauseableMixin):
|
||||
"""
|
||||
Generic class for length prefixed protocols.
|
||||
|
||||
@ivar _unprocessed: bytes received, but not yet broken up into messages /
|
||||
sent to stringReceived. _compatibilityOffset must be updated when this
|
||||
value is updated so that the C{recvd} attribute can be generated
|
||||
correctly.
|
||||
@type _unprocessed: C{bytes}
|
||||
|
||||
@ivar structFormat: format used for struct packing/unpacking. Define it in
|
||||
subclass.
|
||||
@type structFormat: C{str}
|
||||
|
||||
@ivar prefixLength: length of the prefix, in bytes. Define it in subclass,
|
||||
using C{struct.calcsize(structFormat)}
|
||||
@type prefixLength: C{int}
|
||||
|
||||
@ivar _compatibilityOffset: the offset within C{_unprocessed} to the next
|
||||
message to be parsed. (used to generate the recvd attribute)
|
||||
@type _compatibilityOffset: C{int}
|
||||
"""
|
||||
|
||||
MAX_LENGTH = 99999
|
||||
_unprocessed = b""
|
||||
_compatibilityOffset = 0
|
||||
|
||||
# Backwards compatibility support for applications which directly touch the
|
||||
# "internal" parse buffer.
|
||||
recvd = _RecvdCompatHack()
|
||||
|
||||
def stringReceived(self, string):
|
||||
"""
|
||||
Override this for notification when each complete string is received.
|
||||
|
||||
@param string: The complete string which was received with all
|
||||
framing (length prefix, etc) removed.
|
||||
@type string: C{bytes}
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
def lengthLimitExceeded(self, length):
|
||||
"""
|
||||
Callback invoked when a length prefix greater than C{MAX_LENGTH} is
|
||||
received. The default implementation disconnects the transport.
|
||||
Override this.
|
||||
|
||||
@param length: The length prefix which was received.
|
||||
@type length: C{int}
|
||||
"""
|
||||
self.transport.loseConnection()
|
||||
|
||||
|
||||
def dataReceived(self, data):
|
||||
"""
|
||||
Convert int prefixed strings into calls to stringReceived.
|
||||
"""
|
||||
# Try to minimize string copying (via slices) by keeping one buffer
|
||||
# containing all the data we have so far and a separate offset into that
|
||||
# buffer.
|
||||
alldata = self._unprocessed + data
|
||||
currentOffset = 0
|
||||
prefixLength = self.prefixLength
|
||||
fmt = self.structFormat
|
||||
self._unprocessed = alldata
|
||||
|
||||
while len(alldata) >= (currentOffset + prefixLength) and not self.paused:
|
||||
messageStart = currentOffset + prefixLength
|
||||
length, = unpack(fmt, alldata[currentOffset:messageStart])
|
||||
if length > self.MAX_LENGTH:
|
||||
self._unprocessed = alldata
|
||||
self._compatibilityOffset = currentOffset
|
||||
self.lengthLimitExceeded(length)
|
||||
return
|
||||
messageEnd = messageStart + length
|
||||
if len(alldata) < messageEnd:
|
||||
break
|
||||
|
||||
# Here we have to slice the working buffer so we can send just the
|
||||
# netstring into the stringReceived callback.
|
||||
packet = alldata[messageStart:messageEnd]
|
||||
currentOffset = messageEnd
|
||||
self._compatibilityOffset = currentOffset
|
||||
self.stringReceived(packet)
|
||||
|
||||
# Check to see if the backwards compat "recvd" attribute got written
|
||||
# to by application code. If so, drop the current data buffer and
|
||||
# switch to the new buffer given by that attribute's value.
|
||||
if 'recvd' in self.__dict__:
|
||||
alldata = self.__dict__.pop('recvd')
|
||||
self._unprocessed = alldata
|
||||
self._compatibilityOffset = currentOffset = 0
|
||||
if alldata:
|
||||
continue
|
||||
return
|
||||
|
||||
# Slice off all the data that has been processed, avoiding holding onto
|
||||
# memory to store it, and update the compatibility attributes to reflect
|
||||
# that change.
|
||||
self._unprocessed = alldata[currentOffset:]
|
||||
self._compatibilityOffset = 0
|
||||
|
||||
|
||||
def sendString(self, string):
|
||||
"""
|
||||
Send a prefixed string to the other end of the connection.
|
||||
|
||||
@param string: The string to send. The necessary framing (length
|
||||
prefix, etc) will be added.
|
||||
@type string: C{bytes}
|
||||
"""
|
||||
if len(string) >= 2 ** (8 * self.prefixLength):
|
||||
raise StringTooLongError(
|
||||
"Try to send %s bytes whereas maximum is %s" % (
|
||||
len(string), 2 ** (8 * self.prefixLength)))
|
||||
self.transport.write(
|
||||
pack(self.structFormat, len(string)) + string)
|
||||
|
||||
|
||||
|
||||
class Int32StringReceiver(IntNStringReceiver):
|
||||
"""
|
||||
A receiver for int32-prefixed strings.
|
||||
|
||||
An int32 string is a string prefixed by 4 bytes, the 32-bit length of
|
||||
the string encoded in network byte order.
|
||||
|
||||
This class publishes the same interface as NetstringReceiver.
|
||||
"""
|
||||
structFormat = "!I"
|
||||
prefixLength = calcsize(structFormat)
|
||||
|
||||
|
||||
|
||||
class Int16StringReceiver(IntNStringReceiver):
|
||||
"""
|
||||
A receiver for int16-prefixed strings.
|
||||
|
||||
An int16 string is a string prefixed by 2 bytes, the 16-bit length of
|
||||
the string encoded in network byte order.
|
||||
|
||||
This class publishes the same interface as NetstringReceiver.
|
||||
"""
|
||||
structFormat = "!H"
|
||||
prefixLength = calcsize(structFormat)
|
||||
|
||||
|
||||
|
||||
class Int8StringReceiver(IntNStringReceiver):
|
||||
"""
|
||||
A receiver for int8-prefixed strings.
|
||||
|
||||
An int8 string is a string prefixed by 1 byte, the 8-bit length of
|
||||
the string.
|
||||
|
||||
This class publishes the same interface as NetstringReceiver.
|
||||
"""
|
||||
structFormat = "!B"
|
||||
prefixLength = calcsize(structFormat)
|
||||
|
||||
|
||||
|
||||
class StatefulStringProtocol:
|
||||
"""
|
||||
A stateful string protocol.
|
||||
|
||||
This is a mixin for string protocols (L{Int32StringReceiver},
|
||||
L{NetstringReceiver}) which translates L{stringReceived} into a callback
|
||||
(prefixed with C{'proto_'}) depending on state.
|
||||
|
||||
The state C{'done'} is special; if a C{proto_*} method returns it, the
|
||||
connection will be closed immediately.
|
||||
|
||||
@ivar state: Current state of the protocol. Defaults to C{'init'}.
|
||||
@type state: C{str}
|
||||
"""
|
||||
|
||||
state = 'init'
|
||||
|
||||
def stringReceived(self, string):
|
||||
"""
|
||||
Choose a protocol phase function and call it.
|
||||
|
||||
Call back to the appropriate protocol phase; this begins with
|
||||
the function C{proto_init} and moves on to C{proto_*} depending on
|
||||
what each C{proto_*} function returns. (For example, if
|
||||
C{self.proto_init} returns 'foo', then C{self.proto_foo} will be the
|
||||
next function called when a protocol message is received.
|
||||
"""
|
||||
try:
|
||||
pto = 'proto_' + self.state
|
||||
statehandler = getattr(self, pto)
|
||||
except AttributeError:
|
||||
log.msg('callback', self.state, 'not found')
|
||||
else:
|
||||
self.state = statehandler(string)
|
||||
if self.state == 'done':
|
||||
self.transport.loseConnection()
|
||||
|
||||
|
||||
|
||||
@implementer(interfaces.IProducer)
|
||||
class FileSender:
|
||||
"""
|
||||
A producer that sends the contents of a file to a consumer.
|
||||
|
||||
This is a helper for protocols that, at some point, will take a
|
||||
file-like object, read its contents, and write them out to the network,
|
||||
optionally performing some transformation on the bytes in between.
|
||||
"""
|
||||
|
||||
CHUNK_SIZE = 2 ** 14
|
||||
|
||||
lastSent = ''
|
||||
deferred = None
|
||||
|
||||
def beginFileTransfer(self, file, consumer, transform=None):
|
||||
"""
|
||||
Begin transferring a file
|
||||
|
||||
@type file: Any file-like object
|
||||
@param file: The file object to read data from
|
||||
|
||||
@type consumer: Any implementor of IConsumer
|
||||
@param consumer: The object to write data to
|
||||
|
||||
@param transform: A callable taking one string argument and returning
|
||||
the same. All bytes read from the file are passed through this before
|
||||
being written to the consumer.
|
||||
|
||||
@rtype: C{Deferred}
|
||||
@return: A deferred whose callback will be invoked when the file has
|
||||
been completely written to the consumer. The last byte written to the
|
||||
consumer is passed to the callback.
|
||||
"""
|
||||
self.file = file
|
||||
self.consumer = consumer
|
||||
self.transform = transform
|
||||
|
||||
self.deferred = deferred = defer.Deferred()
|
||||
self.consumer.registerProducer(self, False)
|
||||
return deferred
|
||||
|
||||
|
||||
def resumeProducing(self):
|
||||
chunk = ''
|
||||
if self.file:
|
||||
chunk = self.file.read(self.CHUNK_SIZE)
|
||||
if not chunk:
|
||||
self.file = None
|
||||
self.consumer.unregisterProducer()
|
||||
if self.deferred:
|
||||
self.deferred.callback(self.lastSent)
|
||||
self.deferred = None
|
||||
return
|
||||
|
||||
if self.transform:
|
||||
chunk = self.transform(chunk)
|
||||
self.consumer.write(chunk)
|
||||
self.lastSent = chunk[-1:]
|
||||
|
||||
|
||||
def pauseProducing(self):
|
||||
pass
|
||||
|
||||
|
||||
def stopProducing(self):
|
||||
if self.deferred:
|
||||
self.deferred.errback(
|
||||
Exception("Consumer asked us to stop producing"))
|
||||
self.deferred = None
|
||||
415
venv/lib/python3.9/site-packages/twisted/protocols/dict.py
Normal file
415
venv/lib/python3.9/site-packages/twisted/protocols/dict.py
Normal file
|
|
@ -0,0 +1,415 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
|
||||
"""
|
||||
Dict client protocol implementation.
|
||||
|
||||
@author: Pavel Pergamenshchik
|
||||
"""
|
||||
|
||||
from twisted.protocols import basic
|
||||
from twisted.internet import defer, protocol
|
||||
from twisted.python import log
|
||||
from io import BytesIO
|
||||
|
||||
def parseParam(line):
|
||||
"""Chew one dqstring or atom from beginning of line and return (param, remaningline)"""
|
||||
if line == b'':
|
||||
return (None, b'')
|
||||
elif line[0:1] != b'"': # atom
|
||||
mode = 1
|
||||
else: # dqstring
|
||||
mode = 2
|
||||
res = b""
|
||||
io = BytesIO(line)
|
||||
if mode == 2: # skip the opening quote
|
||||
io.read(1)
|
||||
while 1:
|
||||
a = io.read(1)
|
||||
if a == b'"':
|
||||
if mode == 2:
|
||||
io.read(1) # skip the separating space
|
||||
return (res, io.read())
|
||||
elif a == b'\\':
|
||||
a = io.read(1)
|
||||
if a == b'':
|
||||
return (None, line) # unexpected end of string
|
||||
elif a == b'':
|
||||
if mode == 1:
|
||||
return (res, io.read())
|
||||
else:
|
||||
return (None, line) # unexpected end of string
|
||||
elif a == b' ':
|
||||
if mode == 1:
|
||||
return (res, io.read())
|
||||
res += a
|
||||
|
||||
|
||||
|
||||
def makeAtom(line):
|
||||
"""Munch a string into an 'atom'"""
|
||||
# FIXME: proper quoting
|
||||
return filter(lambda x: not (x in map(chr, range(33)+[34, 39, 92])), line)
|
||||
|
||||
|
||||
|
||||
def makeWord(s):
|
||||
mustquote = range(33)+[34, 39, 92]
|
||||
result = []
|
||||
for c in s:
|
||||
if ord(c) in mustquote:
|
||||
result.append(b"\\")
|
||||
result.append(c)
|
||||
s = b"".join(result)
|
||||
return s
|
||||
|
||||
|
||||
|
||||
def parseText(line):
|
||||
if len(line) == 1 and line == b'.':
|
||||
return None
|
||||
else:
|
||||
if len(line) > 1 and line[0:2] == b'..':
|
||||
line = line[1:]
|
||||
return line
|
||||
|
||||
|
||||
|
||||
class Definition:
|
||||
"""A word definition"""
|
||||
def __init__(self, name, db, dbdesc, text):
|
||||
self.name = name
|
||||
self.db = db
|
||||
self.dbdesc = dbdesc
|
||||
self.text = text # list of strings not terminated by newline
|
||||
|
||||
|
||||
|
||||
class DictClient(basic.LineReceiver):
|
||||
"""dict (RFC2229) client"""
|
||||
|
||||
data = None # multiline data
|
||||
MAX_LENGTH = 1024
|
||||
state = None
|
||||
mode = None
|
||||
result = None
|
||||
factory = None
|
||||
|
||||
def __init__(self):
|
||||
self.data = None
|
||||
self.result = None
|
||||
|
||||
|
||||
def connectionMade(self):
|
||||
self.state = "conn"
|
||||
self.mode = "command"
|
||||
|
||||
|
||||
def sendLine(self, line):
|
||||
"""Throw up if the line is longer than 1022 characters"""
|
||||
if len(line) > self.MAX_LENGTH - 2:
|
||||
raise ValueError("DictClient tried to send a too long line")
|
||||
basic.LineReceiver.sendLine(self, line)
|
||||
|
||||
|
||||
def lineReceived(self, line):
|
||||
try:
|
||||
line = line.decode("utf-8")
|
||||
except UnicodeError: # garbage received, skip
|
||||
return
|
||||
if self.mode == "text": # we are receiving textual data
|
||||
code = "text"
|
||||
else:
|
||||
if len(line) < 4:
|
||||
log.msg("DictClient got invalid line from server -- %s" % line)
|
||||
self.protocolError("Invalid line from server")
|
||||
self.transport.LoseConnection()
|
||||
return
|
||||
code = int(line[:3])
|
||||
line = line[4:]
|
||||
method = getattr(self, 'dictCode_%s_%s' % (code, self.state), self.dictCode_default)
|
||||
method(line)
|
||||
|
||||
|
||||
def dictCode_default(self, line):
|
||||
"""Unknown message"""
|
||||
log.msg("DictClient got unexpected message from server -- %s" % line)
|
||||
self.protocolError("Unexpected server message")
|
||||
self.transport.loseConnection()
|
||||
|
||||
|
||||
def dictCode_221_ready(self, line):
|
||||
"""We are about to get kicked off, do nothing"""
|
||||
pass
|
||||
|
||||
|
||||
def dictCode_220_conn(self, line):
|
||||
"""Greeting message"""
|
||||
self.state = "ready"
|
||||
self.dictConnected()
|
||||
|
||||
|
||||
def dictCode_530_conn(self):
|
||||
self.protocolError("Access denied")
|
||||
self.transport.loseConnection()
|
||||
|
||||
|
||||
def dictCode_420_conn(self):
|
||||
self.protocolError("Server temporarily unavailable")
|
||||
self.transport.loseConnection()
|
||||
|
||||
|
||||
def dictCode_421_conn(self):
|
||||
self.protocolError("Server shutting down at operator request")
|
||||
self.transport.loseConnection()
|
||||
|
||||
|
||||
def sendDefine(self, database, word):
|
||||
"""Send a dict DEFINE command"""
|
||||
assert self.state == "ready", "DictClient.sendDefine called when not in ready state"
|
||||
self.result = None # these two are just in case. In "ready" state, result and data
|
||||
self.data = None # should be None
|
||||
self.state = "define"
|
||||
command = "DEFINE %s %s" % (makeAtom(database.encode("UTF-8")), makeWord(word.encode("UTF-8")))
|
||||
self.sendLine(command)
|
||||
|
||||
|
||||
def sendMatch(self, database, strategy, word):
|
||||
"""Send a dict MATCH command"""
|
||||
assert self.state == "ready", "DictClient.sendMatch called when not in ready state"
|
||||
self.result = None
|
||||
self.data = None
|
||||
self.state = "match"
|
||||
command = "MATCH %s %s %s" % (makeAtom(database), makeAtom(strategy), makeAtom(word))
|
||||
self.sendLine(command.encode("UTF-8"))
|
||||
|
||||
def dictCode_550_define(self, line):
|
||||
"""Invalid database"""
|
||||
self.mode = "ready"
|
||||
self.defineFailed("Invalid database")
|
||||
|
||||
|
||||
def dictCode_550_match(self, line):
|
||||
"""Invalid database"""
|
||||
self.mode = "ready"
|
||||
self.matchFailed("Invalid database")
|
||||
|
||||
|
||||
def dictCode_551_match(self, line):
|
||||
"""Invalid strategy"""
|
||||
self.mode = "ready"
|
||||
self.matchFailed("Invalid strategy")
|
||||
|
||||
|
||||
def dictCode_552_define(self, line):
|
||||
"""No match"""
|
||||
self.mode = "ready"
|
||||
self.defineFailed("No match")
|
||||
|
||||
|
||||
def dictCode_552_match(self, line):
|
||||
"""No match"""
|
||||
self.mode = "ready"
|
||||
self.matchFailed("No match")
|
||||
|
||||
|
||||
def dictCode_150_define(self, line):
|
||||
"""n definitions retrieved"""
|
||||
self.result = []
|
||||
|
||||
|
||||
def dictCode_151_define(self, line):
|
||||
"""Definition text follows"""
|
||||
self.mode = "text"
|
||||
(word, line) = parseParam(line)
|
||||
(db, line) = parseParam(line)
|
||||
(dbdesc, line) = parseParam(line)
|
||||
if not (word and db and dbdesc):
|
||||
self.protocolError("Invalid server response")
|
||||
self.transport.loseConnection()
|
||||
else:
|
||||
self.result.append(Definition(word, db, dbdesc, []))
|
||||
self.data = []
|
||||
|
||||
|
||||
def dictCode_152_match(self, line):
|
||||
"""n matches found, text follows"""
|
||||
self.mode = "text"
|
||||
self.result = []
|
||||
self.data = []
|
||||
|
||||
|
||||
def dictCode_text_define(self, line):
|
||||
"""A line of definition text received"""
|
||||
res = parseText(line)
|
||||
if res == None:
|
||||
self.mode = "command"
|
||||
self.result[-1].text = self.data
|
||||
self.data = None
|
||||
else:
|
||||
self.data.append(line)
|
||||
|
||||
|
||||
def dictCode_text_match(self, line):
|
||||
"""One line of match text received"""
|
||||
def l(s):
|
||||
p1, t = parseParam(s)
|
||||
p2, t = parseParam(t)
|
||||
return (p1, p2)
|
||||
res = parseText(line)
|
||||
if res == None:
|
||||
self.mode = "command"
|
||||
self.result = map(l, self.data)
|
||||
self.data = None
|
||||
else:
|
||||
self.data.append(line)
|
||||
|
||||
|
||||
def dictCode_250_define(self, line):
|
||||
"""ok"""
|
||||
t = self.result
|
||||
self.result = None
|
||||
self.state = "ready"
|
||||
self.defineDone(t)
|
||||
|
||||
|
||||
def dictCode_250_match(self, line):
|
||||
"""ok"""
|
||||
t = self.result
|
||||
self.result = None
|
||||
self.state = "ready"
|
||||
self.matchDone(t)
|
||||
|
||||
|
||||
def protocolError(self, reason):
|
||||
"""override to catch unexpected dict protocol conditions"""
|
||||
pass
|
||||
|
||||
|
||||
def dictConnected(self):
|
||||
"""override to be notified when the server is ready to accept commands"""
|
||||
pass
|
||||
|
||||
|
||||
def defineFailed(self, reason):
|
||||
"""override to catch reasonable failure responses to DEFINE"""
|
||||
pass
|
||||
|
||||
|
||||
def defineDone(self, result):
|
||||
"""override to catch successful DEFINE"""
|
||||
pass
|
||||
|
||||
|
||||
def matchFailed(self, reason):
|
||||
"""override to catch resonable failure responses to MATCH"""
|
||||
pass
|
||||
|
||||
|
||||
def matchDone(self, result):
|
||||
"""override to catch successful MATCH"""
|
||||
pass
|
||||
|
||||
|
||||
|
||||
class InvalidResponse(Exception):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
class DictLookup(DictClient):
|
||||
"""Utility class for a single dict transaction. To be used with DictLookupFactory"""
|
||||
|
||||
def protocolError(self, reason):
|
||||
if not self.factory.done:
|
||||
self.factory.d.errback(InvalidResponse(reason))
|
||||
self.factory.clientDone()
|
||||
|
||||
|
||||
def dictConnected(self):
|
||||
if self.factory.queryType == "define":
|
||||
self.sendDefine(*self.factory.param)
|
||||
elif self.factory.queryType == "match":
|
||||
self.sendMatch(*self.factory.param)
|
||||
|
||||
|
||||
def defineFailed(self, reason):
|
||||
self.factory.d.callback([])
|
||||
self.factory.clientDone()
|
||||
self.transport.loseConnection()
|
||||
|
||||
|
||||
def defineDone(self, result):
|
||||
self.factory.d.callback(result)
|
||||
self.factory.clientDone()
|
||||
self.transport.loseConnection()
|
||||
|
||||
|
||||
def matchFailed(self, reason):
|
||||
self.factory.d.callback([])
|
||||
self.factory.clientDone()
|
||||
self.transport.loseConnection()
|
||||
|
||||
|
||||
def matchDone(self, result):
|
||||
self.factory.d.callback(result)
|
||||
self.factory.clientDone()
|
||||
self.transport.loseConnection()
|
||||
|
||||
|
||||
|
||||
class DictLookupFactory(protocol.ClientFactory):
|
||||
"""Utility factory for a single dict transaction"""
|
||||
protocol = DictLookup
|
||||
done = None
|
||||
|
||||
def __init__(self, queryType, param, d):
|
||||
self.queryType = queryType
|
||||
self.param = param
|
||||
self.d = d
|
||||
self.done = 0
|
||||
|
||||
|
||||
def clientDone(self):
|
||||
"""Called by client when done."""
|
||||
self.done = 1
|
||||
del self.d
|
||||
|
||||
|
||||
def clientConnectionFailed(self, connector, error):
|
||||
self.d.errback(error)
|
||||
|
||||
|
||||
def clientConnectionLost(self, connector, error):
|
||||
if not self.done:
|
||||
self.d.errback(error)
|
||||
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
p = self.protocol()
|
||||
p.factory = self
|
||||
return p
|
||||
|
||||
|
||||
|
||||
def define(host, port, database, word):
|
||||
"""Look up a word using a dict server"""
|
||||
d = defer.Deferred()
|
||||
factory = DictLookupFactory("define", (database, word), d)
|
||||
|
||||
from twisted.internet import reactor
|
||||
reactor.connectTCP(host, port, factory)
|
||||
return d
|
||||
|
||||
|
||||
|
||||
def match(host, port, database, strategy, word):
|
||||
"""Match a word using a dict server"""
|
||||
d = defer.Deferred()
|
||||
factory = DictLookupFactory("match", (database, strategy, word), d)
|
||||
|
||||
from twisted.internet import reactor
|
||||
reactor.connectTCP(host, port, factory)
|
||||
return d
|
||||
|
||||
42
venv/lib/python3.9/site-packages/twisted/protocols/finger.py
Normal file
42
venv/lib/python3.9/site-packages/twisted/protocols/finger.py
Normal file
|
|
@ -0,0 +1,42 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
|
||||
"""The Finger User Information Protocol (RFC 1288)"""
|
||||
|
||||
from twisted.protocols import basic
|
||||
|
||||
class Finger(basic.LineReceiver):
|
||||
|
||||
def lineReceived(self, line):
|
||||
parts = line.split()
|
||||
if not parts:
|
||||
parts = [b'']
|
||||
if len(parts) == 1:
|
||||
slash_w = 0
|
||||
else:
|
||||
slash_w = 1
|
||||
user = parts[-1]
|
||||
if b'@' in user:
|
||||
hostPlace = user.rfind(b'@')
|
||||
user = user[:hostPlace]
|
||||
host = user[hostPlace+1:]
|
||||
return self.forwardQuery(slash_w, user, host)
|
||||
if user:
|
||||
return self.getUser(slash_w, user)
|
||||
else:
|
||||
return self.getDomain(slash_w)
|
||||
|
||||
def _refuseMessage(self, message):
|
||||
self.transport.write(message + b"\n")
|
||||
self.transport.loseConnection()
|
||||
|
||||
def forwardQuery(self, slash_w, user, host):
|
||||
self._refuseMessage(b'Finger forwarding service denied')
|
||||
|
||||
def getDomain(self, slash_w):
|
||||
self._refuseMessage(b'Finger online list denied')
|
||||
|
||||
def getUser(self, slash_w, user):
|
||||
self.transport.write(b'Login: ' + user + b'\n')
|
||||
self._refuseMessage(b'No such user')
|
||||
3374
venv/lib/python3.9/site-packages/twisted/protocols/ftp.py
Normal file
3374
venv/lib/python3.9/site-packages/twisted/protocols/ftp.py
Normal file
File diff suppressed because it is too large
Load diff
|
|
@ -0,0 +1,13 @@
|
|||
# -*- test-case-name: twisted.protocols.haproxy.test -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
HAProxy PROXY protocol implementations.
|
||||
"""
|
||||
|
||||
from ._wrapper import proxyEndpoint
|
||||
|
||||
__all__ = [
|
||||
'proxyEndpoint',
|
||||
]
|
||||
|
|
@ -0,0 +1,52 @@
|
|||
# -*- test-case-name: twisted.protocols.haproxy.test -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
HAProxy specific exceptions.
|
||||
"""
|
||||
|
||||
import contextlib
|
||||
import sys
|
||||
|
||||
from twisted.python import compat
|
||||
|
||||
|
||||
class InvalidProxyHeader(Exception):
|
||||
"""
|
||||
The provided PROXY protocol header is invalid.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
class InvalidNetworkProtocol(InvalidProxyHeader):
|
||||
"""
|
||||
The network protocol was not one of TCP4 TCP6 or UNKNOWN.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
class MissingAddressData(InvalidProxyHeader):
|
||||
"""
|
||||
The address data is missing or incomplete.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
@contextlib.contextmanager
|
||||
def convertError(sourceType, targetType):
|
||||
"""
|
||||
Convert an error into a different error type.
|
||||
|
||||
@param sourceType: The type of exception that should be caught and
|
||||
converted.
|
||||
@type sourceType: L{Exception}
|
||||
|
||||
@param targetType: The type of exception to which the original should be
|
||||
converted.
|
||||
@type targetType: L{Exception}
|
||||
"""
|
||||
try:
|
||||
yield None
|
||||
except sourceType:
|
||||
compat.reraise(targetType(), sys.exc_info()[-1])
|
||||
|
|
@ -0,0 +1,36 @@
|
|||
# -*- test-case-name: twisted.protocols.haproxy.test -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
IProxyInfo implementation.
|
||||
"""
|
||||
|
||||
from zope.interface import implementer
|
||||
|
||||
from ._interfaces import IProxyInfo
|
||||
|
||||
|
||||
@implementer(IProxyInfo)
|
||||
class ProxyInfo(object):
|
||||
"""
|
||||
A data container for parsed PROXY protocol information.
|
||||
|
||||
@ivar header: The raw header bytes extracted from the connection.
|
||||
@type header: bytes
|
||||
@ivar source: The connection source address.
|
||||
@type source: L{twisted.internet.interfaces.IAddress}
|
||||
@ivar destination: The connection destination address.
|
||||
@type destination: L{twisted.internet.interfaces.IAddress}
|
||||
"""
|
||||
|
||||
__slots__ = (
|
||||
'header',
|
||||
'source',
|
||||
'destination',
|
||||
)
|
||||
|
||||
def __init__(self, header, source, destination):
|
||||
self.header = header
|
||||
self.source = source
|
||||
self.destination = destination
|
||||
|
|
@ -0,0 +1,64 @@
|
|||
# -*- test-case-name: twisted.protocols.haproxy.test -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Interfaces used by the PROXY protocol modules.
|
||||
"""
|
||||
|
||||
import zope.interface
|
||||
|
||||
|
||||
class IProxyInfo(zope.interface.Interface):
|
||||
"""
|
||||
Data container for PROXY protocol header data.
|
||||
"""
|
||||
|
||||
header = zope.interface.Attribute(
|
||||
"The raw byestring that represents the PROXY protocol header.",
|
||||
)
|
||||
source = zope.interface.Attribute(
|
||||
"An L{twisted.internet.interfaces.IAddress} representing the "
|
||||
"connection source."
|
||||
)
|
||||
destination = zope.interface.Attribute(
|
||||
"An L{twisted.internet.interfaces.IAddress} representing the "
|
||||
"connection destination."
|
||||
)
|
||||
|
||||
|
||||
|
||||
class IProxyParser(zope.interface.Interface):
|
||||
"""
|
||||
Streaming parser that handles PROXY protocol headers.
|
||||
"""
|
||||
|
||||
def feed(self, data):
|
||||
"""
|
||||
Consume a chunk of data and attempt to parse it.
|
||||
|
||||
@param data: A bytestring.
|
||||
@type data: bytes
|
||||
|
||||
@return: A two-tuple containing, in order, an L{IProxyInfo} and any
|
||||
bytes fed to the parser that followed the end of the header. Both
|
||||
of these values are None until a complete header is parsed.
|
||||
|
||||
@raises InvalidProxyHeader: If the bytes fed to the parser create an
|
||||
invalid PROXY header.
|
||||
"""
|
||||
|
||||
|
||||
def parse(self, line):
|
||||
"""
|
||||
Parse a bytestring as a full PROXY protocol header line.
|
||||
|
||||
@param line: A bytestring that represents a valid HAProxy PROXY
|
||||
protocol header line.
|
||||
@type line: bytes
|
||||
|
||||
@return: An L{IProxyInfo} containing the parsed data.
|
||||
|
||||
@raises InvalidProxyHeader: If the bytestring does not represent a
|
||||
valid PROXY header.
|
||||
"""
|
||||
|
|
@ -0,0 +1,71 @@
|
|||
# -*- test-case-name: twisted.protocols.haproxy.test.test_parser -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Parser for 'haproxy:' string endpoint.
|
||||
"""
|
||||
|
||||
from zope.interface import implementer
|
||||
from twisted.plugin import IPlugin
|
||||
|
||||
from twisted.internet.endpoints import (
|
||||
quoteStringArgument, serverFromString, IStreamServerEndpointStringParser
|
||||
)
|
||||
from twisted.python.compat import iteritems
|
||||
|
||||
from . import proxyEndpoint
|
||||
|
||||
|
||||
def unparseEndpoint(args, kwargs):
|
||||
"""
|
||||
Un-parse the already-parsed args and kwargs back into endpoint syntax.
|
||||
|
||||
@param args: C{:}-separated arguments
|
||||
@type args: L{tuple} of native L{str}
|
||||
|
||||
@param kwargs: C{:} and then C{=}-separated keyword arguments
|
||||
|
||||
@type arguments: L{tuple} of native L{str}
|
||||
|
||||
@return: a string equivalent to the original format which this was parsed
|
||||
as.
|
||||
@rtype: native L{str}
|
||||
"""
|
||||
|
||||
description = ':'.join(
|
||||
[quoteStringArgument(str(arg)) for arg in args] +
|
||||
sorted(['%s=%s' % (quoteStringArgument(str(key)),
|
||||
quoteStringArgument(str(value)))
|
||||
for key, value in iteritems(kwargs)
|
||||
]))
|
||||
return description
|
||||
|
||||
|
||||
|
||||
@implementer(IPlugin, IStreamServerEndpointStringParser)
|
||||
class HAProxyServerParser(object):
|
||||
"""
|
||||
Stream server endpoint string parser for the HAProxyServerEndpoint type.
|
||||
|
||||
@ivar prefix: See L{IStreamServerEndpointStringParser.prefix}.
|
||||
"""
|
||||
prefix = "haproxy"
|
||||
|
||||
def parseStreamServer(self, reactor, *args, **kwargs):
|
||||
"""
|
||||
Parse a stream server endpoint from a reactor and string-only arguments
|
||||
and keyword arguments.
|
||||
|
||||
@param reactor: The reactor.
|
||||
|
||||
@param args: The parsed string arguments.
|
||||
|
||||
@param kwargs: The parsed keyword arguments.
|
||||
|
||||
@return: a stream server endpoint
|
||||
@rtype: L{IStreamServerEndpoint}
|
||||
"""
|
||||
subdescription = unparseEndpoint(args, kwargs)
|
||||
wrappedEndpoint = serverFromString(reactor, subdescription)
|
||||
return proxyEndpoint(wrappedEndpoint)
|
||||
|
|
@ -0,0 +1,143 @@
|
|||
# -*- test-case-name: twisted.protocols.haproxy.test.test_v1parser -*-
|
||||
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
IProxyParser implementation for version one of the PROXY protocol.
|
||||
"""
|
||||
|
||||
from zope.interface import implementer
|
||||
from twisted.internet import address
|
||||
|
||||
from ._exceptions import (
|
||||
convertError, InvalidProxyHeader, InvalidNetworkProtocol,
|
||||
MissingAddressData
|
||||
)
|
||||
from . import _info
|
||||
from . import _interfaces
|
||||
|
||||
|
||||
|
||||
@implementer(_interfaces.IProxyParser)
|
||||
class V1Parser(object):
|
||||
"""
|
||||
PROXY protocol version one header parser.
|
||||
|
||||
Version one of the PROXY protocol is a human readable format represented
|
||||
by a single, newline delimited binary string that contains all of the
|
||||
relevant source and destination data.
|
||||
"""
|
||||
|
||||
PROXYSTR = b'PROXY'
|
||||
UNKNOWN_PROTO = b'UNKNOWN'
|
||||
TCP4_PROTO = b'TCP4'
|
||||
TCP6_PROTO = b'TCP6'
|
||||
ALLOWED_NET_PROTOS = (
|
||||
TCP4_PROTO,
|
||||
TCP6_PROTO,
|
||||
UNKNOWN_PROTO,
|
||||
)
|
||||
NEWLINE = b'\r\n'
|
||||
|
||||
def __init__(self):
|
||||
self.buffer = b''
|
||||
|
||||
|
||||
def feed(self, data):
|
||||
"""
|
||||
Consume a chunk of data and attempt to parse it.
|
||||
|
||||
@param data: A bytestring.
|
||||
@type data: L{bytes}
|
||||
|
||||
@return: A two-tuple containing, in order, a
|
||||
L{_interfaces.IProxyInfo} and any bytes fed to the
|
||||
parser that followed the end of the header. Both of these values
|
||||
are None until a complete header is parsed.
|
||||
|
||||
@raises InvalidProxyHeader: If the bytes fed to the parser create an
|
||||
invalid PROXY header.
|
||||
"""
|
||||
self.buffer += data
|
||||
if len(self.buffer) > 107 and self.NEWLINE not in self.buffer:
|
||||
raise InvalidProxyHeader()
|
||||
lines = (self.buffer).split(self.NEWLINE, 1)
|
||||
if not len(lines) > 1:
|
||||
return (None, None)
|
||||
self.buffer = b''
|
||||
remaining = lines.pop()
|
||||
header = lines.pop()
|
||||
info = self.parse(header)
|
||||
return (info, remaining)
|
||||
|
||||
|
||||
@classmethod
|
||||
def parse(cls, line):
|
||||
"""
|
||||
Parse a bytestring as a full PROXY protocol header line.
|
||||
|
||||
@param line: A bytestring that represents a valid HAProxy PROXY
|
||||
protocol header line.
|
||||
@type line: bytes
|
||||
|
||||
@return: A L{_interfaces.IProxyInfo} containing the parsed data.
|
||||
|
||||
@raises InvalidProxyHeader: If the bytestring does not represent a
|
||||
valid PROXY header.
|
||||
|
||||
@raises InvalidNetworkProtocol: When no protocol can be parsed or is
|
||||
not one of the allowed values.
|
||||
|
||||
@raises MissingAddressData: When the protocol is TCP* but the header
|
||||
does not contain a complete set of addresses and ports.
|
||||
"""
|
||||
originalLine = line
|
||||
proxyStr = None
|
||||
networkProtocol = None
|
||||
sourceAddr = None
|
||||
sourcePort = None
|
||||
destAddr = None
|
||||
destPort = None
|
||||
|
||||
with convertError(ValueError, InvalidProxyHeader):
|
||||
proxyStr, line = line.split(b' ', 1)
|
||||
|
||||
if proxyStr != cls.PROXYSTR:
|
||||
raise InvalidProxyHeader()
|
||||
|
||||
with convertError(ValueError, InvalidNetworkProtocol):
|
||||
networkProtocol, line = line.split(b' ', 1)
|
||||
|
||||
if networkProtocol not in cls.ALLOWED_NET_PROTOS:
|
||||
raise InvalidNetworkProtocol()
|
||||
|
||||
if networkProtocol == cls.UNKNOWN_PROTO:
|
||||
|
||||
return _info.ProxyInfo(originalLine, None, None)
|
||||
|
||||
with convertError(ValueError, MissingAddressData):
|
||||
sourceAddr, line = line.split(b' ', 1)
|
||||
|
||||
with convertError(ValueError, MissingAddressData):
|
||||
destAddr, line = line.split(b' ', 1)
|
||||
|
||||
with convertError(ValueError, MissingAddressData):
|
||||
sourcePort, line = line.split(b' ', 1)
|
||||
|
||||
with convertError(ValueError, MissingAddressData):
|
||||
destPort = line.split(b' ')[0]
|
||||
|
||||
if networkProtocol == cls.TCP4_PROTO:
|
||||
|
||||
return _info.ProxyInfo(
|
||||
originalLine,
|
||||
address.IPv4Address('TCP', sourceAddr, int(sourcePort)),
|
||||
address.IPv4Address('TCP', destAddr, int(destPort)),
|
||||
)
|
||||
|
||||
return _info.ProxyInfo(
|
||||
originalLine,
|
||||
address.IPv6Address('TCP', sourceAddr, int(sourcePort)),
|
||||
address.IPv6Address('TCP', destAddr, int(destPort)),
|
||||
)
|
||||
|
|
@ -0,0 +1,215 @@
|
|||
# -*- test-case-name: twisted.protocols.haproxy.test.test_v2parser -*-
|
||||
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
IProxyParser implementation for version two of the PROXY protocol.
|
||||
"""
|
||||
|
||||
import binascii
|
||||
import struct
|
||||
|
||||
from constantly import Values, ValueConstant
|
||||
|
||||
from zope.interface import implementer
|
||||
from twisted.internet import address
|
||||
from twisted.python import compat
|
||||
|
||||
from ._exceptions import (
|
||||
convertError, InvalidProxyHeader, InvalidNetworkProtocol,
|
||||
MissingAddressData
|
||||
)
|
||||
from . import _info
|
||||
from . import _interfaces
|
||||
|
||||
class NetFamily(Values):
|
||||
"""
|
||||
Values for the 'family' field.
|
||||
"""
|
||||
UNSPEC = ValueConstant(0x00)
|
||||
INET = ValueConstant(0x10)
|
||||
INET6 = ValueConstant(0x20)
|
||||
UNIX = ValueConstant(0x30)
|
||||
|
||||
|
||||
|
||||
class NetProtocol(Values):
|
||||
"""
|
||||
Values for 'protocol' field.
|
||||
"""
|
||||
UNSPEC = ValueConstant(0)
|
||||
STREAM = ValueConstant(1)
|
||||
DGRAM = ValueConstant(2)
|
||||
|
||||
|
||||
_HIGH = 0b11110000
|
||||
_LOW = 0b00001111
|
||||
_LOCALCOMMAND = 'LOCAL'
|
||||
_PROXYCOMMAND = 'PROXY'
|
||||
|
||||
@implementer(_interfaces.IProxyParser)
|
||||
class V2Parser(object):
|
||||
"""
|
||||
PROXY protocol version two header parser.
|
||||
|
||||
Version two of the PROXY protocol is a binary format.
|
||||
"""
|
||||
|
||||
PREFIX = b'\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A'
|
||||
VERSIONS = [32]
|
||||
COMMANDS = {0: _LOCALCOMMAND, 1: _PROXYCOMMAND}
|
||||
ADDRESSFORMATS = {
|
||||
# TCP4
|
||||
17: '!4s4s2H',
|
||||
18: '!4s4s2H',
|
||||
# TCP6
|
||||
33: '!16s16s2H',
|
||||
34: '!16s16s2H',
|
||||
# UNIX
|
||||
49: '!108s108s',
|
||||
50: '!108s108s',
|
||||
}
|
||||
|
||||
def __init__(self):
|
||||
self.buffer = b''
|
||||
|
||||
|
||||
def feed(self, data):
|
||||
"""
|
||||
Consume a chunk of data and attempt to parse it.
|
||||
|
||||
@param data: A bytestring.
|
||||
@type data: bytes
|
||||
|
||||
@return: A two-tuple containing, in order, a L{_interfaces.IProxyInfo}
|
||||
and any bytes fed to the parser that followed the end of the
|
||||
header. Both of these values are None until a complete header is
|
||||
parsed.
|
||||
|
||||
@raises InvalidProxyHeader: If the bytes fed to the parser create an
|
||||
invalid PROXY header.
|
||||
"""
|
||||
self.buffer += data
|
||||
if len(self.buffer) < 16:
|
||||
raise InvalidProxyHeader()
|
||||
|
||||
size = struct.unpack('!H', self.buffer[14:16])[0] + 16
|
||||
if len(self.buffer) < size:
|
||||
return (None, None)
|
||||
|
||||
header, remaining = self.buffer[:size], self.buffer[size:]
|
||||
self.buffer = b''
|
||||
info = self.parse(header)
|
||||
return (info, remaining)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _bytesToIPv4(bytestring):
|
||||
"""
|
||||
Convert packed 32-bit IPv4 address bytes into a dotted-quad ASCII bytes
|
||||
representation of that address.
|
||||
|
||||
@param bytestring: 4 octets representing an IPv4 address.
|
||||
@type bytestring: L{bytes}
|
||||
|
||||
@return: a dotted-quad notation IPv4 address.
|
||||
@rtype: L{bytes}
|
||||
"""
|
||||
return b'.'.join(
|
||||
('%i' % (ord(b),)).encode('ascii')
|
||||
for b in compat.iterbytes(bytestring)
|
||||
)
|
||||
|
||||
|
||||
@staticmethod
|
||||
def _bytesToIPv6(bytestring):
|
||||
"""
|
||||
Convert packed 128-bit IPv6 address bytes into a colon-separated ASCII
|
||||
bytes representation of that address.
|
||||
|
||||
@param bytestring: 16 octets representing an IPv6 address.
|
||||
@type bytestring: L{bytes}
|
||||
|
||||
@return: a dotted-quad notation IPv6 address.
|
||||
@rtype: L{bytes}
|
||||
"""
|
||||
hexString = binascii.b2a_hex(bytestring)
|
||||
return b':'.join(
|
||||
('%x' % (int(hexString[b:b+4], 16),)).encode('ascii')
|
||||
for b in range(0, 32, 4)
|
||||
)
|
||||
|
||||
|
||||
@classmethod
|
||||
def parse(cls, line):
|
||||
"""
|
||||
Parse a bytestring as a full PROXY protocol header.
|
||||
|
||||
@param line: A bytestring that represents a valid HAProxy PROXY
|
||||
protocol version 2 header.
|
||||
@type line: bytes
|
||||
|
||||
@return: A L{_interfaces.IProxyInfo} containing the
|
||||
parsed data.
|
||||
|
||||
@raises InvalidProxyHeader: If the bytestring does not represent a
|
||||
valid PROXY header.
|
||||
"""
|
||||
prefix = line[:12]
|
||||
addrInfo = None
|
||||
with convertError(IndexError, InvalidProxyHeader):
|
||||
# Use single value slices to ensure bytestring values are returned
|
||||
# instead of int in PY3.
|
||||
versionCommand = ord(line[12:13])
|
||||
familyProto = ord(line[13:14])
|
||||
|
||||
if prefix != cls.PREFIX:
|
||||
raise InvalidProxyHeader()
|
||||
|
||||
version, command = versionCommand & _HIGH, versionCommand & _LOW
|
||||
if version not in cls.VERSIONS or command not in cls.COMMANDS:
|
||||
raise InvalidProxyHeader()
|
||||
|
||||
if cls.COMMANDS[command] == _LOCALCOMMAND:
|
||||
return _info.ProxyInfo(line, None, None)
|
||||
|
||||
family, netproto = familyProto & _HIGH, familyProto & _LOW
|
||||
with convertError(ValueError, InvalidNetworkProtocol):
|
||||
family = NetFamily.lookupByValue(family)
|
||||
netproto = NetProtocol.lookupByValue(netproto)
|
||||
if (
|
||||
family is NetFamily.UNSPEC or
|
||||
netproto is NetProtocol.UNSPEC
|
||||
):
|
||||
return _info.ProxyInfo(line, None, None)
|
||||
|
||||
addressFormat = cls.ADDRESSFORMATS[familyProto]
|
||||
addrInfo = line[16:16+struct.calcsize(addressFormat)]
|
||||
if family is NetFamily.UNIX:
|
||||
with convertError(struct.error, MissingAddressData):
|
||||
source, dest = struct.unpack(addressFormat, addrInfo)
|
||||
return _info.ProxyInfo(
|
||||
line,
|
||||
address.UNIXAddress(source.rstrip(b'\x00')),
|
||||
address.UNIXAddress(dest.rstrip(b'\x00')),
|
||||
)
|
||||
|
||||
addrType = 'TCP'
|
||||
if netproto is NetProtocol.DGRAM:
|
||||
addrType = 'UDP'
|
||||
addrCls = address.IPv4Address
|
||||
addrParser = cls._bytesToIPv4
|
||||
if family is NetFamily.INET6:
|
||||
addrCls = address.IPv6Address
|
||||
addrParser = cls._bytesToIPv6
|
||||
|
||||
with convertError(struct.error, MissingAddressData):
|
||||
info = struct.unpack(addressFormat, addrInfo)
|
||||
source, dest, sPort, dPort = info
|
||||
|
||||
return _info.ProxyInfo(
|
||||
line,
|
||||
addrCls(addrType, addrParser(source), sPort),
|
||||
addrCls(addrType, addrParser(dest), dPort),
|
||||
)
|
||||
|
|
@ -0,0 +1,106 @@
|
|||
# -*- test-case-name: twisted.protocols.haproxy.test.test_wrapper -*-
|
||||
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Protocol wrapper that provides HAProxy PROXY protocol support.
|
||||
"""
|
||||
|
||||
from twisted.protocols import policies
|
||||
from twisted.internet import interfaces
|
||||
from twisted.internet.endpoints import _WrapperServerEndpoint
|
||||
|
||||
from ._exceptions import InvalidProxyHeader
|
||||
from ._v1parser import V1Parser
|
||||
from ._v2parser import V2Parser
|
||||
|
||||
|
||||
|
||||
class HAProxyProtocolWrapper(policies.ProtocolWrapper, object):
|
||||
"""
|
||||
A Protocol wrapper that provides HAProxy support.
|
||||
|
||||
This protocol reads the PROXY stream header, v1 or v2, parses the provided
|
||||
connection data, and modifies the behavior of getPeer and getHost to return
|
||||
the data provided by the PROXY header.
|
||||
"""
|
||||
|
||||
def __init__(self, factory, wrappedProtocol):
|
||||
policies.ProtocolWrapper.__init__(self, factory, wrappedProtocol)
|
||||
self._proxyInfo = None
|
||||
self._parser = None
|
||||
|
||||
|
||||
def dataReceived(self, data):
|
||||
if self._proxyInfo is not None:
|
||||
return self.wrappedProtocol.dataReceived(data)
|
||||
|
||||
if self._parser is None:
|
||||
if (
|
||||
len(data) >= 16 and
|
||||
data[:12] == V2Parser.PREFIX and
|
||||
ord(data[12:13]) & 0b11110000 == 0x20
|
||||
):
|
||||
self._parser = V2Parser()
|
||||
elif len(data) >= 8 and data[:5] == V1Parser.PROXYSTR:
|
||||
self._parser = V1Parser()
|
||||
else:
|
||||
self.loseConnection()
|
||||
return None
|
||||
|
||||
try:
|
||||
self._proxyInfo, remaining = self._parser.feed(data)
|
||||
if remaining:
|
||||
self.wrappedProtocol.dataReceived(remaining)
|
||||
except InvalidProxyHeader:
|
||||
self.loseConnection()
|
||||
|
||||
|
||||
def getPeer(self):
|
||||
if self._proxyInfo and self._proxyInfo.source:
|
||||
return self._proxyInfo.source
|
||||
return self.transport.getPeer()
|
||||
|
||||
|
||||
def getHost(self):
|
||||
if self._proxyInfo and self._proxyInfo.destination:
|
||||
return self._proxyInfo.destination
|
||||
return self.transport.getHost()
|
||||
|
||||
|
||||
|
||||
class HAProxyWrappingFactory(policies.WrappingFactory):
|
||||
"""
|
||||
A Factory wrapper that adds PROXY protocol support to connections.
|
||||
"""
|
||||
protocol = HAProxyProtocolWrapper
|
||||
|
||||
def logPrefix(self):
|
||||
"""
|
||||
Annotate the wrapped factory's log prefix with some text indicating
|
||||
the PROXY protocol is in use.
|
||||
|
||||
@rtype: C{str}
|
||||
"""
|
||||
if interfaces.ILoggingContext.providedBy(self.wrappedFactory):
|
||||
logPrefix = self.wrappedFactory.logPrefix()
|
||||
else:
|
||||
logPrefix = self.wrappedFactory.__class__.__name__
|
||||
return "%s (PROXY)" % (logPrefix,)
|
||||
|
||||
|
||||
|
||||
def proxyEndpoint(wrappedEndpoint):
|
||||
"""
|
||||
Wrap an endpoint with PROXY protocol support, so that the transport's
|
||||
C{getHost} and C{getPeer} methods reflect the attributes of the proxied
|
||||
connection rather than the underlying connection.
|
||||
|
||||
@param wrappedEndpoint: The underlying listening endpoint.
|
||||
@type wrappedEndpoint: L{IStreamServerEndpoint}
|
||||
|
||||
@return: a new listening endpoint that speaks the PROXY protocol.
|
||||
@rtype: L{IStreamServerEndpoint}
|
||||
"""
|
||||
return _WrapperServerEndpoint(wrappedEndpoint, HAProxyWrappingFactory)
|
||||
|
|
@ -0,0 +1,7 @@
|
|||
# -*- test-case-name: twisted.protocols.haproxy.test -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Unit tests for L{twisted.protocols.haproxy}.
|
||||
"""
|
||||
|
|
@ -0,0 +1,132 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Tests for L{twisted.protocols.haproxy._parser}.
|
||||
"""
|
||||
|
||||
from twisted.trial.unittest import SynchronousTestCase as TestCase
|
||||
from twisted.test.proto_helpers import MemoryReactor
|
||||
from twisted.internet.endpoints import (
|
||||
_WrapperServerEndpoint, TCP4ServerEndpoint, TCP6ServerEndpoint,
|
||||
UNIXServerEndpoint, serverFromString, _parse as parseEndpoint
|
||||
)
|
||||
|
||||
from .._wrapper import HAProxyWrappingFactory
|
||||
from .._parser import unparseEndpoint
|
||||
|
||||
|
||||
|
||||
class UnparseEndpointTests(TestCase):
|
||||
"""
|
||||
Tests to ensure that un-parsing an endpoint string round trips through
|
||||
escaping properly.
|
||||
"""
|
||||
|
||||
def check(self, input):
|
||||
"""
|
||||
Check that the input unparses into the output, raising an assertion
|
||||
error if it doesn't.
|
||||
|
||||
@param input: an input in endpoint-string-description format. (To
|
||||
ensure determinism, keyword arguments should be in alphabetical
|
||||
order.)
|
||||
@type input: native L{str}
|
||||
"""
|
||||
self.assertEqual(unparseEndpoint(*parseEndpoint(input)), input)
|
||||
|
||||
|
||||
def test_basicUnparse(self):
|
||||
"""
|
||||
An individual word.
|
||||
"""
|
||||
self.check("word")
|
||||
|
||||
|
||||
def test_multipleArguments(self):
|
||||
"""
|
||||
Multiple arguments.
|
||||
"""
|
||||
self.check("one:two")
|
||||
|
||||
|
||||
def test_keywords(self):
|
||||
"""
|
||||
Keyword arguments.
|
||||
"""
|
||||
self.check("aleph=one:bet=two")
|
||||
|
||||
|
||||
def test_colonInArgument(self):
|
||||
"""
|
||||
Escaped ":".
|
||||
"""
|
||||
self.check("hello\\:colon\\:world")
|
||||
|
||||
|
||||
def test_colonInKeywordValue(self):
|
||||
"""
|
||||
Escaped ":" in keyword value.
|
||||
"""
|
||||
self.check("hello=\\:")
|
||||
|
||||
|
||||
def test_colonInKeywordName(self):
|
||||
"""
|
||||
Escaped ":" in keyword name.
|
||||
"""
|
||||
self.check("\\:=hello")
|
||||
|
||||
|
||||
|
||||
class HAProxyServerParserTests(TestCase):
|
||||
"""
|
||||
Tests that the parser generates the correct endpoints.
|
||||
"""
|
||||
|
||||
def onePrefix(self, description, expectedClass):
|
||||
"""
|
||||
Test the C{haproxy} enpdoint prefix against one sub-endpoint type.
|
||||
|
||||
@param description: A string endpoint description beginning with
|
||||
C{haproxy}.
|
||||
@type description: native L{str}
|
||||
|
||||
@param expectedClass: the expected sub-endpoint class given the
|
||||
description.
|
||||
@type expectedClass: L{type}
|
||||
|
||||
@return: the parsed endpoint
|
||||
@rtype: L{IStreamServerEndpoint}
|
||||
|
||||
@raise twisted.trial.unittest.Failtest: if the parsed endpoint doesn't
|
||||
match expectations.
|
||||
"""
|
||||
reactor = MemoryReactor()
|
||||
endpoint = serverFromString(reactor, description)
|
||||
self.assertIsInstance(endpoint, _WrapperServerEndpoint)
|
||||
self.assertIsInstance(endpoint._wrappedEndpoint, expectedClass)
|
||||
self.assertIs(endpoint._wrapperFactory, HAProxyWrappingFactory)
|
||||
return endpoint
|
||||
|
||||
|
||||
def test_tcp4(self):
|
||||
"""
|
||||
Test if the parser generates a wrapped TCP4 endpoint.
|
||||
"""
|
||||
self.onePrefix('haproxy:tcp:8080', TCP4ServerEndpoint)
|
||||
|
||||
|
||||
def test_tcp6(self):
|
||||
"""
|
||||
Test if the parser generates a wrapped TCP6 endpoint.
|
||||
"""
|
||||
self.onePrefix('haproxy:tcp6:8080', TCP6ServerEndpoint)
|
||||
|
||||
|
||||
def test_unix(self):
|
||||
"""
|
||||
Test if the parser generates a wrapped UNIX endpoint.
|
||||
"""
|
||||
self.onePrefix('haproxy:unix:address=/tmp/socket', UNIXServerEndpoint)
|
||||
|
||||
|
|
@ -0,0 +1,153 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Test cases for L{twisted.protocols.haproxy.V1Parser}.
|
||||
"""
|
||||
|
||||
from twisted.trial import unittest
|
||||
from twisted.internet import address
|
||||
|
||||
from .._exceptions import (
|
||||
InvalidProxyHeader, InvalidNetworkProtocol, MissingAddressData
|
||||
)
|
||||
from .. import _v1parser
|
||||
|
||||
|
||||
class V1ParserTests(unittest.TestCase):
|
||||
"""
|
||||
Test L{twisted.protocols.haproxy.V1Parser} behaviour.
|
||||
"""
|
||||
|
||||
def test_missingPROXYHeaderValue(self):
|
||||
"""
|
||||
Test that an exception is raised when the PROXY header is missing.
|
||||
"""
|
||||
self.assertRaises(
|
||||
InvalidProxyHeader,
|
||||
_v1parser.V1Parser.parse,
|
||||
b'NOTPROXY ',
|
||||
)
|
||||
|
||||
|
||||
def test_invalidNetworkProtocol(self):
|
||||
"""
|
||||
Test that an exception is raised when the proto is not TCP or UNKNOWN.
|
||||
"""
|
||||
self.assertRaises(
|
||||
InvalidNetworkProtocol,
|
||||
_v1parser.V1Parser.parse,
|
||||
b'PROXY WUTPROTO ',
|
||||
)
|
||||
|
||||
|
||||
def test_missingSourceData(self):
|
||||
"""
|
||||
Test that an exception is raised when the proto has no source data.
|
||||
"""
|
||||
self.assertRaises(
|
||||
MissingAddressData,
|
||||
_v1parser.V1Parser.parse,
|
||||
b'PROXY TCP4 ',
|
||||
)
|
||||
|
||||
|
||||
def test_missingDestData(self):
|
||||
"""
|
||||
Test that an exception is raised when the proto has no destination.
|
||||
"""
|
||||
self.assertRaises(
|
||||
MissingAddressData,
|
||||
_v1parser.V1Parser.parse,
|
||||
b'PROXY TCP4 127.0.0.1 8080 8888',
|
||||
)
|
||||
|
||||
|
||||
def test_fullParsingSuccess(self):
|
||||
"""
|
||||
Test that parsing is successful for a PROXY header.
|
||||
"""
|
||||
info = _v1parser.V1Parser.parse(
|
||||
b'PROXY TCP4 127.0.0.1 127.0.0.1 8080 8888',
|
||||
)
|
||||
self.assertIsInstance(info.source, address.IPv4Address)
|
||||
self.assertEqual(info.source.host, b'127.0.0.1')
|
||||
self.assertEqual(info.source.port, 8080)
|
||||
self.assertEqual(info.destination.host, b'127.0.0.1')
|
||||
self.assertEqual(info.destination.port, 8888)
|
||||
|
||||
|
||||
def test_fullParsingSuccess_IPv6(self):
|
||||
"""
|
||||
Test that parsing is successful for an IPv6 PROXY header.
|
||||
"""
|
||||
info = _v1parser.V1Parser.parse(
|
||||
b'PROXY TCP6 ::1 ::1 8080 8888',
|
||||
)
|
||||
self.assertIsInstance(info.source, address.IPv6Address)
|
||||
self.assertEqual(info.source.host, b'::1')
|
||||
self.assertEqual(info.source.port, 8080)
|
||||
self.assertEqual(info.destination.host, b'::1')
|
||||
self.assertEqual(info.destination.port, 8888)
|
||||
|
||||
|
||||
def test_fullParsingSuccess_UNKNOWN(self):
|
||||
"""
|
||||
Test that parsing is successful for a UNKNOWN PROXY header.
|
||||
"""
|
||||
info = _v1parser.V1Parser.parse(
|
||||
b'PROXY UNKNOWN anything could go here',
|
||||
)
|
||||
self.assertIsNone(info.source)
|
||||
self.assertIsNone(info.destination)
|
||||
|
||||
|
||||
def test_feedParsing(self):
|
||||
"""
|
||||
Test that parsing happens when fed a complete line.
|
||||
"""
|
||||
parser = _v1parser.V1Parser()
|
||||
info, remaining = parser.feed(b'PROXY TCP4 127.0.0.1 127.0.0.1 ')
|
||||
self.assertFalse(info)
|
||||
self.assertFalse(remaining)
|
||||
info, remaining = parser.feed(b'8080 8888')
|
||||
self.assertFalse(info)
|
||||
self.assertFalse(remaining)
|
||||
info, remaining = parser.feed(b'\r\n')
|
||||
self.assertFalse(remaining)
|
||||
self.assertIsInstance(info.source, address.IPv4Address)
|
||||
self.assertEqual(info.source.host, b'127.0.0.1')
|
||||
self.assertEqual(info.source.port, 8080)
|
||||
self.assertEqual(info.destination.host, b'127.0.0.1')
|
||||
self.assertEqual(info.destination.port, 8888)
|
||||
|
||||
|
||||
def test_feedParsingTooLong(self):
|
||||
"""
|
||||
Test that parsing fails if no newline is found in 108 bytes.
|
||||
"""
|
||||
parser = _v1parser.V1Parser()
|
||||
info, remaining = parser.feed(b'PROXY TCP4 127.0.0.1 127.0.0.1 ')
|
||||
self.assertFalse(info)
|
||||
self.assertFalse(remaining)
|
||||
info, remaining = parser.feed(b'8080 8888')
|
||||
self.assertFalse(info)
|
||||
self.assertFalse(remaining)
|
||||
self.assertRaises(
|
||||
InvalidProxyHeader,
|
||||
parser.feed,
|
||||
b' ' * 100,
|
||||
)
|
||||
|
||||
|
||||
def test_feedParsingOverflow(self):
|
||||
"""
|
||||
Test that parsing leaves overflow bytes in the buffer.
|
||||
"""
|
||||
parser = _v1parser.V1Parser()
|
||||
info, remaining = parser.feed(
|
||||
b'PROXY TCP4 127.0.0.1 127.0.0.1 8080 8888\r\nHTTP/1.1 GET /\r\n',
|
||||
)
|
||||
self.assertTrue(info)
|
||||
self.assertEqual(remaining, b'HTTP/1.1 GET /\r\n')
|
||||
self.assertFalse(parser.buffer)
|
||||
|
|
@ -0,0 +1,380 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Test cases for L{twisted.protocols.haproxy.V2Parser}.
|
||||
"""
|
||||
|
||||
from twisted.trial import unittest
|
||||
from twisted.internet import address
|
||||
|
||||
from .._exceptions import InvalidProxyHeader
|
||||
from .. import _v2parser
|
||||
|
||||
V2_SIGNATURE = b'\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A'
|
||||
|
||||
def _makeHeaderIPv6(sig=V2_SIGNATURE, verCom=b'\x21', famProto=b'\x21',
|
||||
addrLength=b'\x00\x24',
|
||||
addrs=((b'\x00' * 15) + b'\x01') * 2,
|
||||
ports=b'\x1F\x90\x22\xB8'):
|
||||
"""
|
||||
Construct a version 2 IPv6 header with custom bytes.
|
||||
|
||||
@param sig: The protocol signature; defaults to valid L{V2_SIGNATURE}.
|
||||
@type sig: L{bytes}
|
||||
|
||||
@param verCom: Protocol version and command. Defaults to V2 PROXY.
|
||||
@type verCom: L{bytes}
|
||||
|
||||
@param famProto: Address family and protocol. Defaults to AF_INET6/STREAM.
|
||||
@type famProto: L{bytes}
|
||||
|
||||
@param addrLength: Network-endian byte length of payload. Defaults to
|
||||
description of default addrs/ports.
|
||||
@type addrLength: L{bytes}
|
||||
|
||||
@param addrs: Address payload. Defaults to C{::1} for source and
|
||||
destination.
|
||||
@type addrs: L{bytes}
|
||||
|
||||
@param ports: Source and destination ports. Defaults to 8080 for source
|
||||
8888 for destination.
|
||||
@type ports: L{bytes}
|
||||
|
||||
@return: A packet with header, addresses, and ports.
|
||||
@rtype: L{bytes}
|
||||
"""
|
||||
return sig + verCom + famProto + addrLength + addrs + ports
|
||||
|
||||
|
||||
|
||||
def _makeHeaderIPv4(sig=V2_SIGNATURE, verCom=b'\x21', famProto=b'\x11',
|
||||
addrLength=b'\x00\x0C',
|
||||
addrs=b'\x7F\x00\x00\x01\x7F\x00\x00\x01',
|
||||
ports=b'\x1F\x90\x22\xB8'):
|
||||
"""
|
||||
Construct a version 2 IPv4 header with custom bytes.
|
||||
|
||||
@param sig: The protocol signature; defaults to valid L{V2_SIGNATURE}.
|
||||
@type sig: L{bytes}
|
||||
|
||||
@param verCom: Protocol version and command. Defaults to V2 PROXY.
|
||||
@type verCom: L{bytes}
|
||||
|
||||
@param famProto: Address family and protocol. Defaults to AF_INET/STREAM.
|
||||
@type famProto: L{bytes}
|
||||
|
||||
@param addrLength: Network-endian byte length of payload. Defaults to
|
||||
description of default addrs/ports.
|
||||
@type addrLength: L{bytes}
|
||||
|
||||
@param addrs: Address payload. Defaults to 127.0.0.1 for source and
|
||||
destination.
|
||||
@type addrs: L{bytes}
|
||||
|
||||
@param ports: Source and destination ports. Defaults to 8080 for source
|
||||
8888 for destination.
|
||||
@type ports: L{bytes}
|
||||
|
||||
@return: A packet with header, addresses, and ports.
|
||||
@rtype: L{bytes}
|
||||
"""
|
||||
return sig + verCom + famProto + addrLength + addrs + ports
|
||||
|
||||
|
||||
|
||||
def _makeHeaderUnix(sig=V2_SIGNATURE, verCom=b'\x21', famProto=b'\x31',
|
||||
addrLength=b'\x00\xD8',
|
||||
addrs=(b'\x2F\x68\x6F\x6D\x65\x2F\x74\x65\x73\x74\x73\x2F'
|
||||
b'\x6D\x79\x73\x6F\x63\x6B\x65\x74\x73\x2F\x73\x6F'
|
||||
b'\x63\x6B' + (b'\x00' * 82)) * 2):
|
||||
"""
|
||||
Construct a version 2 IPv4 header with custom bytes.
|
||||
|
||||
@param sig: The protocol signature; defaults to valid L{V2_SIGNATURE}.
|
||||
@type sig: L{bytes}
|
||||
|
||||
@param verCom: Protocol version and command. Defaults to V2 PROXY.
|
||||
@type verCom: L{bytes}
|
||||
|
||||
@param famProto: Address family and protocol. Defaults to AF_UNIX/STREAM.
|
||||
@type famProto: L{bytes}
|
||||
|
||||
@param addrLength: Network-endian byte length of payload. Defaults to 108
|
||||
bytes for 2 null terminated paths.
|
||||
@type addrLength: L{bytes}
|
||||
|
||||
@param addrs: Address payload. Defaults to C{/home/tests/mysockets/sock}
|
||||
for source and destination paths.
|
||||
@type addrs: L{bytes}
|
||||
|
||||
@return: A packet with header, addresses, and8 ports.
|
||||
@rtype: L{bytes}
|
||||
"""
|
||||
return sig + verCom + famProto + addrLength + addrs
|
||||
|
||||
|
||||
|
||||
class V2ParserTests(unittest.TestCase):
|
||||
"""
|
||||
Test L{twisted.protocols.haproxy.V2Parser} behaviour.
|
||||
"""
|
||||
|
||||
def test_happyPathIPv4(self):
|
||||
"""
|
||||
Test if a well formed IPv4 header is parsed without error.
|
||||
"""
|
||||
header = _makeHeaderIPv4()
|
||||
self.assertTrue(_v2parser.V2Parser.parse(header))
|
||||
|
||||
|
||||
def test_happyPathIPv6(self):
|
||||
"""
|
||||
Test if a well formed IPv6 header is parsed without error.
|
||||
"""
|
||||
header = _makeHeaderIPv6()
|
||||
self.assertTrue(_v2parser.V2Parser.parse(header))
|
||||
|
||||
|
||||
def test_happyPathUnix(self):
|
||||
"""
|
||||
Test if a well formed UNIX header is parsed without error.
|
||||
"""
|
||||
header = _makeHeaderUnix()
|
||||
self.assertTrue(_v2parser.V2Parser.parse(header))
|
||||
|
||||
|
||||
def test_invalidSignature(self):
|
||||
"""
|
||||
Test if an invalid signature block raises InvalidProxyError.
|
||||
"""
|
||||
header = _makeHeaderIPv4(sig=b'\x00'*12)
|
||||
self.assertRaises(
|
||||
InvalidProxyHeader,
|
||||
_v2parser.V2Parser.parse,
|
||||
header,
|
||||
)
|
||||
|
||||
|
||||
def test_invalidVersion(self):
|
||||
"""
|
||||
Test if an invalid version raises InvalidProxyError.
|
||||
"""
|
||||
header = _makeHeaderIPv4(verCom=b'\x11')
|
||||
self.assertRaises(
|
||||
InvalidProxyHeader,
|
||||
_v2parser.V2Parser.parse,
|
||||
header,
|
||||
)
|
||||
|
||||
|
||||
def test_invalidCommand(self):
|
||||
"""
|
||||
Test if an invalid command raises InvalidProxyError.
|
||||
"""
|
||||
header = _makeHeaderIPv4(verCom=b'\x23')
|
||||
self.assertRaises(
|
||||
InvalidProxyHeader,
|
||||
_v2parser.V2Parser.parse,
|
||||
header,
|
||||
)
|
||||
|
||||
|
||||
def test_invalidFamily(self):
|
||||
"""
|
||||
Test if an invalid family raises InvalidProxyError.
|
||||
"""
|
||||
header = _makeHeaderIPv4(famProto=b'\x40')
|
||||
self.assertRaises(
|
||||
InvalidProxyHeader,
|
||||
_v2parser.V2Parser.parse,
|
||||
header,
|
||||
)
|
||||
|
||||
|
||||
def test_invalidProto(self):
|
||||
"""
|
||||
Test if an invalid protocol raises InvalidProxyError.
|
||||
"""
|
||||
header = _makeHeaderIPv4(famProto=b'\x24')
|
||||
self.assertRaises(
|
||||
InvalidProxyHeader,
|
||||
_v2parser.V2Parser.parse,
|
||||
header,
|
||||
)
|
||||
|
||||
|
||||
def test_localCommandIpv4(self):
|
||||
"""
|
||||
Test that local does not return endpoint data for IPv4 connections.
|
||||
"""
|
||||
header = _makeHeaderIPv4(verCom=b'\x20')
|
||||
info = _v2parser.V2Parser.parse(header)
|
||||
self.assertFalse(info.source)
|
||||
self.assertFalse(info.destination)
|
||||
|
||||
|
||||
def test_localCommandIpv6(self):
|
||||
"""
|
||||
Test that local does not return endpoint data for IPv6 connections.
|
||||
"""
|
||||
header = _makeHeaderIPv6(verCom=b'\x20')
|
||||
info = _v2parser.V2Parser.parse(header)
|
||||
self.assertFalse(info.source)
|
||||
self.assertFalse(info.destination)
|
||||
|
||||
|
||||
def test_localCommandUnix(self):
|
||||
"""
|
||||
Test that local does not return endpoint data for UNIX connections.
|
||||
"""
|
||||
header = _makeHeaderUnix(verCom=b'\x20')
|
||||
info = _v2parser.V2Parser.parse(header)
|
||||
self.assertFalse(info.source)
|
||||
self.assertFalse(info.destination)
|
||||
|
||||
|
||||
def test_proxyCommandIpv4(self):
|
||||
"""
|
||||
Test that proxy returns endpoint data for IPv4 connections.
|
||||
"""
|
||||
header = _makeHeaderIPv4(verCom=b'\x21')
|
||||
info = _v2parser.V2Parser.parse(header)
|
||||
self.assertTrue(info.source)
|
||||
self.assertIsInstance(info.source, address.IPv4Address)
|
||||
self.assertTrue(info.destination)
|
||||
self.assertIsInstance(info.destination, address.IPv4Address)
|
||||
|
||||
|
||||
def test_proxyCommandIpv6(self):
|
||||
"""
|
||||
Test that proxy returns endpoint data for IPv6 connections.
|
||||
"""
|
||||
header = _makeHeaderIPv6(verCom=b'\x21')
|
||||
info = _v2parser.V2Parser.parse(header)
|
||||
self.assertTrue(info.source)
|
||||
self.assertIsInstance(info.source, address.IPv6Address)
|
||||
self.assertTrue(info.destination)
|
||||
self.assertIsInstance(info.destination, address.IPv6Address)
|
||||
|
||||
|
||||
def test_proxyCommandUnix(self):
|
||||
"""
|
||||
Test that proxy returns endpoint data for UNIX connections.
|
||||
"""
|
||||
header = _makeHeaderUnix(verCom=b'\x21')
|
||||
info = _v2parser.V2Parser.parse(header)
|
||||
self.assertTrue(info.source)
|
||||
self.assertIsInstance(info.source, address.UNIXAddress)
|
||||
self.assertTrue(info.destination)
|
||||
self.assertIsInstance(info.destination, address.UNIXAddress)
|
||||
|
||||
|
||||
def test_unspecFamilyIpv4(self):
|
||||
"""
|
||||
Test that UNSPEC does not return endpoint data for IPv4 connections.
|
||||
"""
|
||||
header = _makeHeaderIPv4(famProto=b'\x01')
|
||||
info = _v2parser.V2Parser.parse(header)
|
||||
self.assertFalse(info.source)
|
||||
self.assertFalse(info.destination)
|
||||
|
||||
|
||||
def test_unspecFamilyIpv6(self):
|
||||
"""
|
||||
Test that UNSPEC does not return endpoint data for IPv6 connections.
|
||||
"""
|
||||
header = _makeHeaderIPv6(famProto=b'\x01')
|
||||
info = _v2parser.V2Parser.parse(header)
|
||||
self.assertFalse(info.source)
|
||||
self.assertFalse(info.destination)
|
||||
|
||||
|
||||
def test_unspecFamilyUnix(self):
|
||||
"""
|
||||
Test that UNSPEC does not return endpoint data for UNIX connections.
|
||||
"""
|
||||
header = _makeHeaderUnix(famProto=b'\x01')
|
||||
info = _v2parser.V2Parser.parse(header)
|
||||
self.assertFalse(info.source)
|
||||
self.assertFalse(info.destination)
|
||||
|
||||
|
||||
def test_unspecProtoIpv4(self):
|
||||
"""
|
||||
Test that UNSPEC does not return endpoint data for IPv4 connections.
|
||||
"""
|
||||
header = _makeHeaderIPv4(famProto=b'\x10')
|
||||
info = _v2parser.V2Parser.parse(header)
|
||||
self.assertFalse(info.source)
|
||||
self.assertFalse(info.destination)
|
||||
|
||||
|
||||
def test_unspecProtoIpv6(self):
|
||||
"""
|
||||
Test that UNSPEC does not return endpoint data for IPv6 connections.
|
||||
"""
|
||||
header = _makeHeaderIPv6(famProto=b'\x20')
|
||||
info = _v2parser.V2Parser.parse(header)
|
||||
self.assertFalse(info.source)
|
||||
self.assertFalse(info.destination)
|
||||
|
||||
|
||||
def test_unspecProtoUnix(self):
|
||||
"""
|
||||
Test that UNSPEC does not return endpoint data for UNIX connections.
|
||||
"""
|
||||
header = _makeHeaderUnix(famProto=b'\x30')
|
||||
info = _v2parser.V2Parser.parse(header)
|
||||
self.assertFalse(info.source)
|
||||
self.assertFalse(info.destination)
|
||||
|
||||
|
||||
def test_overflowIpv4(self):
|
||||
"""
|
||||
Test that overflow bits are preserved during feed parsing for IPv4.
|
||||
"""
|
||||
testValue = b'TEST DATA\r\n\r\nTEST DATA'
|
||||
header = _makeHeaderIPv4() + testValue
|
||||
parser = _v2parser.V2Parser()
|
||||
info, overflow = parser.feed(header)
|
||||
self.assertTrue(info)
|
||||
self.assertEqual(overflow, testValue)
|
||||
|
||||
|
||||
def test_overflowIpv6(self):
|
||||
"""
|
||||
Test that overflow bits are preserved during feed parsing for IPv6.
|
||||
"""
|
||||
testValue = b'TEST DATA\r\n\r\nTEST DATA'
|
||||
header = _makeHeaderIPv6() + testValue
|
||||
parser = _v2parser.V2Parser()
|
||||
info, overflow = parser.feed(header)
|
||||
self.assertTrue(info)
|
||||
self.assertEqual(overflow, testValue)
|
||||
|
||||
|
||||
def test_overflowUnix(self):
|
||||
"""
|
||||
Test that overflow bits are preserved during feed parsing for Unix.
|
||||
"""
|
||||
testValue = b'TEST DATA\r\n\r\nTEST DATA'
|
||||
header = _makeHeaderUnix() + testValue
|
||||
parser = _v2parser.V2Parser()
|
||||
info, overflow = parser.feed(header)
|
||||
self.assertTrue(info)
|
||||
self.assertEqual(overflow, testValue)
|
||||
|
||||
|
||||
def test_segmentTooSmall(self):
|
||||
"""
|
||||
Test that an initial payload of less than 16 bytes fails.
|
||||
"""
|
||||
testValue = b'NEEDMOREDATA'
|
||||
parser = _v2parser.V2Parser()
|
||||
self.assertRaises(
|
||||
InvalidProxyHeader,
|
||||
parser.feed,
|
||||
testValue,
|
||||
)
|
||||
|
||||
|
|
@ -0,0 +1,367 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Test cases for L{twisted.protocols.haproxy.HAProxyProtocol}.
|
||||
"""
|
||||
|
||||
from twisted.trial import unittest
|
||||
from twisted.internet import address
|
||||
from twisted.internet.protocol import Protocol, Factory
|
||||
from twisted.test.proto_helpers import StringTransportWithDisconnection
|
||||
|
||||
from .._wrapper import HAProxyWrappingFactory
|
||||
|
||||
|
||||
|
||||
class StaticProtocol(Protocol):
|
||||
"""
|
||||
Protocol stand-in that maintains test state.
|
||||
"""
|
||||
|
||||
def __init__(self):
|
||||
self.source = None
|
||||
self.destination = None
|
||||
self.data = b''
|
||||
self.disconnected = False
|
||||
|
||||
|
||||
def dataReceived(self, data):
|
||||
self.source = self.transport.getPeer()
|
||||
self.destination = self.transport.getHost()
|
||||
self.data += data
|
||||
|
||||
|
||||
|
||||
class HAProxyWrappingFactoryV1Tests(unittest.TestCase):
|
||||
"""
|
||||
Test L{twisted.protocols.haproxy.HAProxyWrappingFactory} with v1 PROXY
|
||||
headers.
|
||||
"""
|
||||
|
||||
def test_invalidHeaderDisconnects(self):
|
||||
"""
|
||||
Test if invalid headers result in connectionLost events.
|
||||
"""
|
||||
factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol))
|
||||
proto = factory.buildProtocol(
|
||||
address.IPv4Address('TCP', b'127.1.1.1', 8080),
|
||||
)
|
||||
transport = StringTransportWithDisconnection()
|
||||
transport.protocol = proto
|
||||
proto.makeConnection(transport)
|
||||
proto.dataReceived(b'NOTPROXY anything can go here\r\n')
|
||||
self.assertFalse(transport.connected)
|
||||
|
||||
|
||||
def test_invalidPartialHeaderDisconnects(self):
|
||||
"""
|
||||
Test if invalid headers result in connectionLost events.
|
||||
"""
|
||||
factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol))
|
||||
proto = factory.buildProtocol(
|
||||
address.IPv4Address('TCP', b'127.1.1.1', 8080),
|
||||
)
|
||||
transport = StringTransportWithDisconnection()
|
||||
transport.protocol = proto
|
||||
proto.makeConnection(transport)
|
||||
proto.dataReceived(b'PROXY TCP4 1.1.1.1\r\n')
|
||||
proto.dataReceived(b'2.2.2.2 8080\r\n')
|
||||
self.assertFalse(transport.connected)
|
||||
|
||||
|
||||
def test_validIPv4HeaderResolves_getPeerHost(self):
|
||||
"""
|
||||
Test if IPv4 headers result in the correct host and peer data.
|
||||
"""
|
||||
factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol))
|
||||
proto = factory.buildProtocol(
|
||||
address.IPv4Address('TCP', b'127.0.0.1', 8080),
|
||||
)
|
||||
transport = StringTransportWithDisconnection()
|
||||
proto.makeConnection(transport)
|
||||
proto.dataReceived(b'PROXY TCP4 1.1.1.1 2.2.2.2 8080 8888\r\n')
|
||||
self.assertEqual(proto.getPeer().host, b'1.1.1.1')
|
||||
self.assertEqual(proto.getPeer().port, 8080)
|
||||
self.assertEqual(
|
||||
proto.wrappedProtocol.transport.getPeer().host,
|
||||
b'1.1.1.1',
|
||||
)
|
||||
self.assertEqual(
|
||||
proto.wrappedProtocol.transport.getPeer().port,
|
||||
8080,
|
||||
)
|
||||
self.assertEqual(proto.getHost().host, b'2.2.2.2')
|
||||
self.assertEqual(proto.getHost().port, 8888)
|
||||
self.assertEqual(
|
||||
proto.wrappedProtocol.transport.getHost().host,
|
||||
b'2.2.2.2',
|
||||
)
|
||||
self.assertEqual(
|
||||
proto.wrappedProtocol.transport.getHost().port,
|
||||
8888,
|
||||
)
|
||||
|
||||
|
||||
def test_validIPv6HeaderResolves_getPeerHost(self):
|
||||
"""
|
||||
Test if IPv6 headers result in the correct host and peer data.
|
||||
"""
|
||||
factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol))
|
||||
proto = factory.buildProtocol(
|
||||
address.IPv6Address('TCP', b'::1', 8080),
|
||||
)
|
||||
transport = StringTransportWithDisconnection()
|
||||
proto.makeConnection(transport)
|
||||
proto.dataReceived(b'PROXY TCP6 ::1 ::2 8080 8888\r\n')
|
||||
self.assertEqual(proto.getPeer().host, b'::1')
|
||||
self.assertEqual(proto.getPeer().port, 8080)
|
||||
self.assertEqual(
|
||||
proto.wrappedProtocol.transport.getPeer().host,
|
||||
b'::1',
|
||||
)
|
||||
self.assertEqual(
|
||||
proto.wrappedProtocol.transport.getPeer().port,
|
||||
8080,
|
||||
)
|
||||
self.assertEqual(proto.getHost().host, b'::2')
|
||||
self.assertEqual(proto.getHost().port, 8888)
|
||||
self.assertEqual(
|
||||
proto.wrappedProtocol.transport.getHost().host,
|
||||
b'::2',
|
||||
)
|
||||
self.assertEqual(
|
||||
proto.wrappedProtocol.transport.getHost().port,
|
||||
8888,
|
||||
)
|
||||
|
||||
|
||||
def test_overflowBytesSentToWrappedProtocol(self):
|
||||
"""
|
||||
Test if non-header bytes are passed to the wrapped protocol.
|
||||
"""
|
||||
factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol))
|
||||
proto = factory.buildProtocol(
|
||||
address.IPv6Address('TCP', b'::1', 8080),
|
||||
)
|
||||
transport = StringTransportWithDisconnection()
|
||||
proto.makeConnection(transport)
|
||||
proto.dataReceived(b'PROXY TCP6 ::1 ::2 8080 8888\r\nHTTP/1.1 / GET')
|
||||
self.assertEqual(proto.wrappedProtocol.data, b'HTTP/1.1 / GET')
|
||||
|
||||
|
||||
def test_overflowBytesSentToWrappedProtocolChunks(self):
|
||||
"""
|
||||
Test if header streaming passes extra data appropriately.
|
||||
"""
|
||||
factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol))
|
||||
proto = factory.buildProtocol(
|
||||
address.IPv6Address('TCP', b'::1', 8080),
|
||||
)
|
||||
transport = StringTransportWithDisconnection()
|
||||
proto.makeConnection(transport)
|
||||
proto.dataReceived(b'PROXY TCP6 ::1 ::2 ')
|
||||
proto.dataReceived(b'8080 8888\r\nHTTP/1.1 / GET')
|
||||
self.assertEqual(proto.wrappedProtocol.data, b'HTTP/1.1 / GET')
|
||||
|
||||
|
||||
def test_overflowBytesSentToWrappedProtocolAfter(self):
|
||||
"""
|
||||
Test if wrapper writes all data to wrapped protocol after parsing.
|
||||
"""
|
||||
factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol))
|
||||
proto = factory.buildProtocol(
|
||||
address.IPv6Address('TCP', b'::1', 8080),
|
||||
)
|
||||
transport = StringTransportWithDisconnection()
|
||||
proto.makeConnection(transport)
|
||||
proto.dataReceived(b'PROXY TCP6 ::1 ::2 ')
|
||||
proto.dataReceived(b'8080 8888\r\nHTTP/1.1 / GET')
|
||||
proto.dataReceived(b'\r\n\r\n')
|
||||
self.assertEqual(proto.wrappedProtocol.data, b'HTTP/1.1 / GET\r\n\r\n')
|
||||
|
||||
|
||||
|
||||
class HAProxyWrappingFactoryV2Tests(unittest.TestCase):
|
||||
"""
|
||||
Test L{twisted.protocols.haproxy.HAProxyWrappingFactory} with v2 PROXY
|
||||
headers.
|
||||
"""
|
||||
|
||||
IPV4HEADER = (
|
||||
# V2 Signature
|
||||
b'\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A'
|
||||
# V2 PROXY command
|
||||
b'\x21'
|
||||
# AF_INET/STREAM
|
||||
b'\x11'
|
||||
# 12 bytes for 2 IPv4 addresses and two ports
|
||||
b'\x00\x0C'
|
||||
# 127.0.0.1 for source and destination
|
||||
b'\x7F\x00\x00\x01\x7F\x00\x00\x01'
|
||||
# 8080 for source 8888 for destination
|
||||
b'\x1F\x90\x22\xB8'
|
||||
)
|
||||
IPV6HEADER = (
|
||||
# V2 Signature
|
||||
b'\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A'
|
||||
# V2 PROXY command
|
||||
b'\x21'
|
||||
# AF_INET6/STREAM
|
||||
b'\x21'
|
||||
# 16 bytes for 2 IPv6 addresses and two ports
|
||||
b'\x00\x24'
|
||||
# ::1 for source and destination
|
||||
b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01'
|
||||
b'\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x00\x01'
|
||||
# 8080 for source 8888 for destination
|
||||
b'\x1F\x90\x22\xB8'
|
||||
)
|
||||
|
||||
_SOCK_PATH = (
|
||||
b'\x2F\x68\x6F\x6D\x65\x2F\x74\x65\x73\x74\x73\x2F\x6D\x79\x73\x6F'
|
||||
b'\x63\x6B\x65\x74\x73\x2F\x73\x6F\x63\x6B' + (b'\x00' * 82)
|
||||
)
|
||||
UNIXHEADER = (
|
||||
# V2 Signature
|
||||
b'\x0D\x0A\x0D\x0A\x00\x0D\x0A\x51\x55\x49\x54\x0A'
|
||||
# V2 PROXY command
|
||||
b'\x21'
|
||||
# AF_UNIX/STREAM
|
||||
b'\x31'
|
||||
# 108 bytes for 2 null terminated paths
|
||||
b'\x00\xD8'
|
||||
# /home/tests/mysockets/sock for source and destination paths
|
||||
) + _SOCK_PATH + _SOCK_PATH
|
||||
|
||||
def test_invalidHeaderDisconnects(self):
|
||||
"""
|
||||
Test if invalid headers result in connectionLost events.
|
||||
"""
|
||||
factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol))
|
||||
proto = factory.buildProtocol(
|
||||
address.IPv6Address('TCP', b'::1', 8080),
|
||||
)
|
||||
transport = StringTransportWithDisconnection()
|
||||
transport.protocol = proto
|
||||
proto.makeConnection(transport)
|
||||
proto.dataReceived(b'\x00' + self.IPV4HEADER[1:])
|
||||
self.assertFalse(transport.connected)
|
||||
|
||||
|
||||
def test_validIPv4HeaderResolves_getPeerHost(self):
|
||||
"""
|
||||
Test if IPv4 headers result in the correct host and peer data.
|
||||
"""
|
||||
factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol))
|
||||
proto = factory.buildProtocol(
|
||||
address.IPv4Address('TCP', b'127.0.0.1', 8080),
|
||||
)
|
||||
transport = StringTransportWithDisconnection()
|
||||
proto.makeConnection(transport)
|
||||
proto.dataReceived(self.IPV4HEADER)
|
||||
self.assertEqual(proto.getPeer().host, b'127.0.0.1')
|
||||
self.assertEqual(proto.getPeer().port, 8080)
|
||||
self.assertEqual(
|
||||
proto.wrappedProtocol.transport.getPeer().host,
|
||||
b'127.0.0.1',
|
||||
)
|
||||
self.assertEqual(
|
||||
proto.wrappedProtocol.transport.getPeer().port,
|
||||
8080,
|
||||
)
|
||||
self.assertEqual(proto.getHost().host, b'127.0.0.1')
|
||||
self.assertEqual(proto.getHost().port, 8888)
|
||||
self.assertEqual(
|
||||
proto.wrappedProtocol.transport.getHost().host,
|
||||
b'127.0.0.1',
|
||||
)
|
||||
self.assertEqual(
|
||||
proto.wrappedProtocol.transport.getHost().port,
|
||||
8888,
|
||||
)
|
||||
|
||||
|
||||
def test_validIPv6HeaderResolves_getPeerHost(self):
|
||||
"""
|
||||
Test if IPv6 headers result in the correct host and peer data.
|
||||
"""
|
||||
factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol))
|
||||
proto = factory.buildProtocol(
|
||||
address.IPv4Address('TCP', b'::1', 8080),
|
||||
)
|
||||
transport = StringTransportWithDisconnection()
|
||||
proto.makeConnection(transport)
|
||||
proto.dataReceived(self.IPV6HEADER)
|
||||
self.assertEqual(proto.getPeer().host, b'0:0:0:0:0:0:0:1')
|
||||
self.assertEqual(proto.getPeer().port, 8080)
|
||||
self.assertEqual(
|
||||
proto.wrappedProtocol.transport.getPeer().host,
|
||||
b'0:0:0:0:0:0:0:1',
|
||||
)
|
||||
self.assertEqual(
|
||||
proto.wrappedProtocol.transport.getPeer().port,
|
||||
8080,
|
||||
)
|
||||
self.assertEqual(proto.getHost().host, b'0:0:0:0:0:0:0:1')
|
||||
self.assertEqual(proto.getHost().port, 8888)
|
||||
self.assertEqual(
|
||||
proto.wrappedProtocol.transport.getHost().host,
|
||||
b'0:0:0:0:0:0:0:1',
|
||||
)
|
||||
self.assertEqual(
|
||||
proto.wrappedProtocol.transport.getHost().port,
|
||||
8888,
|
||||
)
|
||||
|
||||
|
||||
def test_validUNIXHeaderResolves_getPeerHost(self):
|
||||
"""
|
||||
Test if UNIX headers result in the correct host and peer data.
|
||||
"""
|
||||
factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol))
|
||||
proto = factory.buildProtocol(
|
||||
address.UNIXAddress(b'/home/test/sockets/server.sock'),
|
||||
)
|
||||
transport = StringTransportWithDisconnection()
|
||||
proto.makeConnection(transport)
|
||||
proto.dataReceived(self.UNIXHEADER)
|
||||
self.assertEqual(proto.getPeer().name, b'/home/tests/mysockets/sock')
|
||||
self.assertEqual(
|
||||
proto.wrappedProtocol.transport.getPeer().name,
|
||||
b'/home/tests/mysockets/sock',
|
||||
)
|
||||
self.assertEqual(proto.getHost().name, b'/home/tests/mysockets/sock')
|
||||
self.assertEqual(
|
||||
proto.wrappedProtocol.transport.getHost().name,
|
||||
b'/home/tests/mysockets/sock',
|
||||
)
|
||||
|
||||
|
||||
def test_overflowBytesSentToWrappedProtocol(self):
|
||||
"""
|
||||
Test if non-header bytes are passed to the wrapped protocol.
|
||||
"""
|
||||
factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol))
|
||||
proto = factory.buildProtocol(
|
||||
address.IPv6Address('TCP', b'::1', 8080),
|
||||
)
|
||||
transport = StringTransportWithDisconnection()
|
||||
proto.makeConnection(transport)
|
||||
proto.dataReceived(self.IPV6HEADER + b'HTTP/1.1 / GET')
|
||||
self.assertEqual(proto.wrappedProtocol.data, b'HTTP/1.1 / GET')
|
||||
|
||||
|
||||
def test_overflowBytesSentToWrappedProtocolChunks(self):
|
||||
"""
|
||||
Test if header streaming passes extra data appropriately.
|
||||
"""
|
||||
factory = HAProxyWrappingFactory(Factory.forProtocol(StaticProtocol))
|
||||
proto = factory.buildProtocol(
|
||||
address.IPv6Address('TCP', b'::1', 8080),
|
||||
)
|
||||
transport = StringTransportWithDisconnection()
|
||||
proto.makeConnection(transport)
|
||||
proto.dataReceived(self.IPV6HEADER[:18])
|
||||
proto.dataReceived(self.IPV6HEADER[18:] + b'HTTP/1.1 / GET')
|
||||
self.assertEqual(proto.wrappedProtocol.data, b'HTTP/1.1 / GET')
|
||||
295
venv/lib/python3.9/site-packages/twisted/protocols/htb.py
Normal file
295
venv/lib/python3.9/site-packages/twisted/protocols/htb.py
Normal file
|
|
@ -0,0 +1,295 @@
|
|||
# -*- test-case-name: twisted.test.test_htb -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
|
||||
"""
|
||||
Hierarchical Token Bucket traffic shaping.
|
||||
|
||||
Patterned after U{Martin Devera's Hierarchical Token Bucket traffic
|
||||
shaper for the Linux kernel<http://luxik.cdi.cz/~devik/qos/htb/>}.
|
||||
|
||||
@seealso: U{HTB Linux queuing discipline manual - user guide
|
||||
<http://luxik.cdi.cz/~devik/qos/htb/manual/userg.htm>}
|
||||
@seealso: U{Token Bucket Filter in Linux Advanced Routing & Traffic Control
|
||||
HOWTO<http://lartc.org/howto/lartc.qdisc.classless.html#AEN682>}
|
||||
"""
|
||||
|
||||
|
||||
# TODO: Investigate whether we should be using os.times()[-1] instead of
|
||||
# time.time. time.time, it has been pointed out, can go backwards. Is
|
||||
# the same true of os.times?
|
||||
from time import time
|
||||
from zope.interface import implementer, Interface
|
||||
|
||||
from twisted.protocols import pcp
|
||||
|
||||
|
||||
class Bucket:
|
||||
"""
|
||||
Implementation of a Token bucket.
|
||||
|
||||
A bucket can hold a certain number of tokens and it drains over time.
|
||||
|
||||
@cvar maxburst: The maximum number of tokens that the bucket can
|
||||
hold at any given time. If this is L{None}, the bucket has
|
||||
an infinite size.
|
||||
@type maxburst: C{int}
|
||||
@cvar rate: The rate at which the bucket drains, in number
|
||||
of tokens per second. If the rate is L{None}, the bucket
|
||||
drains instantaneously.
|
||||
@type rate: C{int}
|
||||
"""
|
||||
|
||||
maxburst = None
|
||||
rate = None
|
||||
|
||||
_refcount = 0
|
||||
|
||||
def __init__(self, parentBucket=None):
|
||||
"""
|
||||
Create a L{Bucket} that may have a parent L{Bucket}.
|
||||
|
||||
@param parentBucket: If a parent Bucket is specified,
|
||||
all L{add} and L{drip} operations on this L{Bucket}
|
||||
will be applied on the parent L{Bucket} as well.
|
||||
@type parentBucket: L{Bucket}
|
||||
"""
|
||||
self.content = 0
|
||||
self.parentBucket = parentBucket
|
||||
self.lastDrip = time()
|
||||
|
||||
|
||||
def add(self, amount):
|
||||
"""
|
||||
Adds tokens to the L{Bucket} and its C{parentBucket}.
|
||||
|
||||
This will add as many of the C{amount} tokens as will fit into both
|
||||
this L{Bucket} and its C{parentBucket}.
|
||||
|
||||
@param amount: The number of tokens to try to add.
|
||||
@type amount: C{int}
|
||||
|
||||
@returns: The number of tokens that actually fit.
|
||||
@returntype: C{int}
|
||||
"""
|
||||
self.drip()
|
||||
if self.maxburst is None:
|
||||
allowable = amount
|
||||
else:
|
||||
allowable = min(amount, self.maxburst - self.content)
|
||||
|
||||
if self.parentBucket is not None:
|
||||
allowable = self.parentBucket.add(allowable)
|
||||
self.content += allowable
|
||||
return allowable
|
||||
|
||||
|
||||
def drip(self):
|
||||
"""
|
||||
Let some of the bucket drain.
|
||||
|
||||
The L{Bucket} drains at the rate specified by the class
|
||||
variable C{rate}.
|
||||
|
||||
@returns: C{True} if the bucket is empty after this drip.
|
||||
@returntype: C{bool}
|
||||
"""
|
||||
if self.parentBucket is not None:
|
||||
self.parentBucket.drip()
|
||||
|
||||
if self.rate is None:
|
||||
self.content = 0
|
||||
else:
|
||||
now = time()
|
||||
deltaTime = now - self.lastDrip
|
||||
deltaTokens = deltaTime * self.rate
|
||||
self.content = max(0, self.content - deltaTokens)
|
||||
self.lastDrip = now
|
||||
return self.content == 0
|
||||
|
||||
|
||||
class IBucketFilter(Interface):
|
||||
def getBucketFor(*somethings, **some_kw):
|
||||
"""
|
||||
Return a L{Bucket} corresponding to the provided parameters.
|
||||
|
||||
@returntype: L{Bucket}
|
||||
"""
|
||||
|
||||
@implementer(IBucketFilter)
|
||||
class HierarchicalBucketFilter:
|
||||
"""
|
||||
Filter things into buckets that can be nested.
|
||||
|
||||
@cvar bucketFactory: Class of buckets to make.
|
||||
@type bucketFactory: L{Bucket}
|
||||
@cvar sweepInterval: Seconds between sweeping out the bucket cache.
|
||||
@type sweepInterval: C{int}
|
||||
"""
|
||||
bucketFactory = Bucket
|
||||
sweepInterval = None
|
||||
|
||||
def __init__(self, parentFilter=None):
|
||||
self.buckets = {}
|
||||
self.parentFilter = parentFilter
|
||||
self.lastSweep = time()
|
||||
|
||||
def getBucketFor(self, *a, **kw):
|
||||
"""
|
||||
Find or create a L{Bucket} corresponding to the provided parameters.
|
||||
|
||||
Any parameters are passed on to L{getBucketKey}, from them it
|
||||
decides which bucket you get.
|
||||
|
||||
@returntype: L{Bucket}
|
||||
"""
|
||||
if ((self.sweepInterval is not None)
|
||||
and ((time() - self.lastSweep) > self.sweepInterval)):
|
||||
self.sweep()
|
||||
|
||||
if self.parentFilter:
|
||||
parentBucket = self.parentFilter.getBucketFor(self, *a, **kw)
|
||||
else:
|
||||
parentBucket = None
|
||||
|
||||
key = self.getBucketKey(*a, **kw)
|
||||
bucket = self.buckets.get(key)
|
||||
if bucket is None:
|
||||
bucket = self.bucketFactory(parentBucket)
|
||||
self.buckets[key] = bucket
|
||||
return bucket
|
||||
|
||||
def getBucketKey(self, *a, **kw):
|
||||
"""
|
||||
Construct a key based on the input parameters to choose a L{Bucket}.
|
||||
|
||||
The default implementation returns the same key for all
|
||||
arguments. Override this method to provide L{Bucket} selection.
|
||||
|
||||
@returns: Something to be used as a key in the bucket cache.
|
||||
"""
|
||||
return None
|
||||
|
||||
def sweep(self):
|
||||
"""
|
||||
Remove empty buckets.
|
||||
"""
|
||||
for key, bucket in self.buckets.items():
|
||||
bucket_is_empty = bucket.drip()
|
||||
if (bucket._refcount == 0) and bucket_is_empty:
|
||||
del self.buckets[key]
|
||||
|
||||
self.lastSweep = time()
|
||||
|
||||
|
||||
class FilterByHost(HierarchicalBucketFilter):
|
||||
"""
|
||||
A Hierarchical Bucket filter with a L{Bucket} for each host.
|
||||
"""
|
||||
sweepInterval = 60 * 20
|
||||
|
||||
def getBucketKey(self, transport):
|
||||
return transport.getPeer()[1]
|
||||
|
||||
|
||||
class FilterByServer(HierarchicalBucketFilter):
|
||||
"""
|
||||
A Hierarchical Bucket filter with a L{Bucket} for each service.
|
||||
"""
|
||||
sweepInterval = None
|
||||
|
||||
def getBucketKey(self, transport):
|
||||
return transport.getHost()[2]
|
||||
|
||||
|
||||
class ShapedConsumer(pcp.ProducerConsumerProxy):
|
||||
"""
|
||||
Wraps a C{Consumer} and shapes the rate at which it receives data.
|
||||
"""
|
||||
# Providing a Pull interface means I don't have to try to schedule
|
||||
# traffic with callLaters.
|
||||
iAmStreaming = False
|
||||
|
||||
def __init__(self, consumer, bucket):
|
||||
pcp.ProducerConsumerProxy.__init__(self, consumer)
|
||||
self.bucket = bucket
|
||||
self.bucket._refcount += 1
|
||||
|
||||
def _writeSomeData(self, data):
|
||||
# In practice, this actually results in obscene amounts of
|
||||
# overhead, as a result of generating lots and lots of packets
|
||||
# with twelve-byte payloads. We may need to do a version of
|
||||
# this with scheduled writes after all.
|
||||
amount = self.bucket.add(len(data))
|
||||
return pcp.ProducerConsumerProxy._writeSomeData(self, data[:amount])
|
||||
|
||||
def stopProducing(self):
|
||||
pcp.ProducerConsumerProxy.stopProducing(self)
|
||||
self.bucket._refcount -= 1
|
||||
|
||||
|
||||
class ShapedTransport(ShapedConsumer):
|
||||
"""
|
||||
Wraps a C{Transport} and shapes the rate at which it receives data.
|
||||
|
||||
This is a L{ShapedConsumer} with a little bit of magic to provide for
|
||||
the case where the consumer it wraps is also a C{Transport} and people
|
||||
will be attempting to access attributes this does not proxy as a
|
||||
C{Consumer} (e.g. C{loseConnection}).
|
||||
"""
|
||||
# Ugh. We only wanted to filter IConsumer, not ITransport.
|
||||
|
||||
iAmStreaming = False
|
||||
def __getattr__(self, name):
|
||||
# Because people will be doing things like .getPeer and
|
||||
# .loseConnection on me.
|
||||
return getattr(self.consumer, name)
|
||||
|
||||
|
||||
class ShapedProtocolFactory:
|
||||
"""
|
||||
Dispense C{Protocols} with traffic shaping on their transports.
|
||||
|
||||
Usage::
|
||||
|
||||
myserver = SomeFactory()
|
||||
myserver.protocol = ShapedProtocolFactory(myserver.protocol,
|
||||
bucketFilter)
|
||||
|
||||
Where C{SomeServerFactory} is a L{twisted.internet.protocol.Factory}, and
|
||||
C{bucketFilter} is an instance of L{HierarchicalBucketFilter}.
|
||||
"""
|
||||
def __init__(self, protoClass, bucketFilter):
|
||||
"""
|
||||
Tell me what to wrap and where to get buckets.
|
||||
|
||||
@param protoClass: The class of C{Protocol} this will generate
|
||||
wrapped instances of.
|
||||
@type protoClass: L{Protocol<twisted.internet.interfaces.IProtocol>}
|
||||
class
|
||||
@param bucketFilter: The filter which will determine how
|
||||
traffic is shaped.
|
||||
@type bucketFilter: L{HierarchicalBucketFilter}.
|
||||
"""
|
||||
# More precisely, protoClass can be any callable that will return
|
||||
# instances of something that implements IProtocol.
|
||||
self.protocol = protoClass
|
||||
self.bucketFilter = bucketFilter
|
||||
|
||||
def __call__(self, *a, **kw):
|
||||
"""
|
||||
Make a C{Protocol} instance with a shaped transport.
|
||||
|
||||
Any parameters will be passed on to the protocol's initializer.
|
||||
|
||||
@returns: A C{Protocol} instance with a L{ShapedTransport}.
|
||||
"""
|
||||
proto = self.protocol(*a, **kw)
|
||||
origMakeConnection = proto.makeConnection
|
||||
def makeConnection(transport):
|
||||
bucket = self.bucketFilter.getBucketFor(transport)
|
||||
shapedTransport = ShapedTransport(transport, bucket)
|
||||
return origMakeConnection(shapedTransport)
|
||||
proto.makeConnection = makeConnection
|
||||
return proto
|
||||
255
venv/lib/python3.9/site-packages/twisted/protocols/ident.py
Normal file
255
venv/lib/python3.9/site-packages/twisted/protocols/ident.py
Normal file
|
|
@ -0,0 +1,255 @@
|
|||
# -*- test-case-name: twisted.test.test_ident -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Ident protocol implementation.
|
||||
"""
|
||||
|
||||
import struct
|
||||
|
||||
from twisted.internet import defer
|
||||
from twisted.protocols import basic
|
||||
from twisted.python import log, failure
|
||||
|
||||
_MIN_PORT = 1
|
||||
_MAX_PORT = 2 ** 16 - 1
|
||||
|
||||
class IdentError(Exception):
|
||||
"""
|
||||
Can't determine connection owner; reason unknown.
|
||||
"""
|
||||
|
||||
identDescription = 'UNKNOWN-ERROR'
|
||||
|
||||
def __str__(self):
|
||||
return self.identDescription
|
||||
|
||||
|
||||
|
||||
class NoUser(IdentError):
|
||||
"""
|
||||
The connection specified by the port pair is not currently in use or
|
||||
currently not owned by an identifiable entity.
|
||||
"""
|
||||
identDescription = 'NO-USER'
|
||||
|
||||
|
||||
|
||||
class InvalidPort(IdentError):
|
||||
"""
|
||||
Either the local or foreign port was improperly specified. This should
|
||||
be returned if either or both of the port ids were out of range (TCP
|
||||
port numbers are from 1-65535), negative integers, reals or in any
|
||||
fashion not recognized as a non-negative integer.
|
||||
"""
|
||||
identDescription = 'INVALID-PORT'
|
||||
|
||||
|
||||
|
||||
class HiddenUser(IdentError):
|
||||
"""
|
||||
The server was able to identify the user of this port, but the
|
||||
information was not returned at the request of the user.
|
||||
"""
|
||||
identDescription = 'HIDDEN-USER'
|
||||
|
||||
|
||||
|
||||
class IdentServer(basic.LineOnlyReceiver):
|
||||
"""
|
||||
The Identification Protocol (a.k.a., "ident", a.k.a., "the Ident
|
||||
Protocol") provides a means to determine the identity of a user of a
|
||||
particular TCP connection. Given a TCP port number pair, it returns a
|
||||
character string which identifies the owner of that connection on the
|
||||
server's system.
|
||||
|
||||
Server authors should subclass this class and override the lookup method.
|
||||
The default implementation returns an UNKNOWN-ERROR response for every
|
||||
query.
|
||||
"""
|
||||
|
||||
def lineReceived(self, line):
|
||||
parts = line.split(',')
|
||||
if len(parts) != 2:
|
||||
self.invalidQuery()
|
||||
else:
|
||||
try:
|
||||
portOnServer, portOnClient = map(int, parts)
|
||||
except ValueError:
|
||||
self.invalidQuery()
|
||||
else:
|
||||
if _MIN_PORT <= portOnServer <= _MAX_PORT and _MIN_PORT <= portOnClient <= _MAX_PORT:
|
||||
self.validQuery(portOnServer, portOnClient)
|
||||
else:
|
||||
self._ebLookup(failure.Failure(InvalidPort()), portOnServer, portOnClient)
|
||||
|
||||
|
||||
def invalidQuery(self):
|
||||
self.transport.loseConnection()
|
||||
|
||||
|
||||
def validQuery(self, portOnServer, portOnClient):
|
||||
"""
|
||||
Called when a valid query is received to look up and deliver the
|
||||
response.
|
||||
|
||||
@param portOnServer: The server port from the query.
|
||||
@param portOnClient: The client port from the query.
|
||||
"""
|
||||
serverAddr = self.transport.getHost().host, portOnServer
|
||||
clientAddr = self.transport.getPeer().host, portOnClient
|
||||
defer.maybeDeferred(self.lookup, serverAddr, clientAddr
|
||||
).addCallback(self._cbLookup, portOnServer, portOnClient
|
||||
).addErrback(self._ebLookup, portOnServer, portOnClient
|
||||
)
|
||||
|
||||
|
||||
def _cbLookup(self, result, sport, cport):
|
||||
(sysName, userId) = result
|
||||
self.sendLine('%d, %d : USERID : %s : %s' % (sport, cport, sysName, userId))
|
||||
|
||||
|
||||
def _ebLookup(self, failure, sport, cport):
|
||||
if failure.check(IdentError):
|
||||
self.sendLine('%d, %d : ERROR : %s' % (sport, cport, failure.value))
|
||||
else:
|
||||
log.err(failure)
|
||||
self.sendLine('%d, %d : ERROR : %s' % (sport, cport, IdentError(failure.value)))
|
||||
|
||||
|
||||
def lookup(self, serverAddress, clientAddress):
|
||||
"""
|
||||
Lookup user information about the specified address pair.
|
||||
|
||||
Return value should be a two-tuple of system name and username.
|
||||
Acceptable values for the system name may be found online at::
|
||||
|
||||
U{http://www.iana.org/assignments/operating-system-names}
|
||||
|
||||
This method may also raise any IdentError subclass (or IdentError
|
||||
itself) to indicate user information will not be provided for the
|
||||
given query.
|
||||
|
||||
A Deferred may also be returned.
|
||||
|
||||
@param serverAddress: A two-tuple representing the server endpoint
|
||||
of the address being queried. The first element is a string holding
|
||||
a dotted-quad IP address. The second element is an integer
|
||||
representing the port.
|
||||
|
||||
@param clientAddress: Like I{serverAddress}, but represents the
|
||||
client endpoint of the address being queried.
|
||||
"""
|
||||
raise IdentError()
|
||||
|
||||
|
||||
|
||||
class ProcServerMixin:
|
||||
"""Implements lookup() to grab entries for responses from /proc/net/tcp
|
||||
"""
|
||||
|
||||
SYSTEM_NAME = 'LINUX'
|
||||
|
||||
try:
|
||||
from pwd import getpwuid
|
||||
def getUsername(self, uid, getpwuid=getpwuid):
|
||||
return getpwuid(uid)[0]
|
||||
del getpwuid
|
||||
except ImportError:
|
||||
def getUsername(self, uid):
|
||||
raise IdentError()
|
||||
|
||||
|
||||
def entries(self):
|
||||
with open('/proc/net/tcp') as f:
|
||||
f.readline()
|
||||
for L in f:
|
||||
yield L.strip()
|
||||
|
||||
|
||||
def dottedQuadFromHexString(self, hexstr):
|
||||
return '.'.join(map(str, struct.unpack('4B', struct.pack('=L', int(hexstr, 16)))))
|
||||
|
||||
|
||||
def unpackAddress(self, packed):
|
||||
addr, port = packed.split(':')
|
||||
addr = self.dottedQuadFromHexString(addr)
|
||||
port = int(port, 16)
|
||||
return addr, port
|
||||
|
||||
|
||||
def parseLine(self, line):
|
||||
parts = line.strip().split()
|
||||
localAddr, localPort = self.unpackAddress(parts[1])
|
||||
remoteAddr, remotePort = self.unpackAddress(parts[2])
|
||||
uid = int(parts[7])
|
||||
return (localAddr, localPort), (remoteAddr, remotePort), uid
|
||||
|
||||
|
||||
def lookup(self, serverAddress, clientAddress):
|
||||
for ent in self.entries():
|
||||
localAddr, remoteAddr, uid = self.parseLine(ent)
|
||||
if remoteAddr == clientAddress and localAddr[1] == serverAddress[1]:
|
||||
return (self.SYSTEM_NAME, self.getUsername(uid))
|
||||
|
||||
raise NoUser()
|
||||
|
||||
|
||||
|
||||
class IdentClient(basic.LineOnlyReceiver):
|
||||
|
||||
errorTypes = (IdentError, NoUser, InvalidPort, HiddenUser)
|
||||
|
||||
def __init__(self):
|
||||
self.queries = []
|
||||
|
||||
|
||||
def lookup(self, portOnServer, portOnClient):
|
||||
"""
|
||||
Lookup user information about the specified address pair.
|
||||
"""
|
||||
self.queries.append((defer.Deferred(), portOnServer, portOnClient))
|
||||
if len(self.queries) > 1:
|
||||
return self.queries[-1][0]
|
||||
|
||||
self.sendLine('%d, %d' % (portOnServer, portOnClient))
|
||||
return self.queries[-1][0]
|
||||
|
||||
|
||||
def lineReceived(self, line):
|
||||
if not self.queries:
|
||||
log.msg("Unexpected server response: %r" % (line,))
|
||||
else:
|
||||
d, _, _ = self.queries.pop(0)
|
||||
self.parseResponse(d, line)
|
||||
if self.queries:
|
||||
self.sendLine('%d, %d' % (self.queries[0][1], self.queries[0][2]))
|
||||
|
||||
|
||||
def connectionLost(self, reason):
|
||||
for q in self.queries:
|
||||
q[0].errback(IdentError(reason))
|
||||
self.queries = []
|
||||
|
||||
|
||||
def parseResponse(self, deferred, line):
|
||||
parts = line.split(':', 2)
|
||||
if len(parts) != 3:
|
||||
deferred.errback(IdentError(line))
|
||||
else:
|
||||
ports, type, addInfo = map(str.strip, parts)
|
||||
if type == 'ERROR':
|
||||
for et in self.errorTypes:
|
||||
if et.identDescription == addInfo:
|
||||
deferred.errback(et(line))
|
||||
return
|
||||
deferred.errback(IdentError(line))
|
||||
else:
|
||||
deferred.callback((type, addInfo))
|
||||
|
||||
|
||||
|
||||
__all__ = ['IdentError', 'NoUser', 'InvalidPort', 'HiddenUser',
|
||||
'IdentServer', 'IdentClient',
|
||||
'ProcServerMixin']
|
||||
385
venv/lib/python3.9/site-packages/twisted/protocols/loopback.py
Normal file
385
venv/lib/python3.9/site-packages/twisted/protocols/loopback.py
Normal file
|
|
@ -0,0 +1,385 @@
|
|||
# -*- test-case-name: twisted.test.test_loopback -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Testing support for protocols -- loopback between client and server.
|
||||
"""
|
||||
|
||||
from __future__ import division, absolute_import
|
||||
|
||||
# system imports
|
||||
import tempfile
|
||||
|
||||
from zope.interface import implementer
|
||||
|
||||
# Twisted Imports
|
||||
from twisted.protocols import policies
|
||||
from twisted.internet import interfaces, protocol, main, defer
|
||||
from twisted.internet.task import deferLater
|
||||
from twisted.python import failure
|
||||
from twisted.internet.interfaces import IAddress
|
||||
|
||||
|
||||
class _LoopbackQueue(object):
|
||||
"""
|
||||
Trivial wrapper around a list to give it an interface like a queue, which
|
||||
the addition of also sending notifications by way of a Deferred whenever
|
||||
the list has something added to it.
|
||||
"""
|
||||
|
||||
_notificationDeferred = None
|
||||
disconnect = False
|
||||
|
||||
def __init__(self):
|
||||
self._queue = []
|
||||
|
||||
|
||||
def put(self, v):
|
||||
self._queue.append(v)
|
||||
if self._notificationDeferred is not None:
|
||||
d, self._notificationDeferred = self._notificationDeferred, None
|
||||
d.callback(None)
|
||||
|
||||
|
||||
def __nonzero__(self):
|
||||
return bool(self._queue)
|
||||
__bool__ = __nonzero__
|
||||
|
||||
|
||||
def get(self):
|
||||
return self._queue.pop(0)
|
||||
|
||||
|
||||
|
||||
@implementer(IAddress)
|
||||
class _LoopbackAddress(object):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
@implementer(interfaces.ITransport, interfaces.IConsumer)
|
||||
class _LoopbackTransport(object):
|
||||
disconnecting = False
|
||||
producer = None
|
||||
|
||||
# ITransport
|
||||
def __init__(self, q):
|
||||
self.q = q
|
||||
|
||||
def write(self, data):
|
||||
if not isinstance(data, bytes):
|
||||
raise TypeError("Can only write bytes to ITransport")
|
||||
self.q.put(data)
|
||||
|
||||
def writeSequence(self, iovec):
|
||||
self.q.put(b''.join(iovec))
|
||||
|
||||
def loseConnection(self):
|
||||
self.q.disconnect = True
|
||||
self.q.put(None)
|
||||
|
||||
|
||||
def abortConnection(self):
|
||||
"""
|
||||
Abort the connection. Same as L{loseConnection}.
|
||||
"""
|
||||
self.loseConnection()
|
||||
|
||||
|
||||
def getPeer(self):
|
||||
return _LoopbackAddress()
|
||||
|
||||
def getHost(self):
|
||||
return _LoopbackAddress()
|
||||
|
||||
# IConsumer
|
||||
def registerProducer(self, producer, streaming):
|
||||
assert self.producer is None
|
||||
self.producer = producer
|
||||
self.streamingProducer = streaming
|
||||
self._pollProducer()
|
||||
|
||||
def unregisterProducer(self):
|
||||
assert self.producer is not None
|
||||
self.producer = None
|
||||
|
||||
def _pollProducer(self):
|
||||
if self.producer is not None and not self.streamingProducer:
|
||||
self.producer.resumeProducing()
|
||||
|
||||
|
||||
|
||||
def identityPumpPolicy(queue, target):
|
||||
"""
|
||||
L{identityPumpPolicy} is a policy which delivers each chunk of data written
|
||||
to the given queue as-is to the target.
|
||||
|
||||
This isn't a particularly realistic policy.
|
||||
|
||||
@see: L{loopbackAsync}
|
||||
"""
|
||||
while queue:
|
||||
bytes = queue.get()
|
||||
if bytes is None:
|
||||
break
|
||||
target.dataReceived(bytes)
|
||||
|
||||
|
||||
|
||||
def collapsingPumpPolicy(queue, target):
|
||||
"""
|
||||
L{collapsingPumpPolicy} is a policy which collapses all outstanding chunks
|
||||
into a single string and delivers it to the target.
|
||||
|
||||
@see: L{loopbackAsync}
|
||||
"""
|
||||
bytes = []
|
||||
while queue:
|
||||
chunk = queue.get()
|
||||
if chunk is None:
|
||||
break
|
||||
bytes.append(chunk)
|
||||
if bytes:
|
||||
target.dataReceived(b''.join(bytes))
|
||||
|
||||
|
||||
|
||||
def loopbackAsync(server, client, pumpPolicy=identityPumpPolicy):
|
||||
"""
|
||||
Establish a connection between C{server} and C{client} then transfer data
|
||||
between them until the connection is closed. This is often useful for
|
||||
testing a protocol.
|
||||
|
||||
@param server: The protocol instance representing the server-side of this
|
||||
connection.
|
||||
|
||||
@param client: The protocol instance representing the client-side of this
|
||||
connection.
|
||||
|
||||
@param pumpPolicy: When either C{server} or C{client} writes to its
|
||||
transport, the string passed in is added to a queue of data for the
|
||||
other protocol. Eventually, C{pumpPolicy} will be called with one such
|
||||
queue and the corresponding protocol object. The pump policy callable
|
||||
is responsible for emptying the queue and passing the strings it
|
||||
contains to the given protocol's C{dataReceived} method. The signature
|
||||
of C{pumpPolicy} is C{(queue, protocol)}. C{queue} is an object with a
|
||||
C{get} method which will return the next string written to the
|
||||
transport, or L{None} if the transport has been disconnected, and which
|
||||
evaluates to C{True} if and only if there are more items to be
|
||||
retrieved via C{get}.
|
||||
|
||||
@return: A L{Deferred} which fires when the connection has been closed and
|
||||
both sides have received notification of this.
|
||||
"""
|
||||
serverToClient = _LoopbackQueue()
|
||||
clientToServer = _LoopbackQueue()
|
||||
|
||||
server.makeConnection(_LoopbackTransport(serverToClient))
|
||||
client.makeConnection(_LoopbackTransport(clientToServer))
|
||||
|
||||
return _loopbackAsyncBody(
|
||||
server, serverToClient, client, clientToServer, pumpPolicy)
|
||||
|
||||
|
||||
|
||||
def _loopbackAsyncBody(server, serverToClient, client, clientToServer,
|
||||
pumpPolicy):
|
||||
"""
|
||||
Transfer bytes from the output queue of each protocol to the input of the other.
|
||||
|
||||
@param server: The protocol instance representing the server-side of this
|
||||
connection.
|
||||
|
||||
@param serverToClient: The L{_LoopbackQueue} holding the server's output.
|
||||
|
||||
@param client: The protocol instance representing the client-side of this
|
||||
connection.
|
||||
|
||||
@param clientToServer: The L{_LoopbackQueue} holding the client's output.
|
||||
|
||||
@param pumpPolicy: See L{loopbackAsync}.
|
||||
|
||||
@return: A L{Deferred} which fires when the connection has been closed and
|
||||
both sides have received notification of this.
|
||||
"""
|
||||
def pump(source, q, target):
|
||||
sent = False
|
||||
if q:
|
||||
pumpPolicy(q, target)
|
||||
sent = True
|
||||
if sent and not q:
|
||||
# A write buffer has now been emptied. Give any producer on that
|
||||
# side an opportunity to produce more data.
|
||||
source.transport._pollProducer()
|
||||
|
||||
return sent
|
||||
|
||||
while 1:
|
||||
disconnect = clientSent = serverSent = False
|
||||
|
||||
# Deliver the data which has been written.
|
||||
serverSent = pump(server, serverToClient, client)
|
||||
clientSent = pump(client, clientToServer, server)
|
||||
|
||||
if not clientSent and not serverSent:
|
||||
# Neither side wrote any data. Wait for some new data to be added
|
||||
# before trying to do anything further.
|
||||
d = defer.Deferred()
|
||||
clientToServer._notificationDeferred = d
|
||||
serverToClient._notificationDeferred = d
|
||||
d.addCallback(
|
||||
_loopbackAsyncContinue,
|
||||
server, serverToClient, client, clientToServer, pumpPolicy)
|
||||
return d
|
||||
if serverToClient.disconnect:
|
||||
# The server wants to drop the connection. Flush any remaining
|
||||
# data it has.
|
||||
disconnect = True
|
||||
pump(server, serverToClient, client)
|
||||
elif clientToServer.disconnect:
|
||||
# The client wants to drop the connection. Flush any remaining
|
||||
# data it has.
|
||||
disconnect = True
|
||||
pump(client, clientToServer, server)
|
||||
if disconnect:
|
||||
# Someone wanted to disconnect, so okay, the connection is gone.
|
||||
server.connectionLost(failure.Failure(main.CONNECTION_DONE))
|
||||
client.connectionLost(failure.Failure(main.CONNECTION_DONE))
|
||||
return defer.succeed(None)
|
||||
|
||||
|
||||
|
||||
def _loopbackAsyncContinue(ignored, server, serverToClient, client,
|
||||
clientToServer, pumpPolicy):
|
||||
# Clear the Deferred from each message queue, since it has already fired
|
||||
# and cannot be used again.
|
||||
clientToServer._notificationDeferred = None
|
||||
serverToClient._notificationDeferred = None
|
||||
|
||||
# Schedule some more byte-pushing to happen. This isn't done
|
||||
# synchronously because no actual transport can re-enter dataReceived as
|
||||
# a result of calling write, and doing this synchronously could result
|
||||
# in that.
|
||||
from twisted.internet import reactor
|
||||
return deferLater(
|
||||
reactor, 0,
|
||||
_loopbackAsyncBody,
|
||||
server, serverToClient, client, clientToServer, pumpPolicy)
|
||||
|
||||
|
||||
|
||||
@implementer(interfaces.ITransport, interfaces.IConsumer)
|
||||
class LoopbackRelay:
|
||||
buffer = b''
|
||||
shouldLose = 0
|
||||
disconnecting = 0
|
||||
producer = None
|
||||
|
||||
def __init__(self, target, logFile=None):
|
||||
self.target = target
|
||||
self.logFile = logFile
|
||||
|
||||
def write(self, data):
|
||||
self.buffer = self.buffer + data
|
||||
if self.logFile:
|
||||
self.logFile.write("loopback writing %s\n" % repr(data))
|
||||
|
||||
def writeSequence(self, iovec):
|
||||
self.write(b"".join(iovec))
|
||||
|
||||
def clearBuffer(self):
|
||||
if self.shouldLose == -1:
|
||||
return
|
||||
|
||||
if self.producer:
|
||||
self.producer.resumeProducing()
|
||||
if self.buffer:
|
||||
if self.logFile:
|
||||
self.logFile.write("loopback receiving %s\n" % repr(self.buffer))
|
||||
buffer = self.buffer
|
||||
self.buffer = b''
|
||||
self.target.dataReceived(buffer)
|
||||
if self.shouldLose == 1:
|
||||
self.shouldLose = -1
|
||||
self.target.connectionLost(failure.Failure(main.CONNECTION_DONE))
|
||||
|
||||
def loseConnection(self):
|
||||
if self.shouldLose != -1:
|
||||
self.shouldLose = 1
|
||||
|
||||
def getHost(self):
|
||||
return 'loopback'
|
||||
|
||||
def getPeer(self):
|
||||
return 'loopback'
|
||||
|
||||
def registerProducer(self, producer, streaming):
|
||||
self.producer = producer
|
||||
|
||||
def unregisterProducer(self):
|
||||
self.producer = None
|
||||
|
||||
def logPrefix(self):
|
||||
return 'Loopback(%r)' % (self.target.__class__.__name__,)
|
||||
|
||||
|
||||
|
||||
class LoopbackClientFactory(protocol.ClientFactory):
|
||||
|
||||
def __init__(self, protocol):
|
||||
self.disconnected = 0
|
||||
self.deferred = defer.Deferred()
|
||||
self.protocol = protocol
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
return self.protocol
|
||||
|
||||
def clientConnectionLost(self, connector, reason):
|
||||
self.disconnected = 1
|
||||
self.deferred.callback(None)
|
||||
|
||||
|
||||
class _FireOnClose(policies.ProtocolWrapper):
|
||||
def __init__(self, protocol, factory):
|
||||
policies.ProtocolWrapper.__init__(self, protocol, factory)
|
||||
self.deferred = defer.Deferred()
|
||||
|
||||
def connectionLost(self, reason):
|
||||
policies.ProtocolWrapper.connectionLost(self, reason)
|
||||
self.deferred.callback(None)
|
||||
|
||||
|
||||
def loopbackTCP(server, client, port=0, noisy=True):
|
||||
"""Run session between server and client protocol instances over TCP."""
|
||||
from twisted.internet import reactor
|
||||
f = policies.WrappingFactory(protocol.Factory())
|
||||
serverWrapper = _FireOnClose(f, server)
|
||||
f.noisy = noisy
|
||||
f.buildProtocol = lambda addr: serverWrapper
|
||||
serverPort = reactor.listenTCP(port, f, interface='127.0.0.1')
|
||||
clientF = LoopbackClientFactory(client)
|
||||
clientF.noisy = noisy
|
||||
reactor.connectTCP('127.0.0.1', serverPort.getHost().port, clientF)
|
||||
d = clientF.deferred
|
||||
d.addCallback(lambda x: serverWrapper.deferred)
|
||||
d.addCallback(lambda x: serverPort.stopListening())
|
||||
return d
|
||||
|
||||
|
||||
def loopbackUNIX(server, client, noisy=True):
|
||||
"""Run session between server and client protocol instances over UNIX socket."""
|
||||
path = tempfile.mktemp()
|
||||
from twisted.internet import reactor
|
||||
f = policies.WrappingFactory(protocol.Factory())
|
||||
serverWrapper = _FireOnClose(f, server)
|
||||
f.noisy = noisy
|
||||
f.buildProtocol = lambda addr: serverWrapper
|
||||
serverPort = reactor.listenUNIX(path, f)
|
||||
clientF = LoopbackClientFactory(client)
|
||||
clientF.noisy = noisy
|
||||
reactor.connectUNIX(path, clientF)
|
||||
d = clientF.deferred
|
||||
d.addCallback(lambda x: serverWrapper.deferred)
|
||||
d.addCallback(lambda x: serverPort.stopListening())
|
||||
return d
|
||||
766
venv/lib/python3.9/site-packages/twisted/protocols/memcache.py
Normal file
766
venv/lib/python3.9/site-packages/twisted/protocols/memcache.py
Normal file
|
|
@ -0,0 +1,766 @@
|
|||
# -*- test-case-name: twisted.test.test_memcache -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Memcache client protocol. Memcached is a caching server, storing data in the
|
||||
form of pairs key/value, and memcache is the protocol to talk with it.
|
||||
|
||||
To connect to a server, create a factory for L{MemCacheProtocol}::
|
||||
|
||||
from twisted.internet import reactor, protocol
|
||||
from twisted.protocols.memcache import MemCacheProtocol, DEFAULT_PORT
|
||||
d = protocol.ClientCreator(reactor, MemCacheProtocol
|
||||
).connectTCP("localhost", DEFAULT_PORT)
|
||||
def doSomething(proto):
|
||||
# Here you call the memcache operations
|
||||
return proto.set("mykey", "a lot of data")
|
||||
d.addCallback(doSomething)
|
||||
reactor.run()
|
||||
|
||||
All the operations of the memcache protocol are present, but
|
||||
L{MemCacheProtocol.set} and L{MemCacheProtocol.get} are the more important.
|
||||
|
||||
See U{http://code.sixapart.com/svn/memcached/trunk/server/doc/protocol.txt} for
|
||||
more information about the protocol.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division
|
||||
|
||||
from collections import deque
|
||||
|
||||
from twisted.protocols.basic import LineReceiver
|
||||
from twisted.protocols.policies import TimeoutMixin
|
||||
from twisted.internet.defer import Deferred, fail, TimeoutError
|
||||
from twisted.python import log
|
||||
from twisted.python.compat import (
|
||||
intToBytes, iteritems, nativeString, networkString)
|
||||
|
||||
|
||||
|
||||
DEFAULT_PORT = 11211
|
||||
|
||||
|
||||
|
||||
class NoSuchCommand(Exception):
|
||||
"""
|
||||
Exception raised when a non existent command is called.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
class ClientError(Exception):
|
||||
"""
|
||||
Error caused by an invalid client call.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
class ServerError(Exception):
|
||||
"""
|
||||
Problem happening on the server.
|
||||
"""
|
||||
|
||||
|
||||
|
||||
class Command(object):
|
||||
"""
|
||||
Wrap a client action into an object, that holds the values used in the
|
||||
protocol.
|
||||
|
||||
@ivar _deferred: the L{Deferred} object that will be fired when the result
|
||||
arrives.
|
||||
@type _deferred: L{Deferred}
|
||||
|
||||
@ivar command: name of the command sent to the server.
|
||||
@type command: L{bytes}
|
||||
"""
|
||||
|
||||
def __init__(self, command, **kwargs):
|
||||
"""
|
||||
Create a command.
|
||||
|
||||
@param command: the name of the command.
|
||||
@type command: L{bytes}
|
||||
|
||||
@param kwargs: this values will be stored as attributes of the object
|
||||
for future use
|
||||
"""
|
||||
self.command = command
|
||||
self._deferred = Deferred()
|
||||
for k, v in kwargs.items():
|
||||
setattr(self, k, v)
|
||||
|
||||
|
||||
def success(self, value):
|
||||
"""
|
||||
Shortcut method to fire the underlying deferred.
|
||||
"""
|
||||
self._deferred.callback(value)
|
||||
|
||||
|
||||
def fail(self, error):
|
||||
"""
|
||||
Make the underlying deferred fails.
|
||||
"""
|
||||
self._deferred.errback(error)
|
||||
|
||||
|
||||
|
||||
class MemCacheProtocol(LineReceiver, TimeoutMixin):
|
||||
"""
|
||||
MemCache protocol: connect to a memcached server to store/retrieve values.
|
||||
|
||||
@ivar persistentTimeOut: the timeout period used to wait for a response.
|
||||
@type persistentTimeOut: L{int}
|
||||
|
||||
@ivar _current: current list of requests waiting for an answer from the
|
||||
server.
|
||||
@type _current: L{deque} of L{Command}
|
||||
|
||||
@ivar _lenExpected: amount of data expected in raw mode, when reading for
|
||||
a value.
|
||||
@type _lenExpected: L{int}
|
||||
|
||||
@ivar _getBuffer: current buffer of data, used to store temporary data
|
||||
when reading in raw mode.
|
||||
@type _getBuffer: L{list}
|
||||
|
||||
@ivar _bufferLength: the total amount of bytes in C{_getBuffer}.
|
||||
@type _bufferLength: L{int}
|
||||
|
||||
@ivar _disconnected: indicate if the connectionLost has been called or not.
|
||||
@type _disconnected: L{bool}
|
||||
"""
|
||||
MAX_KEY_LENGTH = 250
|
||||
_disconnected = False
|
||||
|
||||
def __init__(self, timeOut=60):
|
||||
"""
|
||||
Create the protocol.
|
||||
|
||||
@param timeOut: the timeout to wait before detecting that the
|
||||
connection is dead and close it. It's expressed in seconds.
|
||||
@type timeOut: L{int}
|
||||
"""
|
||||
self._current = deque()
|
||||
self._lenExpected = None
|
||||
self._getBuffer = None
|
||||
self._bufferLength = None
|
||||
self.persistentTimeOut = self.timeOut = timeOut
|
||||
|
||||
|
||||
def _cancelCommands(self, reason):
|
||||
"""
|
||||
Cancel all the outstanding commands, making them fail with C{reason}.
|
||||
"""
|
||||
while self._current:
|
||||
cmd = self._current.popleft()
|
||||
cmd.fail(reason)
|
||||
|
||||
|
||||
def timeoutConnection(self):
|
||||
"""
|
||||
Close the connection in case of timeout.
|
||||
"""
|
||||
self._cancelCommands(TimeoutError("Connection timeout"))
|
||||
self.transport.loseConnection()
|
||||
|
||||
|
||||
def connectionLost(self, reason):
|
||||
"""
|
||||
Cause any outstanding commands to fail.
|
||||
"""
|
||||
self._disconnected = True
|
||||
self._cancelCommands(reason)
|
||||
LineReceiver.connectionLost(self, reason)
|
||||
|
||||
|
||||
def sendLine(self, line):
|
||||
"""
|
||||
Override sendLine to add a timeout to response.
|
||||
"""
|
||||
if not self._current:
|
||||
self.setTimeout(self.persistentTimeOut)
|
||||
LineReceiver.sendLine(self, line)
|
||||
|
||||
|
||||
def rawDataReceived(self, data):
|
||||
"""
|
||||
Collect data for a get.
|
||||
"""
|
||||
self.resetTimeout()
|
||||
self._getBuffer.append(data)
|
||||
self._bufferLength += len(data)
|
||||
if self._bufferLength >= self._lenExpected + 2:
|
||||
data = b"".join(self._getBuffer)
|
||||
buf = data[:self._lenExpected]
|
||||
rem = data[self._lenExpected + 2:]
|
||||
val = buf
|
||||
self._lenExpected = None
|
||||
self._getBuffer = None
|
||||
self._bufferLength = None
|
||||
cmd = self._current[0]
|
||||
if cmd.multiple:
|
||||
flags, cas = cmd.values[cmd.currentKey]
|
||||
cmd.values[cmd.currentKey] = (flags, cas, val)
|
||||
else:
|
||||
cmd.value = val
|
||||
self.setLineMode(rem)
|
||||
|
||||
|
||||
def cmd_STORED(self):
|
||||
"""
|
||||
Manage a success response to a set operation.
|
||||
"""
|
||||
self._current.popleft().success(True)
|
||||
|
||||
|
||||
def cmd_NOT_STORED(self):
|
||||
"""
|
||||
Manage a specific 'not stored' response to a set operation: this is not
|
||||
an error, but some condition wasn't met.
|
||||
"""
|
||||
self._current.popleft().success(False)
|
||||
|
||||
|
||||
def cmd_END(self):
|
||||
"""
|
||||
This the end token to a get or a stat operation.
|
||||
"""
|
||||
cmd = self._current.popleft()
|
||||
if cmd.command == b"get":
|
||||
if cmd.multiple:
|
||||
values = {key: val[::2] for key, val in iteritems(cmd.values)}
|
||||
cmd.success(values)
|
||||
else:
|
||||
cmd.success((cmd.flags, cmd.value))
|
||||
elif cmd.command == b"gets":
|
||||
if cmd.multiple:
|
||||
cmd.success(cmd.values)
|
||||
else:
|
||||
cmd.success((cmd.flags, cmd.cas, cmd.value))
|
||||
elif cmd.command == b"stats":
|
||||
cmd.success(cmd.values)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
"Unexpected END response to %s command" %
|
||||
(nativeString(cmd.command),))
|
||||
|
||||
|
||||
def cmd_NOT_FOUND(self):
|
||||
"""
|
||||
Manage error response for incr/decr/delete.
|
||||
"""
|
||||
self._current.popleft().success(False)
|
||||
|
||||
|
||||
def cmd_VALUE(self, line):
|
||||
"""
|
||||
Prepare the reading a value after a get.
|
||||
"""
|
||||
cmd = self._current[0]
|
||||
if cmd.command == b"get":
|
||||
key, flags, length = line.split()
|
||||
cas = b""
|
||||
else:
|
||||
key, flags, length, cas = line.split()
|
||||
self._lenExpected = int(length)
|
||||
self._getBuffer = []
|
||||
self._bufferLength = 0
|
||||
if cmd.multiple:
|
||||
if key not in cmd.keys:
|
||||
raise RuntimeError("Unexpected commands answer.")
|
||||
cmd.currentKey = key
|
||||
cmd.values[key] = [int(flags), cas]
|
||||
else:
|
||||
if cmd.key != key:
|
||||
raise RuntimeError("Unexpected commands answer.")
|
||||
cmd.flags = int(flags)
|
||||
cmd.cas = cas
|
||||
self.setRawMode()
|
||||
|
||||
|
||||
def cmd_STAT(self, line):
|
||||
"""
|
||||
Reception of one stat line.
|
||||
"""
|
||||
cmd = self._current[0]
|
||||
key, val = line.split(b" ", 1)
|
||||
cmd.values[key] = val
|
||||
|
||||
|
||||
def cmd_VERSION(self, versionData):
|
||||
"""
|
||||
Read version token.
|
||||
"""
|
||||
self._current.popleft().success(versionData)
|
||||
|
||||
|
||||
def cmd_ERROR(self):
|
||||
"""
|
||||
A non-existent command has been sent.
|
||||
"""
|
||||
log.err("Non-existent command sent.")
|
||||
cmd = self._current.popleft()
|
||||
cmd.fail(NoSuchCommand())
|
||||
|
||||
|
||||
def cmd_CLIENT_ERROR(self, errText):
|
||||
"""
|
||||
An invalid input as been sent.
|
||||
"""
|
||||
errText = repr(errText)
|
||||
log.err("Invalid input: " + errText)
|
||||
cmd = self._current.popleft()
|
||||
cmd.fail(ClientError(errText))
|
||||
|
||||
|
||||
def cmd_SERVER_ERROR(self, errText):
|
||||
"""
|
||||
An error has happened server-side.
|
||||
"""
|
||||
errText = repr(errText)
|
||||
log.err("Server error: " + errText)
|
||||
cmd = self._current.popleft()
|
||||
cmd.fail(ServerError(errText))
|
||||
|
||||
|
||||
def cmd_DELETED(self):
|
||||
"""
|
||||
A delete command has completed successfully.
|
||||
"""
|
||||
self._current.popleft().success(True)
|
||||
|
||||
|
||||
def cmd_OK(self):
|
||||
"""
|
||||
The last command has been completed.
|
||||
"""
|
||||
self._current.popleft().success(True)
|
||||
|
||||
|
||||
def cmd_EXISTS(self):
|
||||
"""
|
||||
A C{checkAndSet} update has failed.
|
||||
"""
|
||||
self._current.popleft().success(False)
|
||||
|
||||
|
||||
def lineReceived(self, line):
|
||||
"""
|
||||
Receive line commands from the server.
|
||||
"""
|
||||
self.resetTimeout()
|
||||
token = line.split(b" ", 1)[0]
|
||||
# First manage standard commands without space
|
||||
cmd = getattr(self, "cmd_" + nativeString(token), None)
|
||||
if cmd is not None:
|
||||
args = line.split(b" ", 1)[1:]
|
||||
if args:
|
||||
cmd(args[0])
|
||||
else:
|
||||
cmd()
|
||||
else:
|
||||
# Then manage commands with space in it
|
||||
line = line.replace(b" ", b"_")
|
||||
cmd = getattr(self, "cmd_" + nativeString(line), None)
|
||||
if cmd is not None:
|
||||
cmd()
|
||||
else:
|
||||
# Increment/Decrement response
|
||||
cmd = self._current.popleft()
|
||||
val = int(line)
|
||||
cmd.success(val)
|
||||
if not self._current:
|
||||
# No pending request, remove timeout
|
||||
self.setTimeout(None)
|
||||
|
||||
|
||||
def increment(self, key, val=1):
|
||||
"""
|
||||
Increment the value of C{key} by given value (default to 1).
|
||||
C{key} must be consistent with an int. Return the new value.
|
||||
|
||||
@param key: the key to modify.
|
||||
@type key: L{bytes}
|
||||
|
||||
@param val: the value to increment.
|
||||
@type val: L{int}
|
||||
|
||||
@return: a deferred with will be called back with the new value
|
||||
associated with the key (after the increment).
|
||||
@rtype: L{Deferred}
|
||||
"""
|
||||
return self._incrdecr(b"incr", key, val)
|
||||
|
||||
|
||||
def decrement(self, key, val=1):
|
||||
"""
|
||||
Decrement the value of C{key} by given value (default to 1).
|
||||
C{key} must be consistent with an int. Return the new value, coerced to
|
||||
0 if negative.
|
||||
|
||||
@param key: the key to modify.
|
||||
@type key: L{bytes}
|
||||
|
||||
@param val: the value to decrement.
|
||||
@type val: L{int}
|
||||
|
||||
@return: a deferred with will be called back with the new value
|
||||
associated with the key (after the decrement).
|
||||
@rtype: L{Deferred}
|
||||
"""
|
||||
return self._incrdecr(b"decr", key, val)
|
||||
|
||||
|
||||
def _incrdecr(self, cmd, key, val):
|
||||
"""
|
||||
Internal wrapper for incr/decr.
|
||||
"""
|
||||
if self._disconnected:
|
||||
return fail(RuntimeError("not connected"))
|
||||
if not isinstance(key, bytes):
|
||||
return fail(ClientError(
|
||||
"Invalid type for key: %s, expecting bytes" % (type(key),)))
|
||||
if len(key) > self.MAX_KEY_LENGTH:
|
||||
return fail(ClientError("Key too long"))
|
||||
fullcmd = b" ".join([cmd, key, intToBytes(int(val))])
|
||||
self.sendLine(fullcmd)
|
||||
cmdObj = Command(cmd, key=key)
|
||||
self._current.append(cmdObj)
|
||||
return cmdObj._deferred
|
||||
|
||||
|
||||
def replace(self, key, val, flags=0, expireTime=0):
|
||||
"""
|
||||
Replace the given C{key}. It must already exist in the server.
|
||||
|
||||
@param key: the key to replace.
|
||||
@type key: L{bytes}
|
||||
|
||||
@param val: the new value associated with the key.
|
||||
@type val: L{bytes}
|
||||
|
||||
@param flags: the flags to store with the key.
|
||||
@type flags: L{int}
|
||||
|
||||
@param expireTime: if different from 0, the relative time in seconds
|
||||
when the key will be deleted from the store.
|
||||
@type expireTime: L{int}
|
||||
|
||||
@return: a deferred that will fire with C{True} if the operation has
|
||||
succeeded, and C{False} with the key didn't previously exist.
|
||||
@rtype: L{Deferred}
|
||||
"""
|
||||
return self._set(b"replace", key, val, flags, expireTime, b"")
|
||||
|
||||
|
||||
def add(self, key, val, flags=0, expireTime=0):
|
||||
"""
|
||||
Add the given C{key}. It must not exist in the server.
|
||||
|
||||
@param key: the key to add.
|
||||
@type key: L{bytes}
|
||||
|
||||
@param val: the value associated with the key.
|
||||
@type val: L{bytes}
|
||||
|
||||
@param flags: the flags to store with the key.
|
||||
@type flags: L{int}
|
||||
|
||||
@param expireTime: if different from 0, the relative time in seconds
|
||||
when the key will be deleted from the store.
|
||||
@type expireTime: L{int}
|
||||
|
||||
@return: a deferred that will fire with C{True} if the operation has
|
||||
succeeded, and C{False} with the key already exists.
|
||||
@rtype: L{Deferred}
|
||||
"""
|
||||
return self._set(b"add", key, val, flags, expireTime, b"")
|
||||
|
||||
|
||||
def set(self, key, val, flags=0, expireTime=0):
|
||||
"""
|
||||
Set the given C{key}.
|
||||
|
||||
@param key: the key to set.
|
||||
@type key: L{bytes}
|
||||
|
||||
@param val: the value associated with the key.
|
||||
@type val: L{bytes}
|
||||
|
||||
@param flags: the flags to store with the key.
|
||||
@type flags: L{int}
|
||||
|
||||
@param expireTime: if different from 0, the relative time in seconds
|
||||
when the key will be deleted from the store.
|
||||
@type expireTime: L{int}
|
||||
|
||||
@return: a deferred that will fire with C{True} if the operation has
|
||||
succeeded.
|
||||
@rtype: L{Deferred}
|
||||
"""
|
||||
return self._set(b"set", key, val, flags, expireTime, b"")
|
||||
|
||||
|
||||
def checkAndSet(self, key, val, cas, flags=0, expireTime=0):
|
||||
"""
|
||||
Change the content of C{key} only if the C{cas} value matches the
|
||||
current one associated with the key. Use this to store a value which
|
||||
hasn't been modified since last time you fetched it.
|
||||
|
||||
@param key: The key to set.
|
||||
@type key: L{bytes}
|
||||
|
||||
@param val: The value associated with the key.
|
||||
@type val: L{bytes}
|
||||
|
||||
@param cas: Unique 64-bit value returned by previous call of C{get}.
|
||||
@type cas: L{bytes}
|
||||
|
||||
@param flags: The flags to store with the key.
|
||||
@type flags: L{int}
|
||||
|
||||
@param expireTime: If different from 0, the relative time in seconds
|
||||
when the key will be deleted from the store.
|
||||
@type expireTime: L{int}
|
||||
|
||||
@return: A deferred that will fire with C{True} if the operation has
|
||||
succeeded, C{False} otherwise.
|
||||
@rtype: L{Deferred}
|
||||
"""
|
||||
return self._set(b"cas", key, val, flags, expireTime, cas)
|
||||
|
||||
|
||||
def _set(self, cmd, key, val, flags, expireTime, cas):
|
||||
"""
|
||||
Internal wrapper for setting values.
|
||||
"""
|
||||
if self._disconnected:
|
||||
return fail(RuntimeError("not connected"))
|
||||
if not isinstance(key, bytes):
|
||||
return fail(ClientError(
|
||||
"Invalid type for key: %s, expecting bytes" % (type(key),)))
|
||||
if len(key) > self.MAX_KEY_LENGTH:
|
||||
return fail(ClientError("Key too long"))
|
||||
if not isinstance(val, bytes):
|
||||
return fail(ClientError(
|
||||
"Invalid type for value: %s, expecting bytes" %
|
||||
(type(val),)))
|
||||
if cas:
|
||||
cas = b" " + cas
|
||||
length = len(val)
|
||||
fullcmd = b" ".join([
|
||||
cmd, key,
|
||||
networkString("%d %d %d" % (flags, expireTime, length))]) + cas
|
||||
self.sendLine(fullcmd)
|
||||
self.sendLine(val)
|
||||
cmdObj = Command(cmd, key=key, flags=flags, length=length)
|
||||
self._current.append(cmdObj)
|
||||
return cmdObj._deferred
|
||||
|
||||
|
||||
def append(self, key, val):
|
||||
"""
|
||||
Append given data to the value of an existing key.
|
||||
|
||||
@param key: The key to modify.
|
||||
@type key: L{bytes}
|
||||
|
||||
@param val: The value to append to the current value associated with
|
||||
the key.
|
||||
@type val: L{bytes}
|
||||
|
||||
@return: A deferred that will fire with C{True} if the operation has
|
||||
succeeded, C{False} otherwise.
|
||||
@rtype: L{Deferred}
|
||||
"""
|
||||
# Even if flags and expTime values are ignored, we have to pass them
|
||||
return self._set(b"append", key, val, 0, 0, b"")
|
||||
|
||||
|
||||
def prepend(self, key, val):
|
||||
"""
|
||||
Prepend given data to the value of an existing key.
|
||||
|
||||
@param key: The key to modify.
|
||||
@type key: L{bytes}
|
||||
|
||||
@param val: The value to prepend to the current value associated with
|
||||
the key.
|
||||
@type val: L{bytes}
|
||||
|
||||
@return: A deferred that will fire with C{True} if the operation has
|
||||
succeeded, C{False} otherwise.
|
||||
@rtype: L{Deferred}
|
||||
"""
|
||||
# Even if flags and expTime values are ignored, we have to pass them
|
||||
return self._set(b"prepend", key, val, 0, 0, b"")
|
||||
|
||||
|
||||
def get(self, key, withIdentifier=False):
|
||||
"""
|
||||
Get the given C{key}. It doesn't support multiple keys. If
|
||||
C{withIdentifier} is set to C{True}, the command issued is a C{gets},
|
||||
that will return the current identifier associated with the value. This
|
||||
identifier has to be used when issuing C{checkAndSet} update later,
|
||||
using the corresponding method.
|
||||
|
||||
@param key: The key to retrieve.
|
||||
@type key: L{bytes}
|
||||
|
||||
@param withIdentifier: If set to C{True}, retrieve the current
|
||||
identifier along with the value and the flags.
|
||||
@type withIdentifier: L{bool}
|
||||
|
||||
@return: A deferred that will fire with the tuple (flags, value) if
|
||||
C{withIdentifier} is C{False}, or (flags, cas identifier, value)
|
||||
if C{True}. If the server indicates there is no value
|
||||
associated with C{key}, the returned value will be L{None} and
|
||||
the returned flags will be C{0}.
|
||||
@rtype: L{Deferred}
|
||||
"""
|
||||
return self._get([key], withIdentifier, False)
|
||||
|
||||
|
||||
def getMultiple(self, keys, withIdentifier=False):
|
||||
"""
|
||||
Get the given list of C{keys}. If C{withIdentifier} is set to C{True},
|
||||
the command issued is a C{gets}, that will return the identifiers
|
||||
associated with each values. This identifier has to be used when
|
||||
issuing C{checkAndSet} update later, using the corresponding method.
|
||||
|
||||
@param keys: The keys to retrieve.
|
||||
@type keys: L{list} of L{bytes}
|
||||
|
||||
@param withIdentifier: If set to C{True}, retrieve the identifiers
|
||||
along with the values and the flags.
|
||||
@type withIdentifier: L{bool}
|
||||
|
||||
@return: A deferred that will fire with a dictionary with the elements
|
||||
of C{keys} as keys and the tuples (flags, value) as values if
|
||||
C{withIdentifier} is C{False}, or (flags, cas identifier, value) if
|
||||
C{True}. If the server indicates there is no value associated with
|
||||
C{key}, the returned values will be L{None} and the returned flags
|
||||
will be C{0}.
|
||||
@rtype: L{Deferred}
|
||||
|
||||
@since: 9.0
|
||||
"""
|
||||
return self._get(keys, withIdentifier, True)
|
||||
|
||||
|
||||
def _get(self, keys, withIdentifier, multiple):
|
||||
"""
|
||||
Helper method for C{get} and C{getMultiple}.
|
||||
"""
|
||||
keys = list(keys)
|
||||
if self._disconnected:
|
||||
return fail(RuntimeError("not connected"))
|
||||
for key in keys:
|
||||
if not isinstance(key, bytes):
|
||||
return fail(ClientError(
|
||||
"Invalid type for key: %s, expecting bytes" %
|
||||
(type(key),)))
|
||||
if len(key) > self.MAX_KEY_LENGTH:
|
||||
return fail(ClientError("Key too long"))
|
||||
if withIdentifier:
|
||||
cmd = b"gets"
|
||||
else:
|
||||
cmd = b"get"
|
||||
fullcmd = b" ".join([cmd] + keys)
|
||||
self.sendLine(fullcmd)
|
||||
if multiple:
|
||||
values = dict([(key, (0, b"", None)) for key in keys])
|
||||
cmdObj = Command(cmd, keys=keys, values=values, multiple=True)
|
||||
else:
|
||||
cmdObj = Command(cmd, key=keys[0], value=None, flags=0, cas=b"",
|
||||
multiple=False)
|
||||
self._current.append(cmdObj)
|
||||
return cmdObj._deferred
|
||||
|
||||
|
||||
def stats(self, arg=None):
|
||||
"""
|
||||
Get some stats from the server. It will be available as a dict.
|
||||
|
||||
@param arg: An optional additional string which will be sent along
|
||||
with the I{stats} command. The interpretation of this value by
|
||||
the server is left undefined by the memcache protocol
|
||||
specification.
|
||||
@type arg: L{None} or L{bytes}
|
||||
|
||||
@return: a deferred that will fire with a L{dict} of the available
|
||||
statistics.
|
||||
@rtype: L{Deferred}
|
||||
"""
|
||||
if arg:
|
||||
cmd = b"stats " + arg
|
||||
else:
|
||||
cmd = b"stats"
|
||||
if self._disconnected:
|
||||
return fail(RuntimeError("not connected"))
|
||||
self.sendLine(cmd)
|
||||
cmdObj = Command(b"stats", values={})
|
||||
self._current.append(cmdObj)
|
||||
return cmdObj._deferred
|
||||
|
||||
|
||||
def version(self):
|
||||
"""
|
||||
Get the version of the server.
|
||||
|
||||
@return: a deferred that will fire with the string value of the
|
||||
version.
|
||||
@rtype: L{Deferred}
|
||||
"""
|
||||
if self._disconnected:
|
||||
return fail(RuntimeError("not connected"))
|
||||
self.sendLine(b"version")
|
||||
cmdObj = Command(b"version")
|
||||
self._current.append(cmdObj)
|
||||
return cmdObj._deferred
|
||||
|
||||
|
||||
def delete(self, key):
|
||||
"""
|
||||
Delete an existing C{key}.
|
||||
|
||||
@param key: the key to delete.
|
||||
@type key: L{bytes}
|
||||
|
||||
@return: a deferred that will be called back with C{True} if the key
|
||||
was successfully deleted, or C{False} if not.
|
||||
@rtype: L{Deferred}
|
||||
"""
|
||||
if self._disconnected:
|
||||
return fail(RuntimeError("not connected"))
|
||||
if not isinstance(key, bytes):
|
||||
return fail(ClientError(
|
||||
"Invalid type for key: %s, expecting bytes" % (type(key),)))
|
||||
self.sendLine(b"delete " + key)
|
||||
cmdObj = Command(b"delete", key=key)
|
||||
self._current.append(cmdObj)
|
||||
return cmdObj._deferred
|
||||
|
||||
|
||||
def flushAll(self):
|
||||
"""
|
||||
Flush all cached values.
|
||||
|
||||
@return: a deferred that will be called back with C{True} when the
|
||||
operation has succeeded.
|
||||
@rtype: L{Deferred}
|
||||
"""
|
||||
if self._disconnected:
|
||||
return fail(RuntimeError("not connected"))
|
||||
self.sendLine(b"flush_all")
|
||||
cmdObj = Command(b"flush_all")
|
||||
self._current.append(cmdObj)
|
||||
return cmdObj._deferred
|
||||
|
||||
|
||||
|
||||
__all__ = ["MemCacheProtocol", "DEFAULT_PORT", "NoSuchCommand", "ClientError",
|
||||
"ServerError"]
|
||||
203
venv/lib/python3.9/site-packages/twisted/protocols/pcp.py
Normal file
203
venv/lib/python3.9/site-packages/twisted/protocols/pcp.py
Normal file
|
|
@ -0,0 +1,203 @@
|
|||
# -*- test-case-name: twisted.test.test_pcp -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Producer-Consumer Proxy.
|
||||
"""
|
||||
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.internet import interfaces
|
||||
|
||||
|
||||
@implementer(interfaces.IProducer, interfaces.IConsumer)
|
||||
class BasicProducerConsumerProxy:
|
||||
"""
|
||||
I can act as a man in the middle between any Producer and Consumer.
|
||||
|
||||
@ivar producer: the Producer I subscribe to.
|
||||
@type producer: L{IProducer<interfaces.IProducer>}
|
||||
@ivar consumer: the Consumer I publish to.
|
||||
@type consumer: L{IConsumer<interfaces.IConsumer>}
|
||||
@ivar paused: As a Producer, am I paused?
|
||||
@type paused: bool
|
||||
"""
|
||||
consumer = None
|
||||
producer = None
|
||||
producerIsStreaming = None
|
||||
iAmStreaming = True
|
||||
outstandingPull = False
|
||||
paused = False
|
||||
stopped = False
|
||||
|
||||
def __init__(self, consumer):
|
||||
self._buffer = []
|
||||
if consumer is not None:
|
||||
self.consumer = consumer
|
||||
consumer.registerProducer(self, self.iAmStreaming)
|
||||
|
||||
# Producer methods:
|
||||
|
||||
def pauseProducing(self):
|
||||
self.paused = True
|
||||
if self.producer:
|
||||
self.producer.pauseProducing()
|
||||
|
||||
def resumeProducing(self):
|
||||
self.paused = False
|
||||
if self._buffer:
|
||||
# TODO: Check to see if consumer supports writeSeq.
|
||||
self.consumer.write(''.join(self._buffer))
|
||||
self._buffer[:] = []
|
||||
else:
|
||||
if not self.iAmStreaming:
|
||||
self.outstandingPull = True
|
||||
|
||||
if self.producer is not None:
|
||||
self.producer.resumeProducing()
|
||||
|
||||
def stopProducing(self):
|
||||
if self.producer is not None:
|
||||
self.producer.stopProducing()
|
||||
if self.consumer is not None:
|
||||
del self.consumer
|
||||
|
||||
# Consumer methods:
|
||||
|
||||
def write(self, data):
|
||||
if self.paused or (not self.iAmStreaming and not self.outstandingPull):
|
||||
# We could use that fifo queue here.
|
||||
self._buffer.append(data)
|
||||
|
||||
elif self.consumer is not None:
|
||||
self.consumer.write(data)
|
||||
self.outstandingPull = False
|
||||
|
||||
def finish(self):
|
||||
if self.consumer is not None:
|
||||
self.consumer.finish()
|
||||
self.unregisterProducer()
|
||||
|
||||
def registerProducer(self, producer, streaming):
|
||||
self.producer = producer
|
||||
self.producerIsStreaming = streaming
|
||||
|
||||
def unregisterProducer(self):
|
||||
if self.producer is not None:
|
||||
del self.producer
|
||||
del self.producerIsStreaming
|
||||
if self.consumer:
|
||||
self.consumer.unregisterProducer()
|
||||
|
||||
def __repr__(self):
|
||||
return '<%s@%x around %s>' % (self.__class__, id(self), self.consumer)
|
||||
|
||||
|
||||
class ProducerConsumerProxy(BasicProducerConsumerProxy):
|
||||
"""ProducerConsumerProxy with a finite buffer.
|
||||
|
||||
When my buffer fills up, I have my parent Producer pause until my buffer
|
||||
has room in it again.
|
||||
"""
|
||||
# Copies much from abstract.FileDescriptor
|
||||
bufferSize = 2**2**2**2
|
||||
|
||||
producerPaused = False
|
||||
unregistered = False
|
||||
|
||||
def pauseProducing(self):
|
||||
# Does *not* call up to ProducerConsumerProxy to relay the pause
|
||||
# message through to my parent Producer.
|
||||
self.paused = True
|
||||
|
||||
def resumeProducing(self):
|
||||
self.paused = False
|
||||
if self._buffer:
|
||||
data = ''.join(self._buffer)
|
||||
bytesSent = self._writeSomeData(data)
|
||||
if bytesSent < len(data):
|
||||
unsent = data[bytesSent:]
|
||||
assert not self.iAmStreaming, (
|
||||
"Streaming producer did not write all its data.")
|
||||
self._buffer[:] = [unsent]
|
||||
else:
|
||||
self._buffer[:] = []
|
||||
else:
|
||||
bytesSent = 0
|
||||
|
||||
if (self.unregistered and bytesSent and not self._buffer and
|
||||
self.consumer is not None):
|
||||
self.consumer.unregisterProducer()
|
||||
|
||||
if not self.iAmStreaming:
|
||||
self.outstandingPull = not bytesSent
|
||||
|
||||
if self.producer is not None:
|
||||
bytesBuffered = sum([len(s) for s in self._buffer])
|
||||
# TODO: You can see here the potential for high and low
|
||||
# watermarks, where bufferSize would be the high mark when we
|
||||
# ask the upstream producer to pause, and we wouldn't have
|
||||
# it resume again until it hit the low mark. Or if producer
|
||||
# is Pull, maybe we'd like to pull from it as much as necessary
|
||||
# to keep our buffer full to the low mark, so we're never caught
|
||||
# without something to send.
|
||||
if self.producerPaused and (bytesBuffered < self.bufferSize):
|
||||
# Now that our buffer is empty,
|
||||
self.producerPaused = False
|
||||
self.producer.resumeProducing()
|
||||
elif self.outstandingPull:
|
||||
# I did not have any data to write in response to a pull,
|
||||
# so I'd better pull some myself.
|
||||
self.producer.resumeProducing()
|
||||
|
||||
def write(self, data):
|
||||
if self.paused or (not self.iAmStreaming and not self.outstandingPull):
|
||||
# We could use that fifo queue here.
|
||||
self._buffer.append(data)
|
||||
|
||||
elif self.consumer is not None:
|
||||
assert not self._buffer, (
|
||||
"Writing fresh data to consumer before my buffer is empty!")
|
||||
# I'm going to use _writeSomeData here so that there is only one
|
||||
# path to self.consumer.write. But it doesn't actually make sense,
|
||||
# if I am streaming, for some data to not be all data. But maybe I
|
||||
# am not streaming, but I am writing here anyway, because there was
|
||||
# an earlier request for data which was not answered.
|
||||
bytesSent = self._writeSomeData(data)
|
||||
self.outstandingPull = False
|
||||
if not bytesSent == len(data):
|
||||
assert not self.iAmStreaming, (
|
||||
"Streaming producer did not write all its data.")
|
||||
self._buffer.append(data[bytesSent:])
|
||||
|
||||
if (self.producer is not None) and self.producerIsStreaming:
|
||||
bytesBuffered = sum([len(s) for s in self._buffer])
|
||||
if bytesBuffered >= self.bufferSize:
|
||||
|
||||
self.producer.pauseProducing()
|
||||
self.producerPaused = True
|
||||
|
||||
def registerProducer(self, producer, streaming):
|
||||
self.unregistered = False
|
||||
BasicProducerConsumerProxy.registerProducer(self, producer, streaming)
|
||||
if not streaming:
|
||||
producer.resumeProducing()
|
||||
|
||||
def unregisterProducer(self):
|
||||
if self.producer is not None:
|
||||
del self.producer
|
||||
del self.producerIsStreaming
|
||||
self.unregistered = True
|
||||
if self.consumer and not self._buffer:
|
||||
self.consumer.unregisterProducer()
|
||||
|
||||
def _writeSomeData(self, data):
|
||||
"""Write as much of this data as possible.
|
||||
|
||||
@returns: The number of bytes written.
|
||||
"""
|
||||
if self.consumer is None:
|
||||
return 0
|
||||
self.consumer.write(data)
|
||||
return len(data)
|
||||
751
venv/lib/python3.9/site-packages/twisted/protocols/policies.py
Normal file
751
venv/lib/python3.9/site-packages/twisted/protocols/policies.py
Normal file
|
|
@ -0,0 +1,751 @@
|
|||
# -*- test-case-name: twisted.test.test_policies -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Resource limiting policies.
|
||||
|
||||
@seealso: See also L{twisted.protocols.htb} for rate limiting.
|
||||
"""
|
||||
|
||||
from __future__ import division, absolute_import
|
||||
|
||||
# system imports
|
||||
import sys
|
||||
|
||||
from zope.interface import directlyProvides, providedBy
|
||||
|
||||
# twisted imports
|
||||
from twisted.internet.protocol import ServerFactory, Protocol, ClientFactory
|
||||
from twisted.internet import error
|
||||
from twisted.internet.interfaces import ILoggingContext
|
||||
from twisted.python import log
|
||||
|
||||
|
||||
def _wrappedLogPrefix(wrapper, wrapped):
|
||||
"""
|
||||
Compute a log prefix for a wrapper and the object it wraps.
|
||||
|
||||
@rtype: C{str}
|
||||
"""
|
||||
if ILoggingContext.providedBy(wrapped):
|
||||
logPrefix = wrapped.logPrefix()
|
||||
else:
|
||||
logPrefix = wrapped.__class__.__name__
|
||||
return "%s (%s)" % (logPrefix, wrapper.__class__.__name__)
|
||||
|
||||
|
||||
|
||||
class ProtocolWrapper(Protocol):
|
||||
"""
|
||||
Wraps protocol instances and acts as their transport as well.
|
||||
|
||||
@ivar wrappedProtocol: An L{IProtocol<twisted.internet.interfaces.IProtocol>}
|
||||
provider to which L{IProtocol<twisted.internet.interfaces.IProtocol>}
|
||||
method calls onto this L{ProtocolWrapper} will be proxied.
|
||||
|
||||
@ivar factory: The L{WrappingFactory} which created this
|
||||
L{ProtocolWrapper}.
|
||||
"""
|
||||
|
||||
disconnecting = 0
|
||||
|
||||
def __init__(self, factory, wrappedProtocol):
|
||||
self.wrappedProtocol = wrappedProtocol
|
||||
self.factory = factory
|
||||
|
||||
|
||||
def logPrefix(self):
|
||||
"""
|
||||
Use a customized log prefix mentioning both the wrapped protocol and
|
||||
the current one.
|
||||
"""
|
||||
return _wrappedLogPrefix(self, self.wrappedProtocol)
|
||||
|
||||
|
||||
def makeConnection(self, transport):
|
||||
"""
|
||||
When a connection is made, register this wrapper with its factory,
|
||||
save the real transport, and connect the wrapped protocol to this
|
||||
L{ProtocolWrapper} to intercept any transport calls it makes.
|
||||
"""
|
||||
directlyProvides(self, providedBy(transport))
|
||||
Protocol.makeConnection(self, transport)
|
||||
self.factory.registerProtocol(self)
|
||||
self.wrappedProtocol.makeConnection(self)
|
||||
|
||||
|
||||
# Transport relaying
|
||||
|
||||
def write(self, data):
|
||||
self.transport.write(data)
|
||||
|
||||
|
||||
def writeSequence(self, data):
|
||||
self.transport.writeSequence(data)
|
||||
|
||||
|
||||
def loseConnection(self):
|
||||
self.disconnecting = 1
|
||||
self.transport.loseConnection()
|
||||
|
||||
|
||||
def getPeer(self):
|
||||
return self.transport.getPeer()
|
||||
|
||||
|
||||
def getHost(self):
|
||||
return self.transport.getHost()
|
||||
|
||||
|
||||
def registerProducer(self, producer, streaming):
|
||||
self.transport.registerProducer(producer, streaming)
|
||||
|
||||
|
||||
def unregisterProducer(self):
|
||||
self.transport.unregisterProducer()
|
||||
|
||||
|
||||
def stopConsuming(self):
|
||||
self.transport.stopConsuming()
|
||||
|
||||
|
||||
def __getattr__(self, name):
|
||||
return getattr(self.transport, name)
|
||||
|
||||
|
||||
# Protocol relaying
|
||||
|
||||
def dataReceived(self, data):
|
||||
self.wrappedProtocol.dataReceived(data)
|
||||
|
||||
|
||||
def connectionLost(self, reason):
|
||||
self.factory.unregisterProtocol(self)
|
||||
self.wrappedProtocol.connectionLost(reason)
|
||||
|
||||
# Breaking reference cycle between self and wrappedProtocol.
|
||||
self.wrappedProtocol = None
|
||||
|
||||
|
||||
class WrappingFactory(ClientFactory):
|
||||
"""
|
||||
Wraps a factory and its protocols, and keeps track of them.
|
||||
"""
|
||||
|
||||
protocol = ProtocolWrapper
|
||||
|
||||
def __init__(self, wrappedFactory):
|
||||
self.wrappedFactory = wrappedFactory
|
||||
self.protocols = {}
|
||||
|
||||
|
||||
def logPrefix(self):
|
||||
"""
|
||||
Generate a log prefix mentioning both the wrapped factory and this one.
|
||||
"""
|
||||
return _wrappedLogPrefix(self, self.wrappedFactory)
|
||||
|
||||
|
||||
def doStart(self):
|
||||
self.wrappedFactory.doStart()
|
||||
ClientFactory.doStart(self)
|
||||
|
||||
|
||||
def doStop(self):
|
||||
self.wrappedFactory.doStop()
|
||||
ClientFactory.doStop(self)
|
||||
|
||||
|
||||
def startedConnecting(self, connector):
|
||||
self.wrappedFactory.startedConnecting(connector)
|
||||
|
||||
|
||||
def clientConnectionFailed(self, connector, reason):
|
||||
self.wrappedFactory.clientConnectionFailed(connector, reason)
|
||||
|
||||
|
||||
def clientConnectionLost(self, connector, reason):
|
||||
self.wrappedFactory.clientConnectionLost(connector, reason)
|
||||
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
return self.protocol(self, self.wrappedFactory.buildProtocol(addr))
|
||||
|
||||
|
||||
def registerProtocol(self, p):
|
||||
"""
|
||||
Called by protocol to register itself.
|
||||
"""
|
||||
self.protocols[p] = 1
|
||||
|
||||
|
||||
def unregisterProtocol(self, p):
|
||||
"""
|
||||
Called by protocols when they go away.
|
||||
"""
|
||||
del self.protocols[p]
|
||||
|
||||
|
||||
|
||||
class ThrottlingProtocol(ProtocolWrapper):
|
||||
"""
|
||||
Protocol for L{ThrottlingFactory}.
|
||||
"""
|
||||
|
||||
# wrap API for tracking bandwidth
|
||||
|
||||
def write(self, data):
|
||||
self.factory.registerWritten(len(data))
|
||||
ProtocolWrapper.write(self, data)
|
||||
|
||||
|
||||
def writeSequence(self, seq):
|
||||
self.factory.registerWritten(sum(map(len, seq)))
|
||||
ProtocolWrapper.writeSequence(self, seq)
|
||||
|
||||
|
||||
def dataReceived(self, data):
|
||||
self.factory.registerRead(len(data))
|
||||
ProtocolWrapper.dataReceived(self, data)
|
||||
|
||||
|
||||
def registerProducer(self, producer, streaming):
|
||||
self.producer = producer
|
||||
ProtocolWrapper.registerProducer(self, producer, streaming)
|
||||
|
||||
|
||||
def unregisterProducer(self):
|
||||
del self.producer
|
||||
ProtocolWrapper.unregisterProducer(self)
|
||||
|
||||
|
||||
def throttleReads(self):
|
||||
self.transport.pauseProducing()
|
||||
|
||||
|
||||
def unthrottleReads(self):
|
||||
self.transport.resumeProducing()
|
||||
|
||||
|
||||
def throttleWrites(self):
|
||||
if hasattr(self, "producer"):
|
||||
self.producer.pauseProducing()
|
||||
|
||||
|
||||
def unthrottleWrites(self):
|
||||
if hasattr(self, "producer"):
|
||||
self.producer.resumeProducing()
|
||||
|
||||
|
||||
|
||||
class ThrottlingFactory(WrappingFactory):
|
||||
"""
|
||||
Throttles bandwidth and number of connections.
|
||||
|
||||
Write bandwidth will only be throttled if there is a producer
|
||||
registered.
|
||||
"""
|
||||
|
||||
protocol = ThrottlingProtocol
|
||||
|
||||
def __init__(self, wrappedFactory, maxConnectionCount=sys.maxsize,
|
||||
readLimit=None, writeLimit=None):
|
||||
WrappingFactory.__init__(self, wrappedFactory)
|
||||
self.connectionCount = 0
|
||||
self.maxConnectionCount = maxConnectionCount
|
||||
self.readLimit = readLimit # max bytes we should read per second
|
||||
self.writeLimit = writeLimit # max bytes we should write per second
|
||||
self.readThisSecond = 0
|
||||
self.writtenThisSecond = 0
|
||||
self.unthrottleReadsID = None
|
||||
self.checkReadBandwidthID = None
|
||||
self.unthrottleWritesID = None
|
||||
self.checkWriteBandwidthID = None
|
||||
|
||||
|
||||
def callLater(self, period, func):
|
||||
"""
|
||||
Wrapper around
|
||||
L{reactor.callLater<twisted.internet.interfaces.IReactorTime.callLater>}
|
||||
for test purpose.
|
||||
"""
|
||||
from twisted.internet import reactor
|
||||
return reactor.callLater(period, func)
|
||||
|
||||
|
||||
def registerWritten(self, length):
|
||||
"""
|
||||
Called by protocol to tell us more bytes were written.
|
||||
"""
|
||||
self.writtenThisSecond += length
|
||||
|
||||
|
||||
def registerRead(self, length):
|
||||
"""
|
||||
Called by protocol to tell us more bytes were read.
|
||||
"""
|
||||
self.readThisSecond += length
|
||||
|
||||
|
||||
def checkReadBandwidth(self):
|
||||
"""
|
||||
Checks if we've passed bandwidth limits.
|
||||
"""
|
||||
if self.readThisSecond > self.readLimit:
|
||||
self.throttleReads()
|
||||
throttleTime = (float(self.readThisSecond) / self.readLimit) - 1.0
|
||||
self.unthrottleReadsID = self.callLater(throttleTime,
|
||||
self.unthrottleReads)
|
||||
self.readThisSecond = 0
|
||||
self.checkReadBandwidthID = self.callLater(1, self.checkReadBandwidth)
|
||||
|
||||
|
||||
def checkWriteBandwidth(self):
|
||||
if self.writtenThisSecond > self.writeLimit:
|
||||
self.throttleWrites()
|
||||
throttleTime = (float(self.writtenThisSecond) / self.writeLimit) - 1.0
|
||||
self.unthrottleWritesID = self.callLater(throttleTime,
|
||||
self.unthrottleWrites)
|
||||
# reset for next round
|
||||
self.writtenThisSecond = 0
|
||||
self.checkWriteBandwidthID = self.callLater(1, self.checkWriteBandwidth)
|
||||
|
||||
|
||||
def throttleReads(self):
|
||||
"""
|
||||
Throttle reads on all protocols.
|
||||
"""
|
||||
log.msg("Throttling reads on %s" % self)
|
||||
for p in self.protocols.keys():
|
||||
p.throttleReads()
|
||||
|
||||
|
||||
def unthrottleReads(self):
|
||||
"""
|
||||
Stop throttling reads on all protocols.
|
||||
"""
|
||||
self.unthrottleReadsID = None
|
||||
log.msg("Stopped throttling reads on %s" % self)
|
||||
for p in self.protocols.keys():
|
||||
p.unthrottleReads()
|
||||
|
||||
|
||||
def throttleWrites(self):
|
||||
"""
|
||||
Throttle writes on all protocols.
|
||||
"""
|
||||
log.msg("Throttling writes on %s" % self)
|
||||
for p in self.protocols.keys():
|
||||
p.throttleWrites()
|
||||
|
||||
|
||||
def unthrottleWrites(self):
|
||||
"""
|
||||
Stop throttling writes on all protocols.
|
||||
"""
|
||||
self.unthrottleWritesID = None
|
||||
log.msg("Stopped throttling writes on %s" % self)
|
||||
for p in self.protocols.keys():
|
||||
p.unthrottleWrites()
|
||||
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
if self.connectionCount == 0:
|
||||
if self.readLimit is not None:
|
||||
self.checkReadBandwidth()
|
||||
if self.writeLimit is not None:
|
||||
self.checkWriteBandwidth()
|
||||
|
||||
if self.connectionCount < self.maxConnectionCount:
|
||||
self.connectionCount += 1
|
||||
return WrappingFactory.buildProtocol(self, addr)
|
||||
else:
|
||||
log.msg("Max connection count reached!")
|
||||
return None
|
||||
|
||||
|
||||
def unregisterProtocol(self, p):
|
||||
WrappingFactory.unregisterProtocol(self, p)
|
||||
self.connectionCount -= 1
|
||||
if self.connectionCount == 0:
|
||||
if self.unthrottleReadsID is not None:
|
||||
self.unthrottleReadsID.cancel()
|
||||
if self.checkReadBandwidthID is not None:
|
||||
self.checkReadBandwidthID.cancel()
|
||||
if self.unthrottleWritesID is not None:
|
||||
self.unthrottleWritesID.cancel()
|
||||
if self.checkWriteBandwidthID is not None:
|
||||
self.checkWriteBandwidthID.cancel()
|
||||
|
||||
|
||||
|
||||
class SpewingProtocol(ProtocolWrapper):
|
||||
def dataReceived(self, data):
|
||||
log.msg("Received: %r" % data)
|
||||
ProtocolWrapper.dataReceived(self,data)
|
||||
|
||||
def write(self, data):
|
||||
log.msg("Sending: %r" % data)
|
||||
ProtocolWrapper.write(self,data)
|
||||
|
||||
|
||||
|
||||
class SpewingFactory(WrappingFactory):
|
||||
protocol = SpewingProtocol
|
||||
|
||||
|
||||
|
||||
class LimitConnectionsByPeer(WrappingFactory):
|
||||
|
||||
maxConnectionsPerPeer = 5
|
||||
|
||||
def startFactory(self):
|
||||
self.peerConnections = {}
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
peerHost = addr[0]
|
||||
connectionCount = self.peerConnections.get(peerHost, 0)
|
||||
if connectionCount >= self.maxConnectionsPerPeer:
|
||||
return None
|
||||
self.peerConnections[peerHost] = connectionCount + 1
|
||||
return WrappingFactory.buildProtocol(self, addr)
|
||||
|
||||
def unregisterProtocol(self, p):
|
||||
peerHost = p.getPeer()[1]
|
||||
self.peerConnections[peerHost] -= 1
|
||||
if self.peerConnections[peerHost] == 0:
|
||||
del self.peerConnections[peerHost]
|
||||
|
||||
|
||||
class LimitTotalConnectionsFactory(ServerFactory):
|
||||
"""
|
||||
Factory that limits the number of simultaneous connections.
|
||||
|
||||
@type connectionCount: C{int}
|
||||
@ivar connectionCount: number of current connections.
|
||||
@type connectionLimit: C{int} or L{None}
|
||||
@cvar connectionLimit: maximum number of connections.
|
||||
@type overflowProtocol: L{Protocol} or L{None}
|
||||
@cvar overflowProtocol: Protocol to use for new connections when
|
||||
connectionLimit is exceeded. If L{None} (the default value), excess
|
||||
connections will be closed immediately.
|
||||
"""
|
||||
connectionCount = 0
|
||||
connectionLimit = None
|
||||
overflowProtocol = None
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
if (self.connectionLimit is None or
|
||||
self.connectionCount < self.connectionLimit):
|
||||
# Build the normal protocol
|
||||
wrappedProtocol = self.protocol()
|
||||
elif self.overflowProtocol is None:
|
||||
# Just drop the connection
|
||||
return None
|
||||
else:
|
||||
# Too many connections, so build the overflow protocol
|
||||
wrappedProtocol = self.overflowProtocol()
|
||||
|
||||
wrappedProtocol.factory = self
|
||||
protocol = ProtocolWrapper(self, wrappedProtocol)
|
||||
self.connectionCount += 1
|
||||
return protocol
|
||||
|
||||
def registerProtocol(self, p):
|
||||
pass
|
||||
|
||||
def unregisterProtocol(self, p):
|
||||
self.connectionCount -= 1
|
||||
|
||||
|
||||
|
||||
class TimeoutProtocol(ProtocolWrapper):
|
||||
"""
|
||||
Protocol that automatically disconnects when the connection is idle.
|
||||
"""
|
||||
|
||||
def __init__(self, factory, wrappedProtocol, timeoutPeriod):
|
||||
"""
|
||||
Constructor.
|
||||
|
||||
@param factory: An L{TimeoutFactory}.
|
||||
@param wrappedProtocol: A L{Protocol} to wrapp.
|
||||
@param timeoutPeriod: Number of seconds to wait for activity before
|
||||
timing out.
|
||||
"""
|
||||
ProtocolWrapper.__init__(self, factory, wrappedProtocol)
|
||||
self.timeoutCall = None
|
||||
self.timeoutPeriod = None
|
||||
self.setTimeout(timeoutPeriod)
|
||||
|
||||
|
||||
def setTimeout(self, timeoutPeriod=None):
|
||||
"""
|
||||
Set a timeout.
|
||||
|
||||
This will cancel any existing timeouts.
|
||||
|
||||
@param timeoutPeriod: If not L{None}, change the timeout period.
|
||||
Otherwise, use the existing value.
|
||||
"""
|
||||
self.cancelTimeout()
|
||||
self.timeoutPeriod = timeoutPeriod
|
||||
if timeoutPeriod is not None:
|
||||
self.timeoutCall = self.factory.callLater(self.timeoutPeriod, self.timeoutFunc)
|
||||
|
||||
|
||||
def cancelTimeout(self):
|
||||
"""
|
||||
Cancel the timeout.
|
||||
|
||||
If the timeout was already cancelled, this does nothing.
|
||||
"""
|
||||
self.timeoutPeriod = None
|
||||
if self.timeoutCall:
|
||||
try:
|
||||
self.timeoutCall.cancel()
|
||||
except (error.AlreadyCalled, error.AlreadyCancelled):
|
||||
pass
|
||||
self.timeoutCall = None
|
||||
|
||||
|
||||
def resetTimeout(self):
|
||||
"""
|
||||
Reset the timeout, usually because some activity just happened.
|
||||
"""
|
||||
if self.timeoutCall:
|
||||
self.timeoutCall.reset(self.timeoutPeriod)
|
||||
|
||||
|
||||
def write(self, data):
|
||||
self.resetTimeout()
|
||||
ProtocolWrapper.write(self, data)
|
||||
|
||||
|
||||
def writeSequence(self, seq):
|
||||
self.resetTimeout()
|
||||
ProtocolWrapper.writeSequence(self, seq)
|
||||
|
||||
|
||||
def dataReceived(self, data):
|
||||
self.resetTimeout()
|
||||
ProtocolWrapper.dataReceived(self, data)
|
||||
|
||||
|
||||
def connectionLost(self, reason):
|
||||
self.cancelTimeout()
|
||||
ProtocolWrapper.connectionLost(self, reason)
|
||||
|
||||
|
||||
def timeoutFunc(self):
|
||||
"""
|
||||
This method is called when the timeout is triggered.
|
||||
|
||||
By default it calls I{loseConnection}. Override this if you want
|
||||
something else to happen.
|
||||
"""
|
||||
self.loseConnection()
|
||||
|
||||
|
||||
|
||||
class TimeoutFactory(WrappingFactory):
|
||||
"""
|
||||
Factory for TimeoutWrapper.
|
||||
"""
|
||||
protocol = TimeoutProtocol
|
||||
|
||||
|
||||
def __init__(self, wrappedFactory, timeoutPeriod=30*60):
|
||||
self.timeoutPeriod = timeoutPeriod
|
||||
WrappingFactory.__init__(self, wrappedFactory)
|
||||
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
return self.protocol(self, self.wrappedFactory.buildProtocol(addr),
|
||||
timeoutPeriod=self.timeoutPeriod)
|
||||
|
||||
|
||||
def callLater(self, period, func):
|
||||
"""
|
||||
Wrapper around
|
||||
L{reactor.callLater<twisted.internet.interfaces.IReactorTime.callLater>}
|
||||
for test purpose.
|
||||
"""
|
||||
from twisted.internet import reactor
|
||||
return reactor.callLater(period, func)
|
||||
|
||||
|
||||
|
||||
class TrafficLoggingProtocol(ProtocolWrapper):
|
||||
|
||||
def __init__(self, factory, wrappedProtocol, logfile, lengthLimit=None,
|
||||
number=0):
|
||||
"""
|
||||
@param factory: factory which created this protocol.
|
||||
@type factory: L{protocol.Factory}.
|
||||
@param wrappedProtocol: the underlying protocol.
|
||||
@type wrappedProtocol: C{protocol.Protocol}.
|
||||
@param logfile: file opened for writing used to write log messages.
|
||||
@type logfile: C{file}
|
||||
@param lengthLimit: maximum size of the datareceived logged.
|
||||
@type lengthLimit: C{int}
|
||||
@param number: identifier of the connection.
|
||||
@type number: C{int}.
|
||||
"""
|
||||
ProtocolWrapper.__init__(self, factory, wrappedProtocol)
|
||||
self.logfile = logfile
|
||||
self.lengthLimit = lengthLimit
|
||||
self._number = number
|
||||
|
||||
|
||||
def _log(self, line):
|
||||
self.logfile.write(line + '\n')
|
||||
self.logfile.flush()
|
||||
|
||||
|
||||
def _mungeData(self, data):
|
||||
if self.lengthLimit and len(data) > self.lengthLimit:
|
||||
data = data[:self.lengthLimit - 12] + '<... elided>'
|
||||
return data
|
||||
|
||||
|
||||
# IProtocol
|
||||
def connectionMade(self):
|
||||
self._log('*')
|
||||
return ProtocolWrapper.connectionMade(self)
|
||||
|
||||
|
||||
def dataReceived(self, data):
|
||||
self._log('C %d: %r' % (self._number, self._mungeData(data)))
|
||||
return ProtocolWrapper.dataReceived(self, data)
|
||||
|
||||
|
||||
def connectionLost(self, reason):
|
||||
self._log('C %d: %r' % (self._number, reason))
|
||||
return ProtocolWrapper.connectionLost(self, reason)
|
||||
|
||||
|
||||
# ITransport
|
||||
def write(self, data):
|
||||
self._log('S %d: %r' % (self._number, self._mungeData(data)))
|
||||
return ProtocolWrapper.write(self, data)
|
||||
|
||||
|
||||
def writeSequence(self, iovec):
|
||||
self._log('SV %d: %r' % (self._number, [self._mungeData(d) for d in iovec]))
|
||||
return ProtocolWrapper.writeSequence(self, iovec)
|
||||
|
||||
|
||||
def loseConnection(self):
|
||||
self._log('S %d: *' % (self._number,))
|
||||
return ProtocolWrapper.loseConnection(self)
|
||||
|
||||
|
||||
|
||||
class TrafficLoggingFactory(WrappingFactory):
|
||||
protocol = TrafficLoggingProtocol
|
||||
|
||||
_counter = 0
|
||||
|
||||
def __init__(self, wrappedFactory, logfilePrefix, lengthLimit=None):
|
||||
self.logfilePrefix = logfilePrefix
|
||||
self.lengthLimit = lengthLimit
|
||||
WrappingFactory.__init__(self, wrappedFactory)
|
||||
|
||||
|
||||
def open(self, name):
|
||||
return open(name, 'w')
|
||||
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
self._counter += 1
|
||||
logfile = self.open(self.logfilePrefix + '-' + str(self._counter))
|
||||
return self.protocol(self, self.wrappedFactory.buildProtocol(addr),
|
||||
logfile, self.lengthLimit, self._counter)
|
||||
|
||||
|
||||
def resetCounter(self):
|
||||
"""
|
||||
Reset the value of the counter used to identify connections.
|
||||
"""
|
||||
self._counter = 0
|
||||
|
||||
|
||||
|
||||
class TimeoutMixin:
|
||||
"""
|
||||
Mixin for protocols which wish to timeout connections.
|
||||
|
||||
Protocols that mix this in have a single timeout, set using L{setTimeout}.
|
||||
When the timeout is hit, L{timeoutConnection} is called, which, by
|
||||
default, closes the connection.
|
||||
|
||||
@cvar timeOut: The number of seconds after which to timeout the connection.
|
||||
"""
|
||||
timeOut = None
|
||||
|
||||
__timeoutCall = None
|
||||
|
||||
def callLater(self, period, func):
|
||||
"""
|
||||
Wrapper around
|
||||
L{reactor.callLater<twisted.internet.interfaces.IReactorTime.callLater>}
|
||||
for test purpose.
|
||||
"""
|
||||
from twisted.internet import reactor
|
||||
return reactor.callLater(period, func)
|
||||
|
||||
|
||||
def resetTimeout(self):
|
||||
"""
|
||||
Reset the timeout count down.
|
||||
|
||||
If the connection has already timed out, then do nothing. If the
|
||||
timeout has been cancelled (probably using C{setTimeout(None)}), also
|
||||
do nothing.
|
||||
|
||||
It's often a good idea to call this when the protocol has received
|
||||
some meaningful input from the other end of the connection. "I've got
|
||||
some data, they're still there, reset the timeout".
|
||||
"""
|
||||
if self.__timeoutCall is not None and self.timeOut is not None:
|
||||
self.__timeoutCall.reset(self.timeOut)
|
||||
|
||||
def setTimeout(self, period):
|
||||
"""
|
||||
Change the timeout period
|
||||
|
||||
@type period: C{int} or L{None}
|
||||
@param period: The period, in seconds, to change the timeout to, or
|
||||
L{None} to disable the timeout.
|
||||
"""
|
||||
prev = self.timeOut
|
||||
self.timeOut = period
|
||||
|
||||
if self.__timeoutCall is not None:
|
||||
if period is None:
|
||||
try:
|
||||
self.__timeoutCall.cancel()
|
||||
except (error.AlreadyCancelled, error.AlreadyCalled):
|
||||
# Do nothing if the call was already consumed.
|
||||
pass
|
||||
self.__timeoutCall = None
|
||||
else:
|
||||
self.__timeoutCall.reset(period)
|
||||
elif period is not None:
|
||||
self.__timeoutCall = self.callLater(period, self.__timedOut)
|
||||
|
||||
return prev
|
||||
|
||||
def __timedOut(self):
|
||||
self.__timeoutCall = None
|
||||
self.timeoutConnection()
|
||||
|
||||
def timeoutConnection(self):
|
||||
"""
|
||||
Called when the connection times out.
|
||||
|
||||
Override to define behavior other than dropping the connection.
|
||||
"""
|
||||
self.transport.loseConnection()
|
||||
|
|
@ -0,0 +1,99 @@
|
|||
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
A simple port forwarder.
|
||||
"""
|
||||
|
||||
# Twisted imports
|
||||
from twisted.internet import protocol
|
||||
from twisted.python import log
|
||||
|
||||
class Proxy(protocol.Protocol):
|
||||
noisy = True
|
||||
|
||||
peer = None
|
||||
|
||||
def setPeer(self, peer):
|
||||
self.peer = peer
|
||||
|
||||
|
||||
def connectionLost(self, reason):
|
||||
if self.peer is not None:
|
||||
self.peer.transport.loseConnection()
|
||||
self.peer = None
|
||||
elif self.noisy:
|
||||
log.msg("Unable to connect to peer: %s" % (reason,))
|
||||
|
||||
|
||||
def dataReceived(self, data):
|
||||
self.peer.transport.write(data)
|
||||
|
||||
|
||||
|
||||
class ProxyClient(Proxy):
|
||||
def connectionMade(self):
|
||||
self.peer.setPeer(self)
|
||||
|
||||
# Wire this and the peer transport together to enable
|
||||
# flow control (this stops connections from filling
|
||||
# this proxy memory when one side produces data at a
|
||||
# higher rate than the other can consume).
|
||||
self.transport.registerProducer(self.peer.transport, True)
|
||||
self.peer.transport.registerProducer(self.transport, True)
|
||||
|
||||
# We're connected, everybody can read to their hearts content.
|
||||
self.peer.transport.resumeProducing()
|
||||
|
||||
|
||||
|
||||
class ProxyClientFactory(protocol.ClientFactory):
|
||||
|
||||
protocol = ProxyClient
|
||||
|
||||
def setServer(self, server):
|
||||
self.server = server
|
||||
|
||||
|
||||
def buildProtocol(self, *args, **kw):
|
||||
prot = protocol.ClientFactory.buildProtocol(self, *args, **kw)
|
||||
prot.setPeer(self.server)
|
||||
return prot
|
||||
|
||||
|
||||
def clientConnectionFailed(self, connector, reason):
|
||||
self.server.transport.loseConnection()
|
||||
|
||||
|
||||
|
||||
class ProxyServer(Proxy):
|
||||
|
||||
clientProtocolFactory = ProxyClientFactory
|
||||
reactor = None
|
||||
|
||||
def connectionMade(self):
|
||||
# Don't read anything from the connecting client until we have
|
||||
# somewhere to send it to.
|
||||
self.transport.pauseProducing()
|
||||
|
||||
client = self.clientProtocolFactory()
|
||||
client.setServer(self)
|
||||
|
||||
if self.reactor is None:
|
||||
from twisted.internet import reactor
|
||||
self.reactor = reactor
|
||||
self.reactor.connectTCP(self.factory.host, self.factory.port, client)
|
||||
|
||||
|
||||
|
||||
class ProxyFactory(protocol.Factory):
|
||||
"""
|
||||
Factory for port forwarder.
|
||||
"""
|
||||
|
||||
protocol = ProxyServer
|
||||
|
||||
def __init__(self, host, port):
|
||||
self.host = host
|
||||
self.port = port
|
||||
158
venv/lib/python3.9/site-packages/twisted/protocols/postfix.py
Normal file
158
venv/lib/python3.9/site-packages/twisted/protocols/postfix.py
Normal file
|
|
@ -0,0 +1,158 @@
|
|||
# -*- test-case-name: twisted.test.test_postfix -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Postfix mail transport agent related protocols.
|
||||
"""
|
||||
|
||||
import sys
|
||||
try:
|
||||
# Python 2
|
||||
from UserDict import UserDict
|
||||
except ImportError:
|
||||
# Python 3
|
||||
from collections import UserDict
|
||||
|
||||
try:
|
||||
# Python 2
|
||||
from urllib import quote as _quote, unquote as _unquote
|
||||
except ImportError:
|
||||
# Python 3
|
||||
from urllib.parse import quote as _quote, unquote as _unquote
|
||||
|
||||
from twisted.protocols import basic
|
||||
from twisted.protocols import policies
|
||||
from twisted.internet import protocol, defer
|
||||
from twisted.python import log
|
||||
from twisted.python.compat import unicode
|
||||
|
||||
# urllib's quote functions just happen to match
|
||||
# the postfix semantics.
|
||||
def quote(s):
|
||||
quoted = _quote(s)
|
||||
if isinstance(quoted, unicode):
|
||||
quoted = quoted.encode("ascii")
|
||||
return quoted
|
||||
|
||||
|
||||
|
||||
def unquote(s):
|
||||
if isinstance(s, bytes):
|
||||
s = s.decode("ascii")
|
||||
quoted = _unquote(s)
|
||||
return quoted.encode("ascii")
|
||||
|
||||
|
||||
|
||||
class PostfixTCPMapServer(basic.LineReceiver, policies.TimeoutMixin):
|
||||
"""
|
||||
Postfix mail transport agent TCP map protocol implementation.
|
||||
|
||||
Receive requests for data matching given key via lineReceived,
|
||||
asks it's factory for the data with self.factory.get(key), and
|
||||
returns the data to the requester. None means no entry found.
|
||||
|
||||
You can use postfix's postmap to test the map service::
|
||||
|
||||
/usr/sbin/postmap -q KEY tcp:localhost:4242
|
||||
|
||||
"""
|
||||
|
||||
timeout = 600
|
||||
delimiter = b'\n'
|
||||
|
||||
def connectionMade(self):
|
||||
self.setTimeout(self.timeout)
|
||||
|
||||
|
||||
|
||||
def sendCode(self, code, message=b''):
|
||||
"""
|
||||
Send an SMTP-like code with a message.
|
||||
"""
|
||||
self.sendLine(str(code).encode("ascii") + b' ' + message)
|
||||
|
||||
|
||||
|
||||
def lineReceived(self, line):
|
||||
self.resetTimeout()
|
||||
try:
|
||||
request, params = line.split(None, 1)
|
||||
except ValueError:
|
||||
request = line
|
||||
params = None
|
||||
try:
|
||||
f = getattr(self, u'do_' + request.decode("ascii"))
|
||||
except AttributeError:
|
||||
self.sendCode(400, b'unknown command')
|
||||
else:
|
||||
try:
|
||||
f(params)
|
||||
except:
|
||||
excInfo = str(sys.exc_info()[1]).encode("ascii")
|
||||
self.sendCode(400, b'Command ' + request + b' failed: ' +
|
||||
excInfo)
|
||||
|
||||
|
||||
|
||||
def do_get(self, key):
|
||||
if key is None:
|
||||
self.sendCode(400, b"Command 'get' takes 1 parameters.")
|
||||
else:
|
||||
d = defer.maybeDeferred(self.factory.get, key)
|
||||
d.addCallbacks(self._cbGot, self._cbNot)
|
||||
d.addErrback(log.err)
|
||||
|
||||
|
||||
|
||||
def _cbNot(self, fail):
|
||||
msg = fail.getErrorMessage().encode("ascii")
|
||||
self.sendCode(400, msg)
|
||||
|
||||
|
||||
|
||||
def _cbGot(self, value):
|
||||
if value is None:
|
||||
self.sendCode(500)
|
||||
else:
|
||||
self.sendCode(200, quote(value))
|
||||
|
||||
|
||||
|
||||
def do_put(self, keyAndValue):
|
||||
if keyAndValue is None:
|
||||
self.sendCode(400, b"Command 'put' takes 2 parameters.")
|
||||
else:
|
||||
try:
|
||||
key, value = keyAndValue.split(None, 1)
|
||||
except ValueError:
|
||||
self.sendCode(400, b"Command 'put' takes 2 parameters.")
|
||||
else:
|
||||
self.sendCode(500, b'put is not implemented yet.')
|
||||
|
||||
|
||||
|
||||
class PostfixTCPMapDictServerFactory(UserDict, protocol.ServerFactory):
|
||||
"""
|
||||
An in-memory dictionary factory for PostfixTCPMapServer.
|
||||
"""
|
||||
|
||||
protocol = PostfixTCPMapServer
|
||||
|
||||
|
||||
|
||||
class PostfixTCPMapDeferringDictServerFactory(protocol.ServerFactory):
|
||||
"""
|
||||
An in-memory dictionary factory for PostfixTCPMapServer.
|
||||
"""
|
||||
|
||||
protocol = PostfixTCPMapServer
|
||||
|
||||
def __init__(self, data=None):
|
||||
self.data = {}
|
||||
if data is not None:
|
||||
self.data.update(data)
|
||||
|
||||
def get(self, key):
|
||||
return defer.succeed(self.data.get(key))
|
||||
1294
venv/lib/python3.9/site-packages/twisted/protocols/sip.py
Normal file
1294
venv/lib/python3.9/site-packages/twisted/protocols/sip.py
Normal file
File diff suppressed because it is too large
Load diff
255
venv/lib/python3.9/site-packages/twisted/protocols/socks.py
Normal file
255
venv/lib/python3.9/site-packages/twisted/protocols/socks.py
Normal file
|
|
@ -0,0 +1,255 @@
|
|||
# -*- test-case-name: twisted.test.test_socks -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Implementation of the SOCKSv4 protocol.
|
||||
"""
|
||||
|
||||
# python imports
|
||||
import struct
|
||||
import string
|
||||
import socket
|
||||
import time
|
||||
|
||||
# twisted imports
|
||||
from twisted.internet import reactor, protocol, defer
|
||||
from twisted.python import log
|
||||
|
||||
|
||||
class SOCKSv4Outgoing(protocol.Protocol):
|
||||
def __init__(self, socks):
|
||||
self.socks=socks
|
||||
|
||||
|
||||
def connectionMade(self):
|
||||
peer = self.transport.getPeer()
|
||||
self.socks.makeReply(90, 0, port=peer.port, ip=peer.host)
|
||||
self.socks.otherConn=self
|
||||
|
||||
|
||||
def connectionLost(self, reason):
|
||||
self.socks.transport.loseConnection()
|
||||
|
||||
|
||||
def dataReceived(self, data):
|
||||
self.socks.write(data)
|
||||
|
||||
|
||||
def write(self,data):
|
||||
self.socks.log(self,data)
|
||||
self.transport.write(data)
|
||||
|
||||
|
||||
|
||||
class SOCKSv4Incoming(protocol.Protocol):
|
||||
def __init__(self,socks):
|
||||
self.socks=socks
|
||||
self.socks.otherConn=self
|
||||
|
||||
|
||||
def connectionLost(self, reason):
|
||||
self.socks.transport.loseConnection()
|
||||
|
||||
|
||||
def dataReceived(self,data):
|
||||
self.socks.write(data)
|
||||
|
||||
|
||||
def write(self, data):
|
||||
self.socks.log(self,data)
|
||||
self.transport.write(data)
|
||||
|
||||
|
||||
|
||||
class SOCKSv4(protocol.Protocol):
|
||||
"""
|
||||
An implementation of the SOCKSv4 protocol.
|
||||
|
||||
@type logging: L{str} or L{None}
|
||||
@ivar logging: If not L{None}, the name of the logfile to which connection
|
||||
information will be written.
|
||||
|
||||
@type reactor: object providing L{twisted.internet.interfaces.IReactorTCP}
|
||||
@ivar reactor: The reactor used to create connections.
|
||||
|
||||
@type buf: L{str}
|
||||
@ivar buf: Part of a SOCKSv4 connection request.
|
||||
|
||||
@type otherConn: C{SOCKSv4Incoming}, C{SOCKSv4Outgoing} or L{None}
|
||||
@ivar otherConn: Until the connection has been established, C{otherConn} is
|
||||
L{None}. After that, it is the proxy-to-destination protocol instance
|
||||
along which the client's connection is being forwarded.
|
||||
"""
|
||||
def __init__(self, logging=None, reactor=reactor):
|
||||
self.logging = logging
|
||||
self.reactor = reactor
|
||||
|
||||
|
||||
def connectionMade(self):
|
||||
self.buf = b""
|
||||
self.otherConn = None
|
||||
|
||||
|
||||
def dataReceived(self, data):
|
||||
"""
|
||||
Called whenever data is received.
|
||||
|
||||
@type data: L{bytes}
|
||||
@param data: Part or all of a SOCKSv4 packet.
|
||||
"""
|
||||
if self.otherConn:
|
||||
self.otherConn.write(data)
|
||||
return
|
||||
self.buf = self.buf + data
|
||||
completeBuffer = self.buf
|
||||
if b"\000" in self.buf[8:]:
|
||||
head, self.buf = self.buf[:8], self.buf[8:]
|
||||
version, code, port = struct.unpack("!BBH", head[:4])
|
||||
user, self.buf = self.buf.split(b"\000", 1)
|
||||
if head[4:7] == b"\000\000\000" and head[7:8] != b"\000":
|
||||
# An IP address of the form 0.0.0.X, where X is non-zero,
|
||||
# signifies that this is a SOCKSv4a packet.
|
||||
# If the complete packet hasn't been received, restore the
|
||||
# buffer and wait for it.
|
||||
if b"\000" not in self.buf:
|
||||
self.buf = completeBuffer
|
||||
return
|
||||
server, self.buf = self.buf.split(b"\000", 1)
|
||||
d = self.reactor.resolve(server)
|
||||
d.addCallback(self._dataReceived2, user,
|
||||
version, code, port)
|
||||
d.addErrback(lambda result, self = self: self.makeReply(91))
|
||||
return
|
||||
else:
|
||||
server = socket.inet_ntoa(head[4:8])
|
||||
|
||||
self._dataReceived2(server, user, version, code, port)
|
||||
|
||||
|
||||
def _dataReceived2(self, server, user, version, code, port):
|
||||
"""
|
||||
The second half of the SOCKS connection setup. For a SOCKSv4 packet this
|
||||
is after the server address has been extracted from the header. For a
|
||||
SOCKSv4a packet this is after the host name has been resolved.
|
||||
|
||||
@type server: L{str}
|
||||
@param server: The IP address of the destination, represented as a
|
||||
dotted quad.
|
||||
|
||||
@type user: L{str}
|
||||
@param user: The username associated with the connection.
|
||||
|
||||
@type version: L{int}
|
||||
@param version: The SOCKS protocol version number.
|
||||
|
||||
@type code: L{int}
|
||||
@param code: The comand code. 1 means establish a TCP/IP stream
|
||||
connection, and 2 means establish a TCP/IP port binding.
|
||||
|
||||
@type port: L{int}
|
||||
@param port: The port number associated with the connection.
|
||||
"""
|
||||
assert version == 4, "Bad version code: %s" % version
|
||||
if not self.authorize(code, server, port, user):
|
||||
self.makeReply(91)
|
||||
return
|
||||
if code == 1: # CONNECT
|
||||
d = self.connectClass(server, port, SOCKSv4Outgoing, self)
|
||||
d.addErrback(lambda result, self = self: self.makeReply(91))
|
||||
elif code == 2: # BIND
|
||||
d = self.listenClass(0, SOCKSv4IncomingFactory, self, server)
|
||||
d.addCallback(lambda x,
|
||||
self = self: self.makeReply(90, 0, x[1], x[0]))
|
||||
else:
|
||||
raise RuntimeError("Bad Connect Code: %s" % (code,))
|
||||
assert self.buf == b"", "hmm, still stuff in buffer... %s" % repr(
|
||||
self.buf)
|
||||
|
||||
|
||||
def connectionLost(self, reason):
|
||||
if self.otherConn:
|
||||
self.otherConn.transport.loseConnection()
|
||||
|
||||
|
||||
def authorize(self,code,server,port,user):
|
||||
log.msg("code %s connection to %s:%s (user %s) authorized" % (code,server,port,user))
|
||||
return 1
|
||||
|
||||
|
||||
def connectClass(self, host, port, klass, *args):
|
||||
return protocol.ClientCreator(reactor, klass, *args).connectTCP(host,port)
|
||||
|
||||
|
||||
def listenClass(self, port, klass, *args):
|
||||
serv = reactor.listenTCP(port, klass(*args))
|
||||
return defer.succeed(serv.getHost()[1:])
|
||||
|
||||
|
||||
def makeReply(self,reply,version=0,port=0,ip="0.0.0.0"):
|
||||
self.transport.write(struct.pack("!BBH",version,reply,port)+socket.inet_aton(ip))
|
||||
if reply!=90: self.transport.loseConnection()
|
||||
|
||||
|
||||
def write(self,data):
|
||||
self.log(self,data)
|
||||
self.transport.write(data)
|
||||
|
||||
|
||||
def log(self,proto,data):
|
||||
if not self.logging: return
|
||||
peer = self.transport.getPeer()
|
||||
their_peer = self.otherConn.transport.getPeer()
|
||||
f=open(self.logging,"a")
|
||||
f.write("%s\t%s:%d %s %s:%d\n"%(time.ctime(),
|
||||
peer.host,peer.port,
|
||||
((proto==self and '<') or '>'),
|
||||
their_peer.host,their_peer.port))
|
||||
while data:
|
||||
p,data=data[:16],data[16:]
|
||||
f.write(string.join(map(lambda x:'%02X'%ord(x),p),' ')+' ')
|
||||
f.write((16-len(p))*3*' ')
|
||||
for c in p:
|
||||
if len(repr(c))>3: f.write('.')
|
||||
else: f.write(c)
|
||||
f.write('\n')
|
||||
f.write('\n')
|
||||
f.close()
|
||||
|
||||
|
||||
|
||||
class SOCKSv4Factory(protocol.Factory):
|
||||
"""
|
||||
A factory for a SOCKSv4 proxy.
|
||||
|
||||
Constructor accepts one argument, a log file name.
|
||||
"""
|
||||
def __init__(self, log):
|
||||
self.logging = log
|
||||
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
return SOCKSv4(self.logging, reactor)
|
||||
|
||||
|
||||
|
||||
class SOCKSv4IncomingFactory(protocol.Factory):
|
||||
"""
|
||||
A utility class for building protocols for incoming connections.
|
||||
"""
|
||||
def __init__(self, socks, ip):
|
||||
self.socks = socks
|
||||
self.ip = ip
|
||||
|
||||
|
||||
def buildProtocol(self, addr):
|
||||
if addr[0] == self.ip:
|
||||
self.ip = ""
|
||||
self.socks.makeReply(90, 0)
|
||||
return SOCKSv4Incoming(self.socks)
|
||||
elif self.ip == "":
|
||||
return None
|
||||
else:
|
||||
self.socks.makeReply(91, 0)
|
||||
self.ip = ""
|
||||
return None
|
||||
|
|
@ -0,0 +1,49 @@
|
|||
# -*- test-case-name: twisted.test.test_stateful -*-
|
||||
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
|
||||
from twisted.internet import protocol
|
||||
|
||||
from io import BytesIO
|
||||
|
||||
class StatefulProtocol(protocol.Protocol):
|
||||
"""A Protocol that stores state for you.
|
||||
|
||||
state is a pair (function, num_bytes). When num_bytes bytes of data arrives
|
||||
from the network, function is called. It is expected to return the next
|
||||
state or None to keep same state. Initial state is returned by
|
||||
getInitialState (override it).
|
||||
"""
|
||||
_sful_data = None, None, 0
|
||||
|
||||
def makeConnection(self, transport):
|
||||
protocol.Protocol.makeConnection(self, transport)
|
||||
self._sful_data = self.getInitialState(), BytesIO(), 0
|
||||
|
||||
def getInitialState(self):
|
||||
raise NotImplementedError
|
||||
|
||||
def dataReceived(self, data):
|
||||
state, buffer, offset = self._sful_data
|
||||
buffer.seek(0, 2)
|
||||
buffer.write(data)
|
||||
blen = buffer.tell() # how many bytes total is in the buffer
|
||||
buffer.seek(offset)
|
||||
while blen - offset >= state[1]:
|
||||
d = buffer.read(state[1])
|
||||
offset += state[1]
|
||||
next = state[0](d)
|
||||
if self.transport.disconnecting: # XXX: argh stupid hack borrowed right from LineReceiver
|
||||
return # dataReceived won't be called again, so who cares about consistent state
|
||||
if next:
|
||||
state = next
|
||||
if offset != 0:
|
||||
b = buffer.read()
|
||||
buffer.seek(0)
|
||||
buffer.truncate()
|
||||
buffer.write(b)
|
||||
offset = 0
|
||||
self._sful_data = state, buffer, offset
|
||||
|
||||
|
|
@ -0,0 +1,6 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Unit tests for L{twisted.protocols}.
|
||||
"""
|
||||
File diff suppressed because it is too large
Load diff
1987
venv/lib/python3.9/site-packages/twisted/protocols/test/test_tls.py
Normal file
1987
venv/lib/python3.9/site-packages/twisted/protocols/test/test_tls.py
Normal file
File diff suppressed because it is too large
Load diff
830
venv/lib/python3.9/site-packages/twisted/protocols/tls.py
Normal file
830
venv/lib/python3.9/site-packages/twisted/protocols/tls.py
Normal file
|
|
@ -0,0 +1,830 @@
|
|||
# -*- test-case-name: twisted.protocols.test.test_tls,twisted.internet.test.test_tls,twisted.test.test_sslverify -*-
|
||||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""
|
||||
Implementation of a TLS transport (L{ISSLTransport}) as an
|
||||
L{IProtocol<twisted.internet.interfaces.IProtocol>} layered on top of any
|
||||
L{ITransport<twisted.internet.interfaces.ITransport>} implementation, based on
|
||||
U{OpenSSL<http://www.openssl.org>}'s memory BIO features.
|
||||
|
||||
L{TLSMemoryBIOFactory} is a L{WrappingFactory} which wraps protocols created by
|
||||
the factory it wraps with L{TLSMemoryBIOProtocol}. L{TLSMemoryBIOProtocol}
|
||||
intercedes between the underlying transport and the wrapped protocol to
|
||||
implement SSL and TLS. Typical usage of this module looks like this::
|
||||
|
||||
from twisted.protocols.tls import TLSMemoryBIOFactory
|
||||
from twisted.internet.protocol import ServerFactory
|
||||
from twisted.internet.ssl import PrivateCertificate
|
||||
from twisted.internet import reactor
|
||||
|
||||
from someapplication import ApplicationProtocol
|
||||
|
||||
serverFactory = ServerFactory()
|
||||
serverFactory.protocol = ApplicationProtocol
|
||||
certificate = PrivateCertificate.loadPEM(certPEMData)
|
||||
contextFactory = certificate.options()
|
||||
tlsFactory = TLSMemoryBIOFactory(contextFactory, False, serverFactory)
|
||||
reactor.listenTCP(12345, tlsFactory)
|
||||
reactor.run()
|
||||
|
||||
This API offers somewhat more flexibility than
|
||||
L{twisted.internet.interfaces.IReactorSSL}; for example, a
|
||||
L{TLSMemoryBIOProtocol} instance can use another instance of
|
||||
L{TLSMemoryBIOProtocol} as its transport, yielding TLS over TLS - useful to
|
||||
implement onion routing. It can also be used to run TLS over unusual
|
||||
transports, such as UNIX sockets and stdio.
|
||||
"""
|
||||
|
||||
from __future__ import division, absolute_import
|
||||
|
||||
from OpenSSL.SSL import Error, ZeroReturnError, WantReadError
|
||||
from OpenSSL.SSL import TLSv1_METHOD, Context, Connection
|
||||
|
||||
try:
|
||||
Connection(Context(TLSv1_METHOD), None)
|
||||
except TypeError as e:
|
||||
if str(e) != "argument must be an int, or have a fileno() method.":
|
||||
raise
|
||||
raise ImportError("twisted.protocols.tls requires pyOpenSSL 0.10 or newer.")
|
||||
|
||||
from zope.interface import implementer, providedBy, directlyProvides
|
||||
|
||||
from twisted.python.compat import unicode
|
||||
from twisted.python.failure import Failure
|
||||
from twisted.internet.interfaces import (
|
||||
ISystemHandle, INegotiated, IPushProducer, ILoggingContext,
|
||||
IOpenSSLServerConnectionCreator, IOpenSSLClientConnectionCreator,
|
||||
IProtocolNegotiationFactory, IHandshakeListener
|
||||
)
|
||||
from twisted.internet.main import CONNECTION_LOST
|
||||
from twisted.internet._producer_helpers import _PullToPush
|
||||
from twisted.internet.protocol import Protocol
|
||||
from twisted.internet._sslverify import _setAcceptableProtocols
|
||||
from twisted.protocols.policies import ProtocolWrapper, WrappingFactory
|
||||
|
||||
|
||||
@implementer(IPushProducer)
|
||||
class _ProducerMembrane(object):
|
||||
"""
|
||||
Stand-in for producer registered with a L{TLSMemoryBIOProtocol} transport.
|
||||
|
||||
Ensures that producer pause/resume events from the undelying transport are
|
||||
coordinated with pause/resume events from the TLS layer.
|
||||
|
||||
@ivar _producer: The application-layer producer.
|
||||
"""
|
||||
|
||||
_producerPaused = False
|
||||
|
||||
def __init__(self, producer):
|
||||
self._producer = producer
|
||||
|
||||
|
||||
def pauseProducing(self):
|
||||
"""
|
||||
C{pauseProducing} the underlying producer, if it's not paused.
|
||||
"""
|
||||
if self._producerPaused:
|
||||
return
|
||||
self._producerPaused = True
|
||||
self._producer.pauseProducing()
|
||||
|
||||
|
||||
def resumeProducing(self):
|
||||
"""
|
||||
C{resumeProducing} the underlying producer, if it's paused.
|
||||
"""
|
||||
if not self._producerPaused:
|
||||
return
|
||||
self._producerPaused = False
|
||||
self._producer.resumeProducing()
|
||||
|
||||
|
||||
def stopProducing(self):
|
||||
"""
|
||||
C{stopProducing} the underlying producer.
|
||||
|
||||
There is only a single source for this event, so it's simply passed
|
||||
on.
|
||||
"""
|
||||
self._producer.stopProducing()
|
||||
|
||||
|
||||
@implementer(ISystemHandle, INegotiated)
|
||||
class TLSMemoryBIOProtocol(ProtocolWrapper):
|
||||
"""
|
||||
L{TLSMemoryBIOProtocol} is a protocol wrapper which uses OpenSSL via a
|
||||
memory BIO to encrypt bytes written to it before sending them on to the
|
||||
underlying transport and decrypts bytes received from the underlying
|
||||
transport before delivering them to the wrapped protocol.
|
||||
|
||||
In addition to producer events from the underlying transport, the need to
|
||||
wait for reads before a write can proceed means the L{TLSMemoryBIOProtocol}
|
||||
may also want to pause a producer. Pause/resume events are therefore
|
||||
merged using the L{_ProducerMembrane} wrapper. Non-streaming (pull)
|
||||
producers are supported by wrapping them with L{_PullToPush}.
|
||||
|
||||
@ivar _tlsConnection: The L{OpenSSL.SSL.Connection} instance which is
|
||||
encrypted and decrypting this connection.
|
||||
|
||||
@ivar _lostTLSConnection: A flag indicating whether connection loss has
|
||||
already been dealt with (C{True}) or not (C{False}). TLS disconnection
|
||||
is distinct from the underlying connection being lost.
|
||||
|
||||
@ivar _appSendBuffer: application-level (cleartext) data that is waiting to
|
||||
be transferred to the TLS buffer, but can't be because the TLS
|
||||
connection is handshaking.
|
||||
@type _appSendBuffer: L{list} of L{bytes}
|
||||
|
||||
@ivar _connectWrapped: A flag indicating whether or not to call
|
||||
C{makeConnection} on the wrapped protocol. This is for the reactor's
|
||||
L{twisted.internet.interfaces.ITLSTransport.startTLS} implementation,
|
||||
since it has a protocol which it has already called C{makeConnection}
|
||||
on, and which has no interest in a new transport. See #3821.
|
||||
|
||||
@ivar _handshakeDone: A flag indicating whether or not the handshake is
|
||||
known to have completed successfully (C{True}) or not (C{False}). This
|
||||
is used to control error reporting behavior. If the handshake has not
|
||||
completed, the underlying L{OpenSSL.SSL.Error} will be passed to the
|
||||
application's C{connectionLost} method. If it has completed, any
|
||||
unexpected L{OpenSSL.SSL.Error} will be turned into a
|
||||
L{ConnectionLost}. This is weird; however, it is simply an attempt at
|
||||
a faithful re-implementation of the behavior provided by
|
||||
L{twisted.internet.ssl}.
|
||||
|
||||
@ivar _reason: If an unexpected L{OpenSSL.SSL.Error} occurs which causes
|
||||
the connection to be lost, it is saved here. If appropriate, this may
|
||||
be used as the reason passed to the application protocol's
|
||||
C{connectionLost} method.
|
||||
|
||||
@ivar _producer: The current producer registered via C{registerProducer},
|
||||
or L{None} if no producer has been registered or a previous one was
|
||||
unregistered.
|
||||
|
||||
@ivar _aborted: C{abortConnection} has been called. No further data will
|
||||
be received to the wrapped protocol's C{dataReceived}.
|
||||
@type _aborted: L{bool}
|
||||
"""
|
||||
|
||||
_reason = None
|
||||
_handshakeDone = False
|
||||
_lostTLSConnection = False
|
||||
_producer = None
|
||||
_aborted = False
|
||||
|
||||
def __init__(self, factory, wrappedProtocol, _connectWrapped=True):
|
||||
ProtocolWrapper.__init__(self, factory, wrappedProtocol)
|
||||
self._connectWrapped = _connectWrapped
|
||||
|
||||
|
||||
def getHandle(self):
|
||||
"""
|
||||
Return the L{OpenSSL.SSL.Connection} object being used to encrypt and
|
||||
decrypt this connection.
|
||||
|
||||
This is done for the benefit of L{twisted.internet.ssl.Certificate}'s
|
||||
C{peerFromTransport} and C{hostFromTransport} methods only. A
|
||||
different system handle may be returned by future versions of this
|
||||
method.
|
||||
"""
|
||||
return self._tlsConnection
|
||||
|
||||
|
||||
def makeConnection(self, transport):
|
||||
"""
|
||||
Connect this wrapper to the given transport and initialize the
|
||||
necessary L{OpenSSL.SSL.Connection} with a memory BIO.
|
||||
"""
|
||||
self._tlsConnection = self.factory._createConnection(self)
|
||||
self._appSendBuffer = []
|
||||
|
||||
# Add interfaces provided by the transport we are wrapping:
|
||||
for interface in providedBy(transport):
|
||||
directlyProvides(self, interface)
|
||||
|
||||
# Intentionally skip ProtocolWrapper.makeConnection - it might call
|
||||
# wrappedProtocol.makeConnection, which we want to make conditional.
|
||||
Protocol.makeConnection(self, transport)
|
||||
self.factory.registerProtocol(self)
|
||||
if self._connectWrapped:
|
||||
# Now that the TLS layer is initialized, notify the application of
|
||||
# the connection.
|
||||
ProtocolWrapper.makeConnection(self, transport)
|
||||
|
||||
# Now that we ourselves have a transport (initialized by the
|
||||
# ProtocolWrapper.makeConnection call above), kick off the TLS
|
||||
# handshake.
|
||||
self._checkHandshakeStatus()
|
||||
|
||||
|
||||
def _checkHandshakeStatus(self):
|
||||
"""
|
||||
Ask OpenSSL to proceed with a handshake in progress.
|
||||
|
||||
Initially, this just sends the ClientHello; after some bytes have been
|
||||
stuffed in to the C{Connection} object by C{dataReceived}, it will then
|
||||
respond to any C{Certificate} or C{KeyExchange} messages.
|
||||
"""
|
||||
# The connection might already be aborted (eg. by a callback during
|
||||
# connection setup), so don't even bother trying to handshake in that
|
||||
# case.
|
||||
if self._aborted:
|
||||
return
|
||||
try:
|
||||
self._tlsConnection.do_handshake()
|
||||
except WantReadError:
|
||||
self._flushSendBIO()
|
||||
except Error:
|
||||
self._tlsShutdownFinished(Failure())
|
||||
else:
|
||||
self._handshakeDone = True
|
||||
if IHandshakeListener.providedBy(self.wrappedProtocol):
|
||||
self.wrappedProtocol.handshakeCompleted()
|
||||
|
||||
|
||||
def _flushSendBIO(self):
|
||||
"""
|
||||
Read any bytes out of the send BIO and write them to the underlying
|
||||
transport.
|
||||
"""
|
||||
try:
|
||||
bytes = self._tlsConnection.bio_read(2 ** 15)
|
||||
except WantReadError:
|
||||
# There may be nothing in the send BIO right now.
|
||||
pass
|
||||
else:
|
||||
self.transport.write(bytes)
|
||||
|
||||
|
||||
def _flushReceiveBIO(self):
|
||||
"""
|
||||
Try to receive any application-level bytes which are now available
|
||||
because of a previous write into the receive BIO. This will take
|
||||
care of delivering any application-level bytes which are received to
|
||||
the protocol, as well as handling of the various exceptions which
|
||||
can come from trying to get such bytes.
|
||||
"""
|
||||
# Keep trying this until an error indicates we should stop or we
|
||||
# close the connection. Looping is necessary to make sure we
|
||||
# process all of the data which was put into the receive BIO, as
|
||||
# there is no guarantee that a single recv call will do it all.
|
||||
while not self._lostTLSConnection:
|
||||
try:
|
||||
bytes = self._tlsConnection.recv(2 ** 15)
|
||||
except WantReadError:
|
||||
# The newly received bytes might not have been enough to produce
|
||||
# any application data.
|
||||
break
|
||||
except ZeroReturnError:
|
||||
# TLS has shut down and no more TLS data will be received over
|
||||
# this connection.
|
||||
self._shutdownTLS()
|
||||
# Passing in None means the user protocol's connnectionLost
|
||||
# will get called with reason from underlying transport:
|
||||
self._tlsShutdownFinished(None)
|
||||
except Error:
|
||||
# Something went pretty wrong. For example, this might be a
|
||||
# handshake failure during renegotiation (because there were no
|
||||
# shared ciphers, because a certificate failed to verify, etc).
|
||||
# TLS can no longer proceed.
|
||||
failure = Failure()
|
||||
self._tlsShutdownFinished(failure)
|
||||
else:
|
||||
if not self._aborted:
|
||||
ProtocolWrapper.dataReceived(self, bytes)
|
||||
|
||||
# The received bytes might have generated a response which needs to be
|
||||
# sent now. For example, the handshake involves several round-trip
|
||||
# exchanges without ever producing application-bytes.
|
||||
self._flushSendBIO()
|
||||
|
||||
|
||||
def dataReceived(self, bytes):
|
||||
"""
|
||||
Deliver any received bytes to the receive BIO and then read and deliver
|
||||
to the application any application-level data which becomes available
|
||||
as a result of this.
|
||||
"""
|
||||
# Let OpenSSL know some bytes were just received.
|
||||
self._tlsConnection.bio_write(bytes)
|
||||
|
||||
# If we are still waiting for the handshake to complete, try to
|
||||
# complete the handshake with the bytes we just received.
|
||||
if not self._handshakeDone:
|
||||
self._checkHandshakeStatus()
|
||||
|
||||
# If the handshake still isn't finished, then we've nothing left to
|
||||
# do.
|
||||
if not self._handshakeDone:
|
||||
return
|
||||
|
||||
# If we've any pending writes, this read may have un-blocked them, so
|
||||
# attempt to unbuffer them into the OpenSSL layer.
|
||||
if self._appSendBuffer:
|
||||
self._unbufferPendingWrites()
|
||||
|
||||
# Since the handshake is complete, the wire-level bytes we just
|
||||
# processed might turn into some application-level bytes; try to pull
|
||||
# those out.
|
||||
self._flushReceiveBIO()
|
||||
|
||||
|
||||
def _shutdownTLS(self):
|
||||
"""
|
||||
Initiate, or reply to, the shutdown handshake of the TLS layer.
|
||||
"""
|
||||
try:
|
||||
shutdownSuccess = self._tlsConnection.shutdown()
|
||||
except Error:
|
||||
# Mid-handshake, a call to shutdown() can result in a
|
||||
# WantWantReadError, or rather an SSL_ERR_WANT_READ; but pyOpenSSL
|
||||
# doesn't allow us to get at the error. See:
|
||||
# https://github.com/pyca/pyopenssl/issues/91
|
||||
shutdownSuccess = False
|
||||
self._flushSendBIO()
|
||||
if shutdownSuccess:
|
||||
# Both sides have shutdown, so we can start closing lower-level
|
||||
# transport. This will also happen if we haven't started
|
||||
# negotiation at all yet, in which case shutdown succeeds
|
||||
# immediately.
|
||||
self.transport.loseConnection()
|
||||
|
||||
|
||||
def _tlsShutdownFinished(self, reason):
|
||||
"""
|
||||
Called when TLS connection has gone away; tell underlying transport to
|
||||
disconnect.
|
||||
|
||||
@param reason: a L{Failure} whose value is an L{Exception} if we want to
|
||||
report that failure through to the wrapped protocol's
|
||||
C{connectionLost}, or L{None} if the C{reason} that
|
||||
C{connectionLost} should receive should be coming from the
|
||||
underlying transport.
|
||||
@type reason: L{Failure} or L{None}
|
||||
"""
|
||||
if reason is not None:
|
||||
# Squash an EOF in violation of the TLS protocol into
|
||||
# ConnectionLost, so that applications which might run over
|
||||
# multiple protocols can recognize its type.
|
||||
if tuple(reason.value.args[:2]) == (-1, 'Unexpected EOF'):
|
||||
reason = Failure(CONNECTION_LOST)
|
||||
if self._reason is None:
|
||||
self._reason = reason
|
||||
self._lostTLSConnection = True
|
||||
# We may need to send a TLS alert regarding the nature of the shutdown
|
||||
# here (for example, why a handshake failed), so always flush our send
|
||||
# buffer before telling our lower-level transport to go away.
|
||||
self._flushSendBIO()
|
||||
# Using loseConnection causes the application protocol's
|
||||
# connectionLost method to be invoked non-reentrantly, which is always
|
||||
# a nice feature. However, for error cases (reason != None) we might
|
||||
# want to use abortConnection when it becomes available. The
|
||||
# loseConnection call is basically tested by test_handshakeFailure.
|
||||
# At least one side will need to do it or the test never finishes.
|
||||
self.transport.loseConnection()
|
||||
|
||||
|
||||
def connectionLost(self, reason):
|
||||
"""
|
||||
Handle the possible repetition of calls to this method (due to either
|
||||
the underlying transport going away or due to an error at the TLS
|
||||
layer) and make sure the base implementation only gets invoked once.
|
||||
"""
|
||||
if not self._lostTLSConnection:
|
||||
# Tell the TLS connection that it's not going to get any more data
|
||||
# and give it a chance to finish reading.
|
||||
self._tlsConnection.bio_shutdown()
|
||||
self._flushReceiveBIO()
|
||||
self._lostTLSConnection = True
|
||||
reason = self._reason or reason
|
||||
self._reason = None
|
||||
self.connected = False
|
||||
ProtocolWrapper.connectionLost(self, reason)
|
||||
|
||||
# Breaking reference cycle between self._tlsConnection and self.
|
||||
self._tlsConnection = None
|
||||
|
||||
|
||||
def loseConnection(self):
|
||||
"""
|
||||
Send a TLS close alert and close the underlying connection.
|
||||
"""
|
||||
if self.disconnecting or not self.connected:
|
||||
return
|
||||
# If connection setup has not finished, OpenSSL 1.0.2f+ will not shut
|
||||
# down the connection until we write some data to the connection which
|
||||
# allows the handshake to complete. However, since no data should be
|
||||
# written after loseConnection, this means we'll be stuck forever
|
||||
# waiting for shutdown to complete. Instead, we simply abort the
|
||||
# connection without trying to shut down cleanly:
|
||||
if not self._handshakeDone and not self._appSendBuffer:
|
||||
self.abortConnection()
|
||||
self.disconnecting = True
|
||||
if not self._appSendBuffer and self._producer is None:
|
||||
self._shutdownTLS()
|
||||
|
||||
|
||||
def abortConnection(self):
|
||||
"""
|
||||
Tear down TLS state so that if the connection is aborted mid-handshake
|
||||
we don't deliver any further data from the application.
|
||||
"""
|
||||
self._aborted = True
|
||||
self.disconnecting = True
|
||||
self._shutdownTLS()
|
||||
self.transport.abortConnection()
|
||||
|
||||
|
||||
def failVerification(self, reason):
|
||||
"""
|
||||
Abort the connection during connection setup, giving a reason that
|
||||
certificate verification failed.
|
||||
|
||||
@param reason: The reason that the verification failed; reported to the
|
||||
application protocol's C{connectionLost} method.
|
||||
@type reason: L{Failure}
|
||||
"""
|
||||
self._reason = reason
|
||||
self.abortConnection()
|
||||
|
||||
|
||||
def write(self, bytes):
|
||||
"""
|
||||
Process the given application bytes and send any resulting TLS traffic
|
||||
which arrives in the send BIO.
|
||||
|
||||
If C{loseConnection} was called, subsequent calls to C{write} will
|
||||
drop the bytes on the floor.
|
||||
"""
|
||||
if isinstance(bytes, unicode):
|
||||
raise TypeError("Must write bytes to a TLS transport, not unicode.")
|
||||
# Writes after loseConnection are not supported, unless a producer has
|
||||
# been registered, in which case writes can happen until the producer
|
||||
# is unregistered:
|
||||
if self.disconnecting and self._producer is None:
|
||||
return
|
||||
self._write(bytes)
|
||||
|
||||
|
||||
def _bufferedWrite(self, octets):
|
||||
"""
|
||||
Put the given octets into L{TLSMemoryBIOProtocol._appSendBuffer}, and
|
||||
tell any listening producer that it should pause because we are now
|
||||
buffering.
|
||||
"""
|
||||
self._appSendBuffer.append(octets)
|
||||
if self._producer is not None:
|
||||
self._producer.pauseProducing()
|
||||
|
||||
|
||||
def _unbufferPendingWrites(self):
|
||||
"""
|
||||
Un-buffer all waiting writes in L{TLSMemoryBIOProtocol._appSendBuffer}.
|
||||
"""
|
||||
pendingWrites, self._appSendBuffer = self._appSendBuffer, []
|
||||
for eachWrite in pendingWrites:
|
||||
self._write(eachWrite)
|
||||
|
||||
if self._appSendBuffer:
|
||||
# If OpenSSL ran out of buffer space in the Connection on our way
|
||||
# through the loop earlier and re-buffered any of our outgoing
|
||||
# writes, then we're done; don't consider any future work.
|
||||
return
|
||||
|
||||
if self._producer is not None:
|
||||
# If we have a registered producer, let it know that we have some
|
||||
# more buffer space.
|
||||
self._producer.resumeProducing()
|
||||
return
|
||||
|
||||
if self.disconnecting:
|
||||
# Finally, if we have no further buffered data, no producer wants
|
||||
# to send us more data in the future, and the application told us
|
||||
# to end the stream, initiate a TLS shutdown.
|
||||
self._shutdownTLS()
|
||||
|
||||
|
||||
def _write(self, bytes):
|
||||
"""
|
||||
Process the given application bytes and send any resulting TLS traffic
|
||||
which arrives in the send BIO.
|
||||
|
||||
This may be called by C{dataReceived} with bytes that were buffered
|
||||
before C{loseConnection} was called, which is why this function
|
||||
doesn't check for disconnection but accepts the bytes regardless.
|
||||
"""
|
||||
if self._lostTLSConnection:
|
||||
return
|
||||
|
||||
# A TLS payload is 16kB max
|
||||
bufferSize = 2 ** 14
|
||||
|
||||
# How far into the input we've gotten so far
|
||||
alreadySent = 0
|
||||
|
||||
while alreadySent < len(bytes):
|
||||
toSend = bytes[alreadySent:alreadySent + bufferSize]
|
||||
try:
|
||||
sent = self._tlsConnection.send(toSend)
|
||||
except WantReadError:
|
||||
self._bufferedWrite(bytes[alreadySent:])
|
||||
break
|
||||
except Error:
|
||||
# Pretend TLS connection disconnected, which will trigger
|
||||
# disconnect of underlying transport. The error will be passed
|
||||
# to the application protocol's connectionLost method. The
|
||||
# other SSL implementation doesn't, but losing helpful
|
||||
# debugging information is a bad idea.
|
||||
self._tlsShutdownFinished(Failure())
|
||||
break
|
||||
else:
|
||||
# We've successfully handed off the bytes to the OpenSSL
|
||||
# Connection object.
|
||||
alreadySent += sent
|
||||
# See if OpenSSL wants to hand any bytes off to the underlying
|
||||
# transport as a result.
|
||||
self._flushSendBIO()
|
||||
|
||||
|
||||
def writeSequence(self, iovec):
|
||||
"""
|
||||
Write a sequence of application bytes by joining them into one string
|
||||
and passing them to L{write}.
|
||||
"""
|
||||
self.write(b"".join(iovec))
|
||||
|
||||
|
||||
def getPeerCertificate(self):
|
||||
return self._tlsConnection.get_peer_certificate()
|
||||
|
||||
|
||||
@property
|
||||
def negotiatedProtocol(self):
|
||||
"""
|
||||
@see: L{INegotiated.negotiatedProtocol}
|
||||
"""
|
||||
protocolName = None
|
||||
|
||||
try:
|
||||
# If ALPN is not implemented that's ok, NPN might be.
|
||||
protocolName = self._tlsConnection.get_alpn_proto_negotiated()
|
||||
except (NotImplementedError, AttributeError):
|
||||
pass
|
||||
|
||||
if protocolName not in (b'', None):
|
||||
# A protocol was selected using ALPN.
|
||||
return protocolName
|
||||
|
||||
try:
|
||||
protocolName = self._tlsConnection.get_next_proto_negotiated()
|
||||
except (NotImplementedError, AttributeError):
|
||||
pass
|
||||
|
||||
if protocolName != b'':
|
||||
return protocolName
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def registerProducer(self, producer, streaming):
|
||||
# If we've already disconnected, nothing to do here:
|
||||
if self._lostTLSConnection:
|
||||
producer.stopProducing()
|
||||
return
|
||||
|
||||
# If we received a non-streaming producer, wrap it so it becomes a
|
||||
# streaming producer:
|
||||
if not streaming:
|
||||
producer = streamingProducer = _PullToPush(producer, self)
|
||||
producer = _ProducerMembrane(producer)
|
||||
# This will raise an exception if a producer is already registered:
|
||||
self.transport.registerProducer(producer, True)
|
||||
self._producer = producer
|
||||
# If we received a non-streaming producer, we need to start the
|
||||
# streaming wrapper:
|
||||
if not streaming:
|
||||
streamingProducer.startStreaming()
|
||||
|
||||
|
||||
def unregisterProducer(self):
|
||||
# If we have no producer, we don't need to do anything here.
|
||||
if self._producer is None:
|
||||
return
|
||||
|
||||
# If we received a non-streaming producer, we need to stop the
|
||||
# streaming wrapper:
|
||||
if isinstance(self._producer._producer, _PullToPush):
|
||||
self._producer._producer.stopStreaming()
|
||||
self._producer = None
|
||||
self._producerPaused = False
|
||||
self.transport.unregisterProducer()
|
||||
if self.disconnecting and not self._appSendBuffer:
|
||||
self._shutdownTLS()
|
||||
|
||||
|
||||
|
||||
@implementer(IOpenSSLClientConnectionCreator, IOpenSSLServerConnectionCreator)
|
||||
class _ContextFactoryToConnectionFactory(object):
|
||||
"""
|
||||
Adapter wrapping a L{twisted.internet.interfaces.IOpenSSLContextFactory}
|
||||
into a L{IOpenSSLClientConnectionCreator} or
|
||||
L{IOpenSSLServerConnectionCreator}.
|
||||
|
||||
See U{https://twistedmatrix.com/trac/ticket/7215} for work that should make
|
||||
this unnecessary.
|
||||
"""
|
||||
|
||||
def __init__(self, oldStyleContextFactory):
|
||||
"""
|
||||
Construct a L{_ContextFactoryToConnectionFactory} with a
|
||||
L{twisted.internet.interfaces.IOpenSSLContextFactory}.
|
||||
|
||||
Immediately call C{getContext} on C{oldStyleContextFactory} in order to
|
||||
force advance parameter checking, since old-style context factories
|
||||
don't actually check that their arguments to L{OpenSSL} are correct.
|
||||
|
||||
@param oldStyleContextFactory: A factory that can produce contexts.
|
||||
@type oldStyleContextFactory:
|
||||
L{twisted.internet.interfaces.IOpenSSLContextFactory}
|
||||
"""
|
||||
oldStyleContextFactory.getContext()
|
||||
self._oldStyleContextFactory = oldStyleContextFactory
|
||||
|
||||
|
||||
def _connectionForTLS(self, protocol):
|
||||
"""
|
||||
Create an L{OpenSSL.SSL.Connection} object.
|
||||
|
||||
@param protocol: The protocol initiating a TLS connection.
|
||||
@type protocol: L{TLSMemoryBIOProtocol}
|
||||
|
||||
@return: a connection
|
||||
@rtype: L{OpenSSL.SSL.Connection}
|
||||
"""
|
||||
context = self._oldStyleContextFactory.getContext()
|
||||
return Connection(context, None)
|
||||
|
||||
|
||||
def serverConnectionForTLS(self, protocol):
|
||||
"""
|
||||
Construct an OpenSSL server connection from the wrapped old-style
|
||||
context factory.
|
||||
|
||||
@note: Since old-style context factories don't distinguish between
|
||||
clients and servers, this is exactly the same as
|
||||
L{_ContextFactoryToConnectionFactory.clientConnectionForTLS}.
|
||||
|
||||
@param protocol: The protocol initiating a TLS connection.
|
||||
@type protocol: L{TLSMemoryBIOProtocol}
|
||||
|
||||
@return: a connection
|
||||
@rtype: L{OpenSSL.SSL.Connection}
|
||||
"""
|
||||
return self._connectionForTLS(protocol)
|
||||
|
||||
|
||||
def clientConnectionForTLS(self, protocol):
|
||||
"""
|
||||
Construct an OpenSSL server connection from the wrapped old-style
|
||||
context factory.
|
||||
|
||||
@note: Since old-style context factories don't distinguish between
|
||||
clients and servers, this is exactly the same as
|
||||
L{_ContextFactoryToConnectionFactory.serverConnectionForTLS}.
|
||||
|
||||
@param protocol: The protocol initiating a TLS connection.
|
||||
@type protocol: L{TLSMemoryBIOProtocol}
|
||||
|
||||
@return: a connection
|
||||
@rtype: L{OpenSSL.SSL.Connection}
|
||||
"""
|
||||
return self._connectionForTLS(protocol)
|
||||
|
||||
|
||||
|
||||
class TLSMemoryBIOFactory(WrappingFactory):
|
||||
"""
|
||||
L{TLSMemoryBIOFactory} adds TLS to connections.
|
||||
|
||||
@ivar _creatorInterface: the interface which L{_connectionCreator} is
|
||||
expected to implement.
|
||||
@type _creatorInterface: L{zope.interface.interfaces.IInterface}
|
||||
|
||||
@ivar _connectionCreator: a callable which creates an OpenSSL Connection
|
||||
object.
|
||||
@type _connectionCreator: 1-argument callable taking
|
||||
L{TLSMemoryBIOProtocol} and returning L{OpenSSL.SSL.Connection}.
|
||||
"""
|
||||
protocol = TLSMemoryBIOProtocol
|
||||
|
||||
noisy = False # disable unnecessary logging.
|
||||
|
||||
def __init__(self, contextFactory, isClient, wrappedFactory):
|
||||
"""
|
||||
Create a L{TLSMemoryBIOFactory}.
|
||||
|
||||
@param contextFactory: Configuration parameters used to create an
|
||||
OpenSSL connection. In order of preference, what you should pass
|
||||
here should be:
|
||||
|
||||
1. L{twisted.internet.ssl.CertificateOptions} (if you're
|
||||
writing a server) or the result of
|
||||
L{twisted.internet.ssl.optionsForClientTLS} (if you're
|
||||
writing a client). If you want security you should really
|
||||
use one of these.
|
||||
|
||||
2. If you really want to implement something yourself, supply a
|
||||
provider of L{IOpenSSLClientConnectionCreator} or
|
||||
L{IOpenSSLServerConnectionCreator}.
|
||||
|
||||
3. If you really have to, supply a
|
||||
L{twisted.internet.ssl.ContextFactory}. This will likely be
|
||||
deprecated at some point so please upgrade to the new
|
||||
interfaces.
|
||||
|
||||
@type contextFactory: L{IOpenSSLClientConnectionCreator} or
|
||||
L{IOpenSSLServerConnectionCreator}, or, for compatibility with
|
||||
older code, anything implementing
|
||||
L{twisted.internet.interfaces.IOpenSSLContextFactory}. See
|
||||
U{https://twistedmatrix.com/trac/ticket/7215} for information on
|
||||
the upcoming deprecation of passing a
|
||||
L{twisted.internet.ssl.ContextFactory} here.
|
||||
|
||||
@param isClient: Is this a factory for TLS client connections; in other
|
||||
words, those that will send a C{ClientHello} greeting? L{True} if
|
||||
so, L{False} otherwise. This flag determines what interface is
|
||||
expected of C{contextFactory}. If L{True}, C{contextFactory}
|
||||
should provide L{IOpenSSLClientConnectionCreator}; otherwise it
|
||||
should provide L{IOpenSSLServerConnectionCreator}.
|
||||
@type isClient: L{bool}
|
||||
|
||||
@param wrappedFactory: A factory which will create the
|
||||
application-level protocol.
|
||||
@type wrappedFactory: L{twisted.internet.interfaces.IProtocolFactory}
|
||||
"""
|
||||
WrappingFactory.__init__(self, wrappedFactory)
|
||||
if isClient:
|
||||
creatorInterface = IOpenSSLClientConnectionCreator
|
||||
else:
|
||||
creatorInterface = IOpenSSLServerConnectionCreator
|
||||
self._creatorInterface = creatorInterface
|
||||
if not creatorInterface.providedBy(contextFactory):
|
||||
contextFactory = _ContextFactoryToConnectionFactory(contextFactory)
|
||||
self._connectionCreator = contextFactory
|
||||
|
||||
|
||||
def logPrefix(self):
|
||||
"""
|
||||
Annotate the wrapped factory's log prefix with some text indicating TLS
|
||||
is in use.
|
||||
|
||||
@rtype: C{str}
|
||||
"""
|
||||
if ILoggingContext.providedBy(self.wrappedFactory):
|
||||
logPrefix = self.wrappedFactory.logPrefix()
|
||||
else:
|
||||
logPrefix = self.wrappedFactory.__class__.__name__
|
||||
return "%s (TLS)" % (logPrefix,)
|
||||
|
||||
|
||||
def _applyProtocolNegotiation(self, connection):
|
||||
"""
|
||||
Applies ALPN/NPN protocol neogitation to the connection, if the factory
|
||||
supports it.
|
||||
|
||||
@param connection: The OpenSSL connection object to have ALPN/NPN added
|
||||
to it.
|
||||
@type connection: L{OpenSSL.SSL.Connection}
|
||||
|
||||
@return: Nothing
|
||||
@rtype: L{None}
|
||||
"""
|
||||
if IProtocolNegotiationFactory.providedBy(self.wrappedFactory):
|
||||
protocols = self.wrappedFactory.acceptableProtocols()
|
||||
context = connection.get_context()
|
||||
_setAcceptableProtocols(context, protocols)
|
||||
|
||||
return
|
||||
|
||||
|
||||
def _createConnection(self, tlsProtocol):
|
||||
"""
|
||||
Create an OpenSSL connection and set it up good.
|
||||
|
||||
@param tlsProtocol: The protocol which is establishing the connection.
|
||||
@type tlsProtocol: L{TLSMemoryBIOProtocol}
|
||||
|
||||
@return: an OpenSSL connection object for C{tlsProtocol} to use
|
||||
@rtype: L{OpenSSL.SSL.Connection}
|
||||
"""
|
||||
connectionCreator = self._connectionCreator
|
||||
if self._creatorInterface is IOpenSSLClientConnectionCreator:
|
||||
connection = connectionCreator.clientConnectionForTLS(tlsProtocol)
|
||||
self._applyProtocolNegotiation(connection)
|
||||
connection.set_connect_state()
|
||||
else:
|
||||
connection = connectionCreator.serverConnectionForTLS(tlsProtocol)
|
||||
self._applyProtocolNegotiation(connection)
|
||||
connection.set_accept_state()
|
||||
return connection
|
||||
124
venv/lib/python3.9/site-packages/twisted/protocols/wire.py
Normal file
124
venv/lib/python3.9/site-packages/twisted/protocols/wire.py
Normal file
|
|
@ -0,0 +1,124 @@
|
|||
# Copyright (c) Twisted Matrix Laboratories.
|
||||
# See LICENSE for details.
|
||||
|
||||
"""Implement standard (and unused) TCP protocols.
|
||||
|
||||
These protocols are either provided by inetd, or are not provided at all.
|
||||
"""
|
||||
|
||||
from __future__ import absolute_import, division
|
||||
|
||||
import time
|
||||
import struct
|
||||
|
||||
from zope.interface import implementer
|
||||
|
||||
from twisted.internet import protocol, interfaces
|
||||
|
||||
|
||||
|
||||
class Echo(protocol.Protocol):
|
||||
"""
|
||||
As soon as any data is received, write it back (RFC 862).
|
||||
"""
|
||||
|
||||
def dataReceived(self, data):
|
||||
self.transport.write(data)
|
||||
|
||||
|
||||
|
||||
class Discard(protocol.Protocol):
|
||||
"""
|
||||
Discard any received data (RFC 863).
|
||||
"""
|
||||
|
||||
def dataReceived(self, data):
|
||||
# I'm ignoring you, nyah-nyah
|
||||
pass
|
||||
|
||||
|
||||
|
||||
@implementer(interfaces.IProducer)
|
||||
class Chargen(protocol.Protocol):
|
||||
"""
|
||||
Generate repeating noise (RFC 864).
|
||||
"""
|
||||
noise = b'@ABCDEFGHIJKLMNOPQRSTUVWXYZ[\]^_`abcdefghijklmnopqrstuvwxyz{|}~ !"#$%&?'
|
||||
|
||||
def connectionMade(self):
|
||||
self.transport.registerProducer(self, 0)
|
||||
|
||||
|
||||
def resumeProducing(self):
|
||||
self.transport.write(self.noise)
|
||||
|
||||
|
||||
def pauseProducing(self):
|
||||
pass
|
||||
|
||||
|
||||
def stopProducing(self):
|
||||
pass
|
||||
|
||||
|
||||
|
||||
class QOTD(protocol.Protocol):
|
||||
"""
|
||||
Return a quote of the day (RFC 865).
|
||||
"""
|
||||
|
||||
def connectionMade(self):
|
||||
self.transport.write(self.getQuote())
|
||||
self.transport.loseConnection()
|
||||
|
||||
|
||||
def getQuote(self):
|
||||
"""
|
||||
Return a quote. May be overrriden in subclasses.
|
||||
"""
|
||||
return b"An apple a day keeps the doctor away.\r\n"
|
||||
|
||||
|
||||
|
||||
class Who(protocol.Protocol):
|
||||
"""
|
||||
Return list of active users (RFC 866)
|
||||
"""
|
||||
|
||||
def connectionMade(self):
|
||||
self.transport.write(self.getUsers())
|
||||
self.transport.loseConnection()
|
||||
|
||||
|
||||
def getUsers(self):
|
||||
"""
|
||||
Return active users. Override in subclasses.
|
||||
"""
|
||||
return b"root\r\n"
|
||||
|
||||
|
||||
|
||||
class Daytime(protocol.Protocol):
|
||||
"""
|
||||
Send back the daytime in ASCII form (RFC 867).
|
||||
"""
|
||||
|
||||
def connectionMade(self):
|
||||
self.transport.write(time.asctime(time.gmtime(time.time())) + b'\r\n')
|
||||
self.transport.loseConnection()
|
||||
|
||||
|
||||
|
||||
class Time(protocol.Protocol):
|
||||
"""
|
||||
Send back the time in machine readable form (RFC 868).
|
||||
"""
|
||||
|
||||
def connectionMade(self):
|
||||
# is this correct only for 32-bit machines?
|
||||
result = struct.pack("!i", int(time.time()))
|
||||
self.transport.write(result)
|
||||
self.transport.loseConnection()
|
||||
|
||||
|
||||
__all__ = ["Echo", "Discard", "Chargen", "QOTD", "Who", "Daytime", "Time"]
|
||||
Loading…
Add table
Add a link
Reference in a new issue