Ausgabe der neuen DB Einträge

This commit is contained in:
hubobel 2022-01-02 21:50:48 +01:00
parent bad48e1627
commit cfbbb9ee3d
2399 changed files with 843193 additions and 43 deletions

View file

@ -0,0 +1,8 @@
# -*- test-case-name: twisted.words.test -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Twisted Words: Client and server implementations for IRC, XMPP, and other chat
services.
"""

View file

@ -0,0 +1,34 @@
# -*- test-case-name: twisted.words.test -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""Exception definitions for Words
"""
class WordsError(Exception):
def __str__(self):
return self.__class__.__name__ + ': ' + Exception.__str__(self)
class NoSuchUser(WordsError):
pass
class DuplicateUser(WordsError):
pass
class NoSuchGroup(WordsError):
pass
class DuplicateGroup(WordsError):
pass
class AlreadyLoggedIn(WordsError):
pass
__all__ = [
'WordsError', 'NoSuchUser', 'DuplicateUser',
'NoSuchGroup', 'DuplicateGroup', 'AlreadyLoggedIn',
]

View file

@ -0,0 +1,8 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Instance Messenger, Pan-protocol chat client.
"""

View file

@ -0,0 +1,62 @@
# -*- Python -*-
#
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
#
class AccountManager:
"""I am responsible for managing a user's accounts.
That is, remembering what accounts are available, their settings,
adding and removal of accounts, etc.
@ivar accounts: A collection of available accounts.
@type accounts: mapping of strings to L{Account<interfaces.IAccount>}s.
"""
def __init__(self):
self.accounts = {}
def getSnapShot(self):
"""A snapshot of all the accounts and their status.
@returns: A list of tuples, each of the form
(string:accountName, boolean:isOnline,
boolean:autoLogin, string:gatewayType)
"""
data = []
for account in self.accounts.values():
data.append((account.accountName, account.isOnline(),
account.autoLogin, account.gatewayType))
return data
def isEmpty(self):
return len(self.accounts) == 0
def getConnectionInfo(self):
connectioninfo = []
for account in self.accounts.values():
connectioninfo.append(account.isOnline())
return connectioninfo
def addAccount(self, account):
self.accounts[account.accountName] = account
def delAccount(self, accountName):
del self.accounts[accountName]
def connect(self, accountName, chatui):
"""
@returntype: Deferred L{interfaces.IClient}
"""
return self.accounts[accountName].logOn(chatui)
def disconnect(self, accountName):
pass
#self.accounts[accountName].logOff() - not yet implemented
def quit(self):
pass
#for account in self.accounts.values():
# account.logOff() - not yet implemented

View file

@ -0,0 +1,512 @@
# -*- test-case-name: twisted.words.test.test_basechat -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Base classes for Instance Messenger clients.
"""
from twisted.words.im.locals import OFFLINE, ONLINE, AWAY
class ContactsList:
"""
A GUI object that displays a contacts list.
@ivar chatui: The GUI chat client associated with this contacts list.
@type chatui: L{ChatUI}
@ivar contacts: The contacts.
@type contacts: C{dict} mapping C{str} to a L{IPerson<interfaces.IPerson>}
provider
@ivar onlineContacts: The contacts who are currently online (have a status
that is not C{OFFLINE}).
@type onlineContacts: C{dict} mapping C{str} to a
L{IPerson<interfaces.IPerson>} provider
@ivar clients: The signed-on clients.
@type clients: C{list} of L{IClient<interfaces.IClient>} providers
"""
def __init__(self, chatui):
"""
@param chatui: The GUI chat client associated with this contacts list.
@type chatui: L{ChatUI}
"""
self.chatui = chatui
self.contacts = {}
self.onlineContacts = {}
self.clients = []
def setContactStatus(self, person):
"""
Inform the user that a person's status has changed.
@param person: The person whose status has changed.
@type person: L{IPerson<interfaces.IPerson>} provider
"""
if person.name not in self.contacts:
self.contacts[person.name] = person
if person.name not in self.onlineContacts and \
(person.status == ONLINE or person.status == AWAY):
self.onlineContacts[person.name] = person
if person.name in self.onlineContacts and \
person.status == OFFLINE:
del self.onlineContacts[person.name]
def registerAccountClient(self, client):
"""
Notify the user that an account client has been signed on to.
@param client: The client being added to your list of account clients.
@type client: L{IClient<interfaces.IClient>} provider
"""
if not client in self.clients:
self.clients.append(client)
def unregisterAccountClient(self, client):
"""
Notify the user that an account client has been signed off or
disconnected from.
@param client: The client being removed from the list of account
clients.
@type client: L{IClient<interfaces.IClient>} provider
"""
if client in self.clients:
self.clients.remove(client)
def contactChangedNick(self, person, newnick):
"""
Update your contact information to reflect a change to a contact's
nickname.
@param person: The person in your contacts list whose nickname is
changing.
@type person: L{IPerson<interfaces.IPerson>} provider
@param newnick: The new nickname for this person.
@type newnick: C{str}
"""
oldname = person.name
if oldname in self.contacts:
del self.contacts[oldname]
person.name = newnick
self.contacts[newnick] = person
if oldname in self.onlineContacts:
del self.onlineContacts[oldname]
self.onlineContacts[newnick] = person
class Conversation:
"""
A GUI window of a conversation with a specific person.
@ivar person: The person who you're having this conversation with.
@type person: L{IPerson<interfaces.IPerson>} provider
@ivar chatui: The GUI chat client associated with this conversation.
@type chatui: L{ChatUI}
"""
def __init__(self, person, chatui):
"""
@param person: The person who you're having this conversation with.
@type person: L{IPerson<interfaces.IPerson>} provider
@param chatui: The GUI chat client associated with this conversation.
@type chatui: L{ChatUI}
"""
self.chatui = chatui
self.person = person
def show(self):
"""
Display the ConversationWindow.
"""
raise NotImplementedError("Subclasses must implement this method")
def hide(self):
"""
Hide the ConversationWindow.
"""
raise NotImplementedError("Subclasses must implement this method")
def sendText(self, text):
"""
Send text to the person with whom the user is conversing.
@param text: The text to be sent.
@type text: C{str}
"""
self.person.sendMessage(text, None)
def showMessage(self, text, metadata=None):
"""
Display a message sent from the person with whom the user is conversing.
@param text: The sent message.
@type text: C{str}
@param metadata: Metadata associated with this message.
@type metadata: C{dict}
"""
raise NotImplementedError("Subclasses must implement this method")
def contactChangedNick(self, person, newnick):
"""
Change a person's name.
@param person: The person whose nickname is changing.
@type person: L{IPerson<interfaces.IPerson>} provider
@param newnick: The new nickname for this person.
@type newnick: C{str}
"""
self.person.name = newnick
class GroupConversation:
"""
A GUI window of a conversation with a group of people.
@ivar chatui: The GUI chat client associated with this conversation.
@type chatui: L{ChatUI}
@ivar group: The group of people that are having this conversation.
@type group: L{IGroup<interfaces.IGroup>} provider
@ivar members: The names of the people in this conversation.
@type members: C{list} of C{str}
"""
def __init__(self, group, chatui):
"""
@param chatui: The GUI chat client associated with this conversation.
@type chatui: L{ChatUI}
@param group: The group of people that are having this conversation.
@type group: L{IGroup<interfaces.IGroup>} provider
"""
self.chatui = chatui
self.group = group
self.members = []
def show(self):
"""
Display the GroupConversationWindow.
"""
raise NotImplementedError("Subclasses must implement this method")
def hide(self):
"""
Hide the GroupConversationWindow.
"""
raise NotImplementedError("Subclasses must implement this method")
def sendText(self, text):
"""
Send text to the group.
@param: The text to be sent.
@type text: C{str}
"""
self.group.sendGroupMessage(text, None)
def showGroupMessage(self, sender, text, metadata=None):
"""
Display to the user a message sent to this group from the given sender.
@param sender: The person sending the message.
@type sender: C{str}
@param text: The sent message.
@type text: C{str}
@param metadata: Metadata associated with this message.
@type metadata: C{dict}
"""
raise NotImplementedError("Subclasses must implement this method")
def setGroupMembers(self, members):
"""
Set the list of members in the group.
@param members: The names of the people that will be in this group.
@type members: C{list} of C{str}
"""
self.members = members
def setTopic(self, topic, author):
"""
Change the topic for the group conversation window and display this
change to the user.
@param topic: This group's topic.
@type topic: C{str}
@param author: The person changing the topic.
@type author: C{str}
"""
raise NotImplementedError("Subclasses must implement this method")
def memberJoined(self, member):
"""
Add the given member to the list of members in the group conversation
and displays this to the user.
@param member: The person joining the group conversation.
@type member: C{str}
"""
if not member in self.members:
self.members.append(member)
def memberChangedNick(self, oldnick, newnick):
"""
Change the nickname for a member of the group conversation and displays
this change to the user.
@param oldnick: The old nickname.
@type oldnick: C{str}
@param newnick: The new nickname.
@type newnick: C{str}
"""
if oldnick in self.members:
self.members.remove(oldnick)
self.members.append(newnick)
def memberLeft(self, member):
"""
Delete the given member from the list of members in the group
conversation and displays the change to the user.
@param member: The person leaving the group conversation.
@type member: C{str}
"""
if member in self.members:
self.members.remove(member)
class ChatUI:
"""
A GUI chat client.
@type conversations: C{dict} of L{Conversation}
@ivar conversations: A cache of all the direct windows.
@type groupConversations: C{dict} of L{GroupConversation}
@ivar groupConversations: A cache of all the group windows.
@type persons: C{dict} with keys that are a C{tuple} of (C{str},
L{IAccount<interfaces.IAccount>} provider) and values that are
L{IPerson<interfaces.IPerson>} provider
@ivar persons: A cache of all the users associated with this client.
@type groups: C{dict} with keys that are a C{tuple} of (C{str},
L{IAccount<interfaces.IAccount>} provider) and values that are
L{IGroup<interfaces.IGroup>} provider
@ivar groups: A cache of all the groups associated with this client.
@type onlineClients: C{list} of L{IClient<interfaces.IClient>} providers
@ivar onlineClients: A list of message sources currently online.
@type contactsList: L{ContactsList}
@ivar contactsList: A contacts list.
"""
def __init__(self):
self.conversations = {}
self.groupConversations = {}
self.persons = {}
self.groups = {}
self.onlineClients = []
self.contactsList = ContactsList(self)
def registerAccountClient(self, client):
"""
Notify the user that an account has been signed on to.
@type client: L{IClient<interfaces.IClient>} provider
@param client: The client account for the person who has just signed on.
@rtype client: L{IClient<interfaces.IClient>} provider
@return: The client, so that it may be used in a callback chain.
"""
self.onlineClients.append(client)
self.contactsList.registerAccountClient(client)
return client
def unregisterAccountClient(self, client):
"""
Notify the user that an account has been signed off or disconnected.
@type client: L{IClient<interfaces.IClient>} provider
@param client: The client account for the person who has just signed
off.
"""
self.onlineClients.remove(client)
self.contactsList.unregisterAccountClient(client)
def getContactsList(self):
"""
Get the contacts list associated with this chat window.
@rtype: L{ContactsList}
@return: The contacts list associated with this chat window.
"""
return self.contactsList
def getConversation(self, person, Class=Conversation, stayHidden=False):
"""
For the given person object, return the conversation window or create
and return a new conversation window if one does not exist.
@type person: L{IPerson<interfaces.IPerson>} provider
@param person: The person whose conversation window we want to get.
@type Class: L{IConversation<interfaces.IConversation>} implementor
@param: The kind of conversation window we want. If the conversation
window for this person didn't already exist, create one of this type.
@type stayHidden: C{bool}
@param stayHidden: Whether or not the conversation window should stay
hidden.
@rtype: L{IConversation<interfaces.IConversation>} provider
@return: The conversation window.
"""
conv = self.conversations.get(person)
if not conv:
conv = Class(person, self)
self.conversations[person] = conv
if stayHidden:
conv.hide()
else:
conv.show()
return conv
def getGroupConversation(self, group, Class=GroupConversation,
stayHidden=False):
"""
For the given group object, return the group conversation window or
create and return a new group conversation window if it doesn't exist.
@type group: L{IGroup<interfaces.IGroup>} provider
@param group: The group whose conversation window we want to get.
@type Class: L{IConversation<interfaces.IConversation>} implementor
@param: The kind of conversation window we want. If the conversation
window for this person didn't already exist, create one of this type.
@type stayHidden: C{bool}
@param stayHidden: Whether or not the conversation window should stay
hidden.
@rtype: L{IGroupConversation<interfaces.IGroupConversation>} provider
@return: The group conversation window.
"""
conv = self.groupConversations.get(group)
if not conv:
conv = Class(group, self)
self.groupConversations[group] = conv
if stayHidden:
conv.hide()
else:
conv.show()
return conv
def getPerson(self, name, client):
"""
For the given name and account client, return an instance of a
L{IGroup<interfaces.IPerson>} provider or create and return a new
instance of a L{IGroup<interfaces.IPerson>} provider.
@type name: C{str}
@param name: The name of the person of interest.
@type client: L{IClient<interfaces.IClient>} provider
@param client: The client account of interest.
@rtype: L{IPerson<interfaces.IPerson>} provider
@return: The person with that C{name}.
"""
account = client.account
p = self.persons.get((name, account))
if not p:
p = account.getPerson(name)
self.persons[name, account] = p
return p
def getGroup(self, name, client):
"""
For the given name and account client, return an instance of a
L{IGroup<interfaces.IGroup>} provider or create and return a new instance
of a L{IGroup<interfaces.IGroup>} provider.
@type name: C{str}
@param name: The name of the group of interest.
@type client: L{IClient<interfaces.IClient>} provider
@param client: The client account of interest.
@rtype: L{IGroup<interfaces.IGroup>} provider
@return: The group with that C{name}.
"""
# I accept 'client' instead of 'account' in my signature for
# backwards compatibility. (Groups changed to be Account-oriented
# in CVS revision 1.8.)
account = client.account
g = self.groups.get((name, account))
if not g:
g = account.getGroup(name)
self.groups[name, account] = g
return g
def contactChangedNick(self, person, newnick):
"""
For the given C{person}, change the C{person}'s C{name} to C{newnick}
and tell the contact list and any conversation windows with that
C{person} to change as well.
@type person: L{IPerson<interfaces.IPerson>} provider
@param person: The person whose nickname will get changed.
@type newnick: C{str}
@param newnick: The new C{name} C{person} will take.
"""
oldnick = person.name
if (oldnick, person.account) in self.persons:
conv = self.conversations.get(person)
if conv:
conv.contactChangedNick(person, newnick)
self.contactsList.contactChangedNick(person, newnick)
del self.persons[oldnick, person.account]
person.name = newnick
self.persons[person.name, person.account] = person

View file

@ -0,0 +1,269 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
#
"""Instance Messenger base classes for protocol support.
You will find these useful if you're adding a new protocol to IM.
"""
# Abstract representation of chat "model" classes
from twisted.words.im.locals import OFFLINE, OfflineError
from twisted.internet.protocol import Protocol
from twisted.python.reflect import prefixedMethods
from twisted.persisted import styles
from twisted.internet import error
class AbstractGroup:
def __init__(self, name, account):
self.name = name
self.account = account
def getGroupCommands(self):
"""finds group commands
these commands are methods on me that start with imgroup_; they are
called with no arguments
"""
return prefixedMethods(self, "imgroup_")
def getTargetCommands(self, target):
"""finds group commands
these commands are methods on me that start with imgroup_; they are
called with a user present within this room as an argument
you may want to override this in your group in order to filter for
appropriate commands on the given user
"""
return prefixedMethods(self, "imtarget_")
def join(self):
if not self.account.client:
raise OfflineError
self.account.client.joinGroup(self.name)
def leave(self):
if not self.account.client:
raise OfflineError
self.account.client.leaveGroup(self.name)
def __repr__(self):
return '<%s %r>' % (self.__class__, self.name)
def __str__(self):
return '%s@%s' % (self.name, self.account.accountName)
class AbstractPerson:
def __init__(self, name, baseAccount):
self.name = name
self.account = baseAccount
self.status = OFFLINE
def getPersonCommands(self):
"""finds person commands
these commands are methods on me that start with imperson_; they are
called with no arguments
"""
return prefixedMethods(self, "imperson_")
def getIdleTime(self):
"""
Returns a string.
"""
return '--'
def __repr__(self):
return '<%s %r/%s>' % (self.__class__, self.name, self.status)
def __str__(self):
return '%s@%s' % (self.name, self.account.accountName)
class AbstractClientMixin:
"""Designed to be mixed in to a Protocol implementing class.
Inherit from me first.
@ivar _logonDeferred: Fired when I am done logging in.
"""
def __init__(self, account, chatui, logonDeferred):
for base in self.__class__.__bases__:
if issubclass(base, Protocol):
self.__class__._protoBase = base
break
else:
pass
self.account = account
self.chat = chatui
self._logonDeferred = logonDeferred
def connectionMade(self):
self._protoBase.connectionMade(self)
def connectionLost(self, reason):
self.account._clientLost(self, reason)
self.unregisterAsAccountClient()
return self._protoBase.connectionLost(self, reason)
def unregisterAsAccountClient(self):
"""Tell the chat UI that I have `signed off'.
"""
self.chat.unregisterAccountClient(self)
class AbstractAccount(styles.Versioned):
"""Base class for Accounts.
I am the start of an implementation of L{IAccount<interfaces.IAccount>}, I
implement L{isOnline} and most of L{logOn}, though you'll need to implement
L{_startLogOn} in a subclass.
@cvar _groupFactory: A Callable that will return a L{IGroup} appropriate
for this account type.
@cvar _personFactory: A Callable that will return a L{IPerson} appropriate
for this account type.
@type _isConnecting: boolean
@ivar _isConnecting: Whether I am in the process of establishing a
connection to the server.
@type _isOnline: boolean
@ivar _isOnline: Whether I am currently on-line with the server.
@ivar accountName:
@ivar autoLogin:
@ivar username:
@ivar password:
@ivar host:
@ivar port:
"""
_isOnline = 0
_isConnecting = 0
client = None
_groupFactory = AbstractGroup
_personFactory = AbstractPerson
persistanceVersion = 2
def __init__(self, accountName, autoLogin, username, password, host, port):
self.accountName = accountName
self.autoLogin = autoLogin
self.username = username
self.password = password
self.host = host
self.port = port
self._groups = {}
self._persons = {}
def upgrateToVersion2(self):
# Added in CVS revision 1.16.
for k in ('_groups', '_persons'):
if not hasattr(self, k):
setattr(self, k, {})
def __getstate__(self):
state = styles.Versioned.__getstate__(self)
for k in ('client', '_isOnline', '_isConnecting'):
try:
del state[k]
except KeyError:
pass
return state
def isOnline(self):
return self._isOnline
def logOn(self, chatui):
"""Log on to this account.
Takes care to not start a connection if a connection is
already in progress. You will need to implement
L{_startLogOn} for this to work, and it would be a good idea
to override L{_loginFailed} too.
@returntype: Deferred L{interfaces.IClient}
"""
if (not self._isConnecting) and (not self._isOnline):
self._isConnecting = 1
d = self._startLogOn(chatui)
d.addCallback(self._cb_logOn)
# if chatui is not None:
# (I don't particularly like having to pass chatUI to this function,
# but we haven't factored it out yet.)
d.addCallback(chatui.registerAccountClient)
d.addErrback(self._loginFailed)
return d
else:
raise error.ConnectError("Connection in progress")
def getGroup(self, name):
"""Group factory.
@param name: Name of the group on this account.
@type name: string
"""
group = self._groups.get(name)
if group is None:
group = self._groupFactory(name, self)
self._groups[name] = group
return group
def getPerson(self, name):
"""Person factory.
@param name: Name of the person on this account.
@type name: string
"""
person = self._persons.get(name)
if person is None:
person = self._personFactory(name, self)
self._persons[name] = person
return person
def _startLogOn(self, chatui):
"""Start the sign on process.
Factored out of L{logOn}.
@returntype: Deferred L{interfaces.IClient}
"""
raise NotImplementedError()
def _cb_logOn(self, client):
self._isConnecting = 0
self._isOnline = 1
self.client = client
return client
def _loginFailed(self, reason):
"""Errorback for L{logOn}.
@type reason: Failure
@returns: I{reason}, for further processing in the callback chain.
@returntype: Failure
"""
self._isConnecting = 0
self._isOnline = 0 # just in case
return reason
def _clientLost(self, client, reason):
self.client = None
self._isConnecting = 0
self._isOnline = 0
return reason
def __repr__(self):
return "<%s: %s (%s@%s:%s)>" % (self.__class__,
self.accountName,
self.username,
self.host,
self.port)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,398 @@
# -*- Python -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Pan-protocol chat client.
"""
from zope.interface import Interface, Attribute
# (Random musings, may not reflect on current state of code:)
#
# Accounts have Protocol components (clients)
# Persons have Conversation components
# Groups have GroupConversation components
# Persons and Groups are associated with specific Accounts
# At run-time, Clients/Accounts are slaved to a User Interface
# (Note: User may be a bot, so don't assume all UIs are built on gui toolkits)
class IAccount(Interface):
"""
I represent a user's account with a chat service.
"""
client = Attribute('The L{IClient} currently connecting to this account, if any.')
gatewayType = Attribute('A C{str} that identifies the protocol used by this account.')
def __init__(accountName, autoLogin, username, password, host, port):
"""
@type accountName: string
@param accountName: A name to refer to the account by locally.
@type autoLogin: boolean
@type username: string
@type password: string
@type host: string
@type port: integer
"""
def isOnline():
"""
Am I online?
@rtype: boolean
"""
def logOn(chatui):
"""
Go on-line.
@type chatui: Implementor of C{IChatUI}
@rtype: L{Deferred} with an eventual L{IClient} result.
"""
def logOff():
"""
Sign off.
"""
def getGroup(groupName):
"""
@rtype: L{Group<IGroup>}
"""
def getPerson(personName):
"""
@rtype: L{Person<IPerson>}
"""
class IClient(Interface):
account = Attribute('The L{IAccount} I am a Client for')
def __init__(account, chatui, logonDeferred):
"""
@type account: L{IAccount}
@type chatui: L{IChatUI}
@param logonDeferred: Will be called back once I am logged on.
@type logonDeferred: L{Deferred<twisted.internet.defer.Deferred>}
"""
def joinGroup(groupName):
"""
@param groupName: The name of the group to join.
@type groupName: string
"""
def leaveGroup(groupName):
"""
@param groupName: The name of the group to leave.
@type groupName: string
"""
def getGroupConversation(name, hide=0):
pass
def getPerson(name):
pass
class IPerson(Interface):
def __init__(name, account):
"""
Initialize me.
@param name: My name, as the server knows me.
@type name: string
@param account: The account I am accessed through.
@type account: I{Account}
"""
def isOnline():
"""
Am I online right now?
@rtype: boolean
"""
def getStatus():
"""
What is my on-line status?
@return: L{locals.StatusEnum}
"""
def getIdleTime():
"""
@rtype: string (XXX: How about a scalar?)
"""
def sendMessage(text, metadata=None):
"""
Send a message to this person.
@type text: string
@type metadata: dict
"""
class IGroup(Interface):
"""
A group which you may have a conversation with.
Groups generally have a loosely-defined set of members, who may
leave and join at any time.
"""
name = Attribute('My C{str} name, as the server knows me.')
account = Attribute('The L{Account<IAccount>} I am accessed through.')
def __init__(name, account):
"""
Initialize me.
@param name: My name, as the server knows me.
@type name: str
@param account: The account I am accessed through.
@type account: L{Account<IAccount>}
"""
def setTopic(text):
"""
Set this Groups topic on the server.
@type text: string
"""
def sendGroupMessage(text, metadata=None):
"""
Send a message to this group.
@type text: str
@type metadata: dict
@param metadata: Valid keys for this dictionary include:
- C{'style'}: associated with one of:
- C{'emote'}: indicates this is an action
"""
def join():
"""
Join this group.
"""
def leave():
"""
Depart this group.
"""
class IConversation(Interface):
"""
A conversation with a specific person.
"""
def __init__(person, chatui):
"""
@type person: L{IPerson}
"""
def show():
"""
doesn't seem like it belongs in this interface.
"""
def hide():
"""
nor this neither.
"""
def sendText(text, metadata):
pass
def showMessage(text, metadata):
pass
def changedNick(person, newnick):
"""
@param person: XXX Shouldn't this always be Conversation.person?
"""
class IGroupConversation(Interface):
def show():
"""
doesn't seem like it belongs in this interface.
"""
def hide():
"""
nor this neither.
"""
def sendText(text, metadata):
pass
def showGroupMessage(sender, text, metadata):
pass
def setGroupMembers(members):
"""
Sets the list of members in the group and displays it to the user.
"""
def setTopic(topic, author):
"""
Displays the topic (from the server) for the group conversation window.
@type topic: string
@type author: string (XXX: Not Person?)
"""
def memberJoined(member):
"""
Adds the given member to the list of members in the group conversation
and displays this to the user,
@type member: string (XXX: Not Person?)
"""
def memberChangedNick(oldnick, newnick):
"""
Changes the oldnick in the list of members to C{newnick} and displays this
change to the user,
@type oldnick: string (XXX: Not Person?)
@type newnick: string
"""
def memberLeft(member):
"""
Deletes the given member from the list of members in the group
conversation and displays the change to the user.
@type member: string (XXX: Not Person?)
"""
class IChatUI(Interface):
def registerAccountClient(client):
"""
Notifies user that an account has been signed on to.
@type client: L{Client<IClient>}
"""
def unregisterAccountClient(client):
"""
Notifies user that an account has been signed off or disconnected.
@type client: L{Client<IClient>}
"""
def getContactsList():
"""
@rtype: L{ContactsList}
"""
# WARNING: You'll want to be polymorphed into something with
# intrinsic stoning resistance before continuing.
def getConversation(person, Class, stayHidden=0):
"""
For the given person object, returns the conversation window
or creates and returns a new conversation window if one does not exist.
@type person: L{Person<IPerson>}
@type Class: L{Conversation<IConversation>} class
@type stayHidden: boolean
@rtype: L{Conversation<IConversation>}
"""
def getGroupConversation(group, Class, stayHidden=0):
"""
For the given group object, returns the group conversation window or
creates and returns a new group conversation window if it doesn't exist.
@type group: L{Group<interfaces.IGroup>}
@type Class: L{Conversation<interfaces.IConversation>} class
@type stayHidden: boolean
@rtype: L{GroupConversation<interfaces.IGroupConversation>}
"""
def getPerson(name, client):
"""
Get a Person for a client.
Duplicates L{IAccount.getPerson}.
@type name: string
@type client: L{Client<IClient>}
@rtype: L{Person<IPerson>}
"""
def getGroup(name, client):
"""
Get a Group for a client.
Duplicates L{IAccount.getGroup}.
@type name: string
@type client: L{Client<IClient>}
@rtype: L{Group<IGroup>}
"""
def contactChangedNick(oldnick, newnick):
"""
For the given person, changes the person's name to newnick, and
tells the contact list and any conversation windows with that person
to change as well.
@type oldnick: string
@type newnick: string
"""

View file

@ -0,0 +1,293 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
IRC support for Instance Messenger.
"""
from twisted.words.protocols import irc
from twisted.words.im.locals import ONLINE
from twisted.internet import defer, reactor, protocol
from twisted.internet.defer import succeed
from twisted.words.im import basesupport, interfaces, locals
from zope.interface import implementer
class IRCPerson(basesupport.AbstractPerson):
def imperson_whois(self):
if self.account.client is None:
raise locals.OfflineError
self.account.client.sendLine("WHOIS %s" % self.name)
### interface impl
def isOnline(self):
return ONLINE
def getStatus(self):
return ONLINE
def setStatus(self,status):
self.status=status
self.chat.getContactsList().setContactStatus(self)
def sendMessage(self, text, meta=None):
if self.account.client is None:
raise locals.OfflineError
for line in text.split('\n'):
if meta and meta.get("style", None) == "emote":
self.account.client.ctcpMakeQuery(self.name,[('ACTION', line)])
else:
self.account.client.msg(self.name, line)
return succeed(text)
@implementer(interfaces.IGroup)
class IRCGroup(basesupport.AbstractGroup):
def imgroup_testAction(self):
pass
def imtarget_kick(self, target):
if self.account.client is None:
raise locals.OfflineError
reason = "for great justice!"
self.account.client.sendLine("KICK #%s %s :%s" % (
self.name, target.name, reason))
### Interface Implementation
def setTopic(self, topic):
if self.account.client is None:
raise locals.OfflineError
self.account.client.topic(self.name, topic)
def sendGroupMessage(self, text, meta={}):
if self.account.client is None:
raise locals.OfflineError
if meta and meta.get("style", None) == "emote":
self.account.client.ctcpMakeQuery(self.name,[('ACTION', text)])
return succeed(text)
#standard shmandard, clients don't support plain escaped newlines!
for line in text.split('\n'):
self.account.client.say(self.name, line)
return succeed(text)
def leave(self):
if self.account.client is None:
raise locals.OfflineError
self.account.client.leave(self.name)
self.account.client.getGroupConversation(self.name,1)
class IRCProto(basesupport.AbstractClientMixin, irc.IRCClient):
def __init__(self, account, chatui, logonDeferred=None):
basesupport.AbstractClientMixin.__init__(self, account, chatui,
logonDeferred)
self._namreplies={}
self._ingroups={}
self._groups={}
self._topics={}
def getGroupConversation(self, name, hide=0):
name = name.lower()
return self.chat.getGroupConversation(self.chat.getGroup(name, self),
stayHidden=hide)
def getPerson(self,name):
return self.chat.getPerson(name, self)
def connectionMade(self):
# XXX: Why do I duplicate code in IRCClient.register?
try:
self.performLogin = True
self.nickname = self.account.username
self.password = self.account.password
self.realname = "Twisted-IM user"
irc.IRCClient.connectionMade(self)
for channel in self.account.channels:
self.joinGroup(channel)
self.account._isOnline=1
if self._logonDeferred is not None:
self._logonDeferred.callback(self)
self.chat.getContactsList()
except:
import traceback
traceback.print_exc()
def setNick(self,nick):
self.name=nick
self.accountName="%s (IRC)"%nick
irc.IRCClient.setNick(self,nick)
def kickedFrom(self, channel, kicker, message):
"""
Called when I am kicked from a channel.
"""
return self.chat.getGroupConversation(
self.chat.getGroup(channel[1:], self), 1)
def userKicked(self, kickee, channel, kicker, message):
pass
def noticed(self, username, channel, message):
self.privmsg(username, channel, message, {"dontAutoRespond": 1})
def privmsg(self, username, channel, message, metadata=None):
if metadata is None:
metadata = {}
username = username.split('!',1)[0]
if username==self.name: return
if channel[0]=='#':
group=channel[1:]
self.getGroupConversation(group).showGroupMessage(username, message, metadata)
return
self.chat.getConversation(self.getPerson(username)).showMessage(message, metadata)
def action(self,username,channel,emote):
username = username.split('!',1)[0]
if username==self.name: return
meta={'style':'emote'}
if channel[0]=='#':
group=channel[1:]
self.getGroupConversation(group).showGroupMessage(username, emote, meta)
return
self.chat.getConversation(self.getPerson(username)).showMessage(emote,meta)
def irc_RPL_NAMREPLY(self,prefix,params):
"""
RPL_NAMREPLY
>> NAMES #bnl
<< :Arlington.VA.US.Undernet.Org 353 z3p = #bnl :pSwede Dan-- SkOyg AG
"""
group = params[2][1:].lower()
users = params[3].split()
for ui in range(len(users)):
while users[ui][0] in ["@","+"]: # channel modes
users[ui]=users[ui][1:]
if group not in self._namreplies:
self._namreplies[group]=[]
self._namreplies[group].extend(users)
for nickname in users:
try:
self._ingroups[nickname].append(group)
except:
self._ingroups[nickname]=[group]
def irc_RPL_ENDOFNAMES(self,prefix,params):
group=params[1][1:]
self.getGroupConversation(group).setGroupMembers(self._namreplies[group.lower()])
del self._namreplies[group.lower()]
def irc_RPL_TOPIC(self,prefix,params):
self._topics[params[1][1:]]=params[2]
def irc_333(self,prefix,params):
group=params[1][1:]
self.getGroupConversation(group).setTopic(self._topics[group],params[2])
del self._topics[group]
def irc_TOPIC(self,prefix,params):
nickname = prefix.split("!")[0]
group = params[0][1:]
topic = params[1]
self.getGroupConversation(group).setTopic(topic,nickname)
def irc_JOIN(self,prefix,params):
nickname = prefix.split("!")[0]
group = params[0][1:].lower()
if nickname!=self.nickname:
try:
self._ingroups[nickname].append(group)
except:
self._ingroups[nickname]=[group]
self.getGroupConversation(group).memberJoined(nickname)
def irc_PART(self,prefix,params):
nickname = prefix.split("!")[0]
group = params[0][1:].lower()
if nickname!=self.nickname:
if group in self._ingroups[nickname]:
self._ingroups[nickname].remove(group)
self.getGroupConversation(group).memberLeft(nickname)
def irc_QUIT(self,prefix,params):
nickname = prefix.split("!")[0]
if nickname in self._ingroups:
for group in self._ingroups[nickname]:
self.getGroupConversation(group).memberLeft(nickname)
self._ingroups[nickname]=[]
def irc_NICK(self, prefix, params):
fromNick = prefix.split("!")[0]
toNick = params[0]
if fromNick not in self._ingroups:
return
for group in self._ingroups[fromNick]:
self.getGroupConversation(group).memberChangedNick(fromNick, toNick)
self._ingroups[toNick] = self._ingroups[fromNick]
del self._ingroups[fromNick]
def irc_unknown(self, prefix, command, params):
pass
# GTKIM calls
def joinGroup(self,name):
self.join(name)
self.getGroupConversation(name)
@implementer(interfaces.IAccount)
class IRCAccount(basesupport.AbstractAccount):
gatewayType = "IRC"
_groupFactory = IRCGroup
_personFactory = IRCPerson
def __init__(self, accountName, autoLogin, username, password, host, port,
channels=''):
basesupport.AbstractAccount.__init__(self, accountName, autoLogin,
username, password, host, port)
self.channels = [channel.strip() for channel in channels.split(',')]
if self.channels == ['']:
self.channels = []
def _startLogOn(self, chatui):
logonDeferred = defer.Deferred()
cc = protocol.ClientCreator(reactor, IRCProto, self, chatui,
logonDeferred)
d = cc.connectTCP(self.host, self.port)
d.addErrback(logonDeferred.errback)
return logonDeferred

View file

@ -0,0 +1,26 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
class Enum:
group = None
def __init__(self, label):
self.label = label
def __repr__(self):
return '<%s: %s>' % (self.group, self.label)
def __str__(self):
return self.label
class StatusEnum(Enum):
group = 'Status'
OFFLINE = Enum('Offline')
ONLINE = Enum('Online')
AWAY = Enum('Away')
class OfflineError(Exception):
"""The requested action can't happen while offline."""

View file

@ -0,0 +1,262 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
L{twisted.words} support for Instance Messenger.
"""
from __future__ import print_function
from twisted.internet import defer
from twisted.internet import error
from twisted.python import log
from twisted.python.failure import Failure
from twisted.spread import pb
from twisted.words.im.locals import ONLINE, OFFLINE, AWAY
from twisted.words.im import basesupport, interfaces
from zope.interface import implementer
class TwistedWordsPerson(basesupport.AbstractPerson):
"""I a facade for a person you can talk to through a twisted.words service.
"""
def __init__(self, name, wordsAccount):
basesupport.AbstractPerson.__init__(self, name, wordsAccount)
self.status = OFFLINE
def isOnline(self):
return ((self.status == ONLINE) or
(self.status == AWAY))
def getStatus(self):
return self.status
def sendMessage(self, text, metadata):
"""Return a deferred...
"""
if metadata:
d=self.account.client.perspective.directMessage(self.name,
text, metadata)
d.addErrback(self.metadataFailed, "* "+text)
return d
else:
return self.account.client.perspective.callRemote('directMessage',self.name, text)
def metadataFailed(self, result, text):
print("result:",result,"text:",text)
return self.account.client.perspective.directMessage(self.name, text)
def setStatus(self, status):
self.status = status
self.chat.getContactsList().setContactStatus(self)
@implementer(interfaces.IGroup)
class TwistedWordsGroup(basesupport.AbstractGroup):
def __init__(self, name, wordsClient):
basesupport.AbstractGroup.__init__(self, name, wordsClient)
self.joined = 0
def sendGroupMessage(self, text, metadata=None):
"""Return a deferred.
"""
#for backwards compatibility with older twisted.words servers.
if metadata:
d=self.account.client.perspective.callRemote(
'groupMessage', self.name, text, metadata)
d.addErrback(self.metadataFailed, "* "+text)
return d
else:
return self.account.client.perspective.callRemote('groupMessage',
self.name, text)
def setTopic(self, text):
self.account.client.perspective.callRemote(
'setGroupMetadata',
{'topic': text, 'topic_author': self.client.name},
self.name)
def metadataFailed(self, result, text):
print("result:",result,"text:",text)
return self.account.client.perspective.callRemote('groupMessage',
self.name, text)
def joining(self):
self.joined = 1
def leaving(self):
self.joined = 0
def leave(self):
return self.account.client.perspective.callRemote('leaveGroup',
self.name)
class TwistedWordsClient(pb.Referenceable, basesupport.AbstractClientMixin):
"""In some cases, this acts as an Account, since it a source of text
messages (multiple Words instances may be on a single PB connection)
"""
def __init__(self, acct, serviceName, perspectiveName, chatui,
_logonDeferred=None):
self.accountName = "%s (%s:%s)" % (acct.accountName, serviceName, perspectiveName)
self.name = perspectiveName
print("HELLO I AM A PB SERVICE", serviceName, perspectiveName)
self.chat = chatui
self.account = acct
self._logonDeferred = _logonDeferred
def getPerson(self, name):
return self.chat.getPerson(name, self)
def getGroup(self, name):
return self.chat.getGroup(name, self)
def getGroupConversation(self, name):
return self.chat.getGroupConversation(self.getGroup(name))
def addContact(self, name):
self.perspective.callRemote('addContact', name)
def remote_receiveGroupMembers(self, names, group):
print('received group members:', names, group)
self.getGroupConversation(group).setGroupMembers(names)
def remote_receiveGroupMessage(self, sender, group, message, metadata=None):
print('received a group message', sender, group, message, metadata)
self.getGroupConversation(group).showGroupMessage(sender, message, metadata)
def remote_memberJoined(self, member, group):
print('member joined', member, group)
self.getGroupConversation(group).memberJoined(member)
def remote_memberLeft(self, member, group):
print('member left')
self.getGroupConversation(group).memberLeft(member)
def remote_notifyStatusChanged(self, name, status):
self.chat.getPerson(name, self).setStatus(status)
def remote_receiveDirectMessage(self, name, message, metadata=None):
self.chat.getConversation(self.chat.getPerson(name, self)).showMessage(message, metadata)
def remote_receiveContactList(self, clist):
for name, status in clist:
self.chat.getPerson(name, self).setStatus(status)
def remote_setGroupMetadata(self, dict_, groupName):
if "topic" in dict_:
self.getGroupConversation(groupName).setTopic(dict_["topic"], dict_.get("topic_author", None))
def joinGroup(self, name):
self.getGroup(name).joining()
return self.perspective.callRemote('joinGroup', name).addCallback(self._cbGroupJoined, name)
def leaveGroup(self, name):
self.getGroup(name).leaving()
return self.perspective.callRemote('leaveGroup', name).addCallback(self._cbGroupLeft, name)
def _cbGroupJoined(self, result, name):
groupConv = self.chat.getGroupConversation(self.getGroup(name))
groupConv.showGroupMessage("sys", "you joined")
self.perspective.callRemote('getGroupMembers', name)
def _cbGroupLeft(self, result, name):
print('left',name)
groupConv = self.chat.getGroupConversation(self.getGroup(name), 1)
groupConv.showGroupMessage("sys", "you left")
def connected(self, perspective):
print('Connected Words Client!', perspective)
if self._logonDeferred is not None:
self._logonDeferred.callback(self)
self.perspective = perspective
self.chat.getContactsList()
pbFrontEnds = {
"twisted.words": TwistedWordsClient,
"twisted.reality": None
}
@implementer(interfaces.IAccount)
class PBAccount(basesupport.AbstractAccount):
gatewayType = "PB"
_groupFactory = TwistedWordsGroup
_personFactory = TwistedWordsPerson
def __init__(self, accountName, autoLogin, username, password, host, port,
services=None):
"""
@param username: The name of your PB Identity.
@type username: string
"""
basesupport.AbstractAccount.__init__(self, accountName, autoLogin,
username, password, host, port)
self.services = []
if not services:
services = [('twisted.words', 'twisted.words', username)]
for serviceType, serviceName, perspectiveName in services:
self.services.append([pbFrontEnds[serviceType], serviceName,
perspectiveName])
def logOn(self, chatui):
"""
@returns: this breaks with L{interfaces.IAccount}
@returntype: DeferredList of L{interfaces.IClient}s
"""
# Overriding basesupport's implementation on account of the
# fact that _startLogOn tends to return a deferredList rather
# than a simple Deferred, and we need to do registerAccountClient.
if (not self._isConnecting) and (not self._isOnline):
self._isConnecting = 1
d = self._startLogOn(chatui)
d.addErrback(self._loginFailed)
def registerMany(results):
for success, result in results:
if success:
chatui.registerAccountClient(result)
self._cb_logOn(result)
else:
log.err(result)
d.addCallback(registerMany)
return d
else:
raise error.ConnectionError("Connection in progress")
def _startLogOn(self, chatui):
print('Connecting...', end=' ')
d = pb.getObjectAt(self.host, self.port)
d.addCallbacks(self._cbConnected, self._ebConnected,
callbackArgs=(chatui,))
return d
def _cbConnected(self, root, chatui):
print('Connected!')
print('Identifying...', end=' ')
d = pb.authIdentity(root, self.username, self.password)
d.addCallbacks(self._cbIdent, self._ebConnected,
callbackArgs=(chatui,))
return d
def _cbIdent(self, ident, chatui):
if not ident:
print('falsely identified.')
return self._ebConnected(Failure(Exception("username or password incorrect")))
print('Identified!')
dl = []
for handlerClass, sname, pname in self.services:
d = defer.Deferred()
dl.append(d)
handler = handlerClass(self, sname, pname, chatui, d)
ident.callRemote('attach', sname, pname, handler).addCallback(handler.connected)
return defer.DeferredList(dl)
def _ebConnected(self, error):
print('Not connected.')
return error

View file

@ -0,0 +1,267 @@
# -*- test-case-name: twisted.words.test -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
from zope.interface import Interface, Attribute
class IProtocolPlugin(Interface):
"""Interface for plugins providing an interface to a Words service
"""
name = Attribute("A single word describing what kind of interface this is (eg, irc or web)")
def getFactory(realm, portal):
"""Retrieve a C{twisted.internet.interfaces.IServerFactory} provider
@param realm: An object providing C{twisted.cred.portal.IRealm} and
L{IChatService}, with which service information should be looked up.
@param portal: An object providing C{twisted.cred.portal.IPortal},
through which logins should be performed.
"""
class IGroup(Interface):
name = Attribute("A short string, unique among groups.")
def add(user):
"""Include the given user in this group.
@type user: L{IUser}
"""
def remove(user, reason=None):
"""Remove the given user from this group.
@type user: L{IUser}
@type reason: C{unicode}
"""
def size():
"""Return the number of participants in this group.
@rtype: L{twisted.internet.defer.Deferred}
@return: A Deferred which fires with an C{int} representing the
number of participants in this group.
"""
def receive(sender, recipient, message):
"""
Broadcast the given message from the given sender to other
users in group.
The message is not re-transmitted to the sender.
@param sender: L{IUser}
@type recipient: L{IGroup}
@param recipient: This is probably a wart. Maybe it will be removed
in the future. For now, it should be the group object the message
is being delivered to.
@param message: C{dict}
@rtype: L{twisted.internet.defer.Deferred}
@return: A Deferred which fires with None when delivery has been
attempted for all users.
"""
def setMetadata(meta):
"""Change the metadata associated with this group.
@type meta: C{dict}
"""
def iterusers():
"""Return an iterator of all users in this group.
"""
class IChatClient(Interface):
"""Interface through which IChatService interacts with clients.
"""
name = Attribute("A short string, unique among users. This will be set by the L{IChatService} at login time.")
def receive(sender, recipient, message):
"""
Callback notifying this user of the given message sent by the
given user.
This will be invoked whenever another user sends a message to a
group this user is participating in, or whenever another user sends
a message directly to this user. In the former case, C{recipient}
will be the group to which the message was sent; in the latter, it
will be the same object as the user who is receiving the message.
@type sender: L{IUser}
@type recipient: L{IUser} or L{IGroup}
@type message: C{dict}
@rtype: L{twisted.internet.defer.Deferred}
@return: A Deferred which fires when the message has been delivered,
or which fails in some way. If the Deferred fails and the message
was directed at a group, this user will be removed from that group.
"""
def groupMetaUpdate(group, meta):
"""
Callback notifying this user that the metadata for the given
group has changed.
@type group: L{IGroup}
@type meta: C{dict}
@rtype: L{twisted.internet.defer.Deferred}
"""
def userJoined(group, user):
"""
Callback notifying this user that the given user has joined
the given group.
@type group: L{IGroup}
@type user: L{IUser}
@rtype: L{twisted.internet.defer.Deferred}
"""
def userLeft(group, user, reason=None):
"""
Callback notifying this user that the given user has left the
given group for the given reason.
@type group: L{IGroup}
@type user: L{IUser}
@type reason: C{unicode}
@rtype: L{twisted.internet.defer.Deferred}
"""
class IUser(Interface):
"""Interface through which clients interact with IChatService.
"""
realm = Attribute("A reference to the Realm to which this user belongs. Set if and only if the user is logged in.")
mind = Attribute("A reference to the mind which logged in to this user. Set if and only if the user is logged in.")
name = Attribute("A short string, unique among users.")
lastMessage = Attribute("A POSIX timestamp indicating the time of the last message received from this user.")
signOn = Attribute("A POSIX timestamp indicating this user's most recent sign on time.")
def loggedIn(realm, mind):
"""Invoked by the associated L{IChatService} when login occurs.
@param realm: The L{IChatService} through which login is occurring.
@param mind: The mind object used for cred login.
"""
def send(recipient, message):
"""Send the given message to the given user or group.
@type recipient: Either L{IUser} or L{IGroup}
@type message: C{dict}
"""
def join(group):
"""Attempt to join the given group.
@type group: L{IGroup}
@rtype: L{twisted.internet.defer.Deferred}
"""
def leave(group):
"""Discontinue participation in the given group.
@type group: L{IGroup}
@rtype: L{twisted.internet.defer.Deferred}
"""
def itergroups():
"""
Return an iterator of all groups of which this user is a
member.
"""
class IChatService(Interface):
name = Attribute("A short string identifying this chat service (eg, a hostname)")
createGroupOnRequest = Attribute(
"A boolean indicating whether L{getGroup} should implicitly "
"create groups which are requested but which do not yet exist.")
createUserOnRequest = Attribute(
"A boolean indicating whether L{getUser} should implicitly "
"create users which are requested but which do not yet exist.")
def itergroups():
"""Return all groups available on this service.
@rtype: C{twisted.internet.defer.Deferred}
@return: A Deferred which fires with a list of C{IGroup} providers.
"""
def getGroup(name):
"""Retrieve the group by the given name.
@type name: C{str}
@rtype: L{twisted.internet.defer.Deferred}
@return: A Deferred which fires with the group with the given
name if one exists (or if one is created due to the setting of
L{IChatService.createGroupOnRequest}, or which fails with
L{twisted.words.ewords.NoSuchGroup} if no such group exists.
"""
def createGroup(name):
"""Create a new group with the given name.
@type name: C{str}
@rtype: L{twisted.internet.defer.Deferred}
@return: A Deferred which fires with the created group, or
with fails with L{twisted.words.ewords.DuplicateGroup} if a
group by that name exists already.
"""
def lookupGroup(name):
"""Retrieve a group by name.
Unlike C{getGroup}, this will never implicitly create a group.
@type name: C{str}
@rtype: L{twisted.internet.defer.Deferred}
@return: A Deferred which fires with the group by the given
name, or which fails with L{twisted.words.ewords.NoSuchGroup}.
"""
def getUser(name):
"""Retrieve the user by the given name.
@type name: C{str}
@rtype: L{twisted.internet.defer.Deferred}
@return: A Deferred which fires with the user with the given
name if one exists (or if one is created due to the setting of
L{IChatService.createUserOnRequest}, or which fails with
L{twisted.words.ewords.NoSuchUser} if no such user exists.
"""
def createUser(name):
"""Create a new user with the given name.
@type name: C{str}
@rtype: L{twisted.internet.defer.Deferred}
@return: A Deferred which fires with the created user, or
with fails with L{twisted.words.ewords.DuplicateUser} if a
user by that name exists already.
"""
__all__ = [
'IGroup', 'IChatClient', 'IUser', 'IChatService',
]

View file

@ -0,0 +1,6 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Chat protocols.
"""

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,8 @@
# -*- test-case-name: twisted.words.test -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Twisted Jabber: Jabber Protocol Helpers
"""

View file

@ -0,0 +1,408 @@
# -*- test-case-name: twisted.words.test.test_jabberclient -*-
#
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
from __future__ import absolute_import, division
from twisted.python.compat import _coercedUnicode, unicode
from twisted.words.protocols.jabber import xmlstream, sasl, error
from twisted.words.protocols.jabber.jid import JID
from twisted.words.xish import domish, xpath, utility
NS_XMPP_STREAMS = 'urn:ietf:params:xml:ns:xmpp-streams'
NS_XMPP_BIND = 'urn:ietf:params:xml:ns:xmpp-bind'
NS_XMPP_SESSION = 'urn:ietf:params:xml:ns:xmpp-session'
NS_IQ_AUTH_FEATURE = 'http://jabber.org/features/iq-auth'
DigestAuthQry = xpath.internQuery("/iq/query/digest")
PlaintextAuthQry = xpath.internQuery("/iq/query/password")
def basicClientFactory(jid, secret):
a = BasicAuthenticator(jid, secret)
return xmlstream.XmlStreamFactory(a)
class IQ(domish.Element):
"""
Wrapper for a Info/Query packet.
This provides the necessary functionality to send IQs and get notified when
a result comes back. It's a subclass from L{domish.Element}, so you can use
the standard DOM manipulation calls to add data to the outbound request.
@type callbacks: L{utility.CallbackList}
@cvar callbacks: Callback list to be notified when response comes back
"""
def __init__(self, xmlstream, type = "set"):
"""
@type xmlstream: L{xmlstream.XmlStream}
@param xmlstream: XmlStream to use for transmission of this IQ
@type type: C{str}
@param type: IQ type identifier ('get' or 'set')
"""
domish.Element.__init__(self, ("jabber:client", "iq"))
self.addUniqueId()
self["type"] = type
self._xmlstream = xmlstream
self.callbacks = utility.CallbackList()
def addCallback(self, fn, *args, **kwargs):
"""
Register a callback for notification when the IQ result is available.
"""
self.callbacks.addCallback(True, fn, *args, **kwargs)
def send(self, to = None):
"""
Call this method to send this IQ request via the associated XmlStream.
@param to: Jabber ID of the entity to send the request to
@type to: C{str}
@returns: Callback list for this IQ. Any callbacks added to this list
will be fired when the result comes back.
"""
if to != None:
self["to"] = to
self._xmlstream.addOnetimeObserver("/iq[@id='%s']" % self["id"], \
self._resultEvent)
self._xmlstream.send(self)
def _resultEvent(self, iq):
self.callbacks.callback(iq)
self.callbacks = None
class IQAuthInitializer(object):
"""
Non-SASL Authentication initializer for the initiating entity.
This protocol is defined in
U{JEP-0078<http://www.jabber.org/jeps/jep-0078.html>} and mainly serves for
compatibility with pre-XMPP-1.0 server implementations.
@cvar INVALID_USER_EVENT: Token to signal that authentication failed, due
to invalid username.
@type INVALID_USER_EVENT: L{str}
@cvar AUTH_FAILED_EVENT: Token to signal that authentication failed, due to
invalid password.
@type AUTH_FAILED_EVENT: L{str}
"""
INVALID_USER_EVENT = "//event/client/basicauth/invaliduser"
AUTH_FAILED_EVENT = "//event/client/basicauth/authfailed"
def __init__(self, xs):
self.xmlstream = xs
def initialize(self):
# Send request for auth fields
iq = xmlstream.IQ(self.xmlstream, "get")
iq.addElement(("jabber:iq:auth", "query"))
jid = self.xmlstream.authenticator.jid
iq.query.addElement("username", content = jid.user)
d = iq.send()
d.addCallbacks(self._cbAuthQuery, self._ebAuthQuery)
return d
def _cbAuthQuery(self, iq):
jid = self.xmlstream.authenticator.jid
password = _coercedUnicode(self.xmlstream.authenticator.password)
# Construct auth request
reply = xmlstream.IQ(self.xmlstream, "set")
reply.addElement(("jabber:iq:auth", "query"))
reply.query.addElement("username", content = jid.user)
reply.query.addElement("resource", content = jid.resource)
# Prefer digest over plaintext
if DigestAuthQry.matches(iq):
digest = xmlstream.hashPassword(self.xmlstream.sid, password)
reply.query.addElement("digest", content=unicode(digest))
else:
reply.query.addElement("password", content = password)
d = reply.send()
d.addCallbacks(self._cbAuth, self._ebAuth)
return d
def _ebAuthQuery(self, failure):
failure.trap(error.StanzaError)
e = failure.value
if e.condition == 'not-authorized':
self.xmlstream.dispatch(e.stanza, self.INVALID_USER_EVENT)
else:
self.xmlstream.dispatch(e.stanza, self.AUTH_FAILED_EVENT)
return failure
def _cbAuth(self, iq):
pass
def _ebAuth(self, failure):
failure.trap(error.StanzaError)
self.xmlstream.dispatch(failure.value.stanza, self.AUTH_FAILED_EVENT)
return failure
class BasicAuthenticator(xmlstream.ConnectAuthenticator):
"""
Authenticates an XmlStream against a Jabber server as a Client.
This only implements non-SASL authentication, per
U{JEP-0078<http://www.jabber.org/jeps/jep-0078.html>}. Additionally, this
authenticator provides the ability to perform inline registration, per
U{JEP-0077<http://www.jabber.org/jeps/jep-0077.html>}.
Under normal circumstances, the BasicAuthenticator generates the
L{xmlstream.STREAM_AUTHD_EVENT} once the stream has authenticated. However,
it can also generate other events, such as:
- L{INVALID_USER_EVENT} : Authentication failed, due to invalid username
- L{AUTH_FAILED_EVENT} : Authentication failed, due to invalid password
- L{REGISTER_FAILED_EVENT} : Registration failed
If authentication fails for any reason, you can attempt to register by
calling the L{registerAccount} method. If the registration succeeds, a
L{xmlstream.STREAM_AUTHD_EVENT} will be fired. Otherwise, one of the above
errors will be generated (again).
@cvar INVALID_USER_EVENT: See L{IQAuthInitializer.INVALID_USER_EVENT}.
@type INVALID_USER_EVENT: L{str}
@cvar AUTH_FAILED_EVENT: See L{IQAuthInitializer.AUTH_FAILED_EVENT}.
@type AUTH_FAILED_EVENT: L{str}
@cvar REGISTER_FAILED_EVENT: Token to signal that registration failed.
@type REGISTER_FAILED_EVENT: L{str}
"""
namespace = "jabber:client"
INVALID_USER_EVENT = IQAuthInitializer.INVALID_USER_EVENT
AUTH_FAILED_EVENT = IQAuthInitializer.AUTH_FAILED_EVENT
REGISTER_FAILED_EVENT = "//event/client/basicauth/registerfailed"
def __init__(self, jid, password):
xmlstream.ConnectAuthenticator.__init__(self, jid.host)
self.jid = jid
self.password = password
def associateWithStream(self, xs):
xs.version = (0, 0)
xmlstream.ConnectAuthenticator.associateWithStream(self, xs)
xs.initializers = [
xmlstream.TLSInitiatingInitializer(xs, required=False),
IQAuthInitializer(xs),
]
# TODO: move registration into an Initializer?
def registerAccount(self, username = None, password = None):
if username:
self.jid.user = username
if password:
self.password = password
iq = IQ(self.xmlstream, "set")
iq.addElement(("jabber:iq:register", "query"))
iq.query.addElement("username", content = self.jid.user)
iq.query.addElement("password", content = self.password)
iq.addCallback(self._registerResultEvent)
iq.send()
def _registerResultEvent(self, iq):
if iq["type"] == "result":
# Registration succeeded -- go ahead and auth
self.streamStarted()
else:
# Registration failed
self.xmlstream.dispatch(iq, self.REGISTER_FAILED_EVENT)
class CheckVersionInitializer(object):
"""
Initializer that checks if the minimum common stream version number is 1.0.
"""
def __init__(self, xs):
self.xmlstream = xs
def initialize(self):
if self.xmlstream.version < (1, 0):
raise error.StreamError('unsupported-version')
class BindInitializer(xmlstream.BaseFeatureInitiatingInitializer):
"""
Initializer that implements Resource Binding for the initiating entity.
This protocol is documented in U{RFC 3920, section
7<http://www.xmpp.org/specs/rfc3920.html#bind>}.
"""
feature = (NS_XMPP_BIND, 'bind')
def start(self):
iq = xmlstream.IQ(self.xmlstream, 'set')
bind = iq.addElement((NS_XMPP_BIND, 'bind'))
resource = self.xmlstream.authenticator.jid.resource
if resource:
bind.addElement('resource', content=resource)
d = iq.send()
d.addCallback(self.onBind)
return d
def onBind(self, iq):
if iq.bind:
self.xmlstream.authenticator.jid = JID(unicode(iq.bind.jid))
class SessionInitializer(xmlstream.BaseFeatureInitiatingInitializer):
"""
Initializer that implements session establishment for the initiating
entity.
This protocol is defined in U{RFC 3921, section
3<http://www.xmpp.org/specs/rfc3921.html#session>}.
"""
feature = (NS_XMPP_SESSION, 'session')
def start(self):
iq = xmlstream.IQ(self.xmlstream, 'set')
iq.addElement((NS_XMPP_SESSION, 'session'))
return iq.send()
def XMPPClientFactory(jid, password, configurationForTLS=None):
"""
Client factory for XMPP 1.0 (only).
This returns a L{xmlstream.XmlStreamFactory} with an L{XMPPAuthenticator}
object to perform the stream initialization steps (such as authentication).
@see: The notes at L{XMPPAuthenticator} describe how the L{jid} and
L{password} parameters are to be used.
@param jid: Jabber ID to connect with.
@type jid: L{jid.JID}
@param password: password to authenticate with.
@type password: L{unicode}
@param configurationForTLS: An object which creates appropriately
configured TLS connections. This is passed to C{startTLS} on the
transport and is preferably created using
L{twisted.internet.ssl.optionsForClientTLS}. If C{None}, the default is
to verify the server certificate against the trust roots as provided by
the platform. See L{twisted.internet._sslverify.platformTrust}.
@type configurationForTLS: L{IOpenSSLClientConnectionCreator} or C{None}
@return: XML stream factory.
@rtype: L{xmlstream.XmlStreamFactory}
"""
a = XMPPAuthenticator(jid, password,
configurationForTLS=configurationForTLS)
return xmlstream.XmlStreamFactory(a)
class XMPPAuthenticator(xmlstream.ConnectAuthenticator):
"""
Initializes an XmlStream connecting to an XMPP server as a Client.
This authenticator performs the initialization steps needed to start
exchanging XML stanzas with an XMPP server as an XMPP client. It checks if
the server advertises XML stream version 1.0, negotiates TLS (when
available), performs SASL authentication, binds a resource and establishes
a session.
Upon successful stream initialization, the L{xmlstream.STREAM_AUTHD_EVENT}
event will be dispatched through the XML stream object. Otherwise, the
L{xmlstream.INIT_FAILED_EVENT} event will be dispatched with a failure
object.
After inspection of the failure, initialization can then be restarted by
calling L{ConnectAuthenticator.initializeStream}. For example, in case of
authentication failure, a user may be given the opportunity to input the
correct password. By setting the L{password} instance variable and restarting
initialization, the stream authentication step is then retried, and subsequent
steps are performed if successful.
@ivar jid: Jabber ID to authenticate with. This may contain a resource
part, as a suggestion to the server for resource binding. A
server may override this, though. If the resource part is left
off, the server will generate a unique resource identifier.
The server will always return the full Jabber ID in the
resource binding step, and this is stored in this instance
variable.
@type jid: L{jid.JID}
@ivar password: password to be used during SASL authentication.
@type password: L{unicode}
"""
namespace = 'jabber:client'
def __init__(self, jid, password, configurationForTLS=None):
"""
@param configurationForTLS: An object which creates appropriately
configured TLS connections. This is passed to C{startTLS} on the
transport and is preferably created using
L{twisted.internet.ssl.optionsForClientTLS}. If C{None}, the
default is to verify the server certificate against the trust roots
as provided by the platform. See
L{twisted.internet._sslverify.platformTrust}.
@type configurationForTLS: L{IOpenSSLClientConnectionCreator} or
C{None}
"""
xmlstream.ConnectAuthenticator.__init__(self, jid.host)
self.jid = jid
self.password = password
self._configurationForTLS = configurationForTLS
def associateWithStream(self, xs):
"""
Register with the XML stream.
Populates stream's list of initializers, along with their
requiredness. This list is used by
L{ConnectAuthenticator.initializeStream} to perform the initialization
steps.
"""
xmlstream.ConnectAuthenticator.associateWithStream(self, xs)
xs.initializers = [
CheckVersionInitializer(xs),
xmlstream.TLSInitiatingInitializer(
xs, required=True,
configurationForTLS=self._configurationForTLS),
sasl.SASLInitiatingInitializer(xs, required=True),
BindInitializer(xs, required=True),
SessionInitializer(xs, required=False),
]

View file

@ -0,0 +1,475 @@
# -*- test-case-name: twisted.words.test.test_jabbercomponent -*-
#
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
External server-side components.
Most Jabber server implementations allow for add-on components that act as a
separate entity on the Jabber network, but use the server-to-server
functionality of a regular Jabber IM server. These so-called 'external
components' are connected to the Jabber server using the Jabber Component
Protocol as defined in U{JEP-0114<http://www.jabber.org/jeps/jep-0114.html>}.
This module allows for writing external server-side component by assigning one
or more services implementing L{ijabber.IService} up to L{ServiceManager}. The
ServiceManager connects to the Jabber server and is responsible for the
corresponding XML stream.
"""
from zope.interface import implementer
from twisted.application import service
from twisted.internet import defer
from twisted.python import log
from twisted.python.compat import _coercedUnicode, unicode
from twisted.words.xish import domish
from twisted.words.protocols.jabber import error, ijabber, jstrports, xmlstream
from twisted.words.protocols.jabber.jid import internJID as JID
NS_COMPONENT_ACCEPT = 'jabber:component:accept'
def componentFactory(componentid, password):
"""
XML stream factory for external server-side components.
@param componentid: JID of the component.
@type componentid: L{unicode}
@param password: password used to authenticate to the server.
@type password: C{str}
"""
a = ConnectComponentAuthenticator(componentid, password)
return xmlstream.XmlStreamFactory(a)
class ComponentInitiatingInitializer(object):
"""
External server-side component authentication initializer for the
initiating entity.
@ivar xmlstream: XML stream between server and component.
@type xmlstream: L{xmlstream.XmlStream}
"""
def __init__(self, xs):
self.xmlstream = xs
self._deferred = None
def initialize(self):
xs = self.xmlstream
hs = domish.Element((self.xmlstream.namespace, "handshake"))
digest = xmlstream.hashPassword(
xs.sid,
_coercedUnicode(xs.authenticator.password))
hs.addContent(unicode(digest))
# Setup observer to watch for handshake result
xs.addOnetimeObserver("/handshake", self._cbHandshake)
xs.send(hs)
self._deferred = defer.Deferred()
return self._deferred
def _cbHandshake(self, _):
# we have successfully shaken hands and can now consider this
# entity to represent the component JID.
self.xmlstream.thisEntity = self.xmlstream.otherEntity
self._deferred.callback(None)
class ConnectComponentAuthenticator(xmlstream.ConnectAuthenticator):
"""
Authenticator to permit an XmlStream to authenticate against a Jabber
server as an external component (where the Authenticator is initiating the
stream).
"""
namespace = NS_COMPONENT_ACCEPT
def __init__(self, componentjid, password):
"""
@type componentjid: C{str}
@param componentjid: Jabber ID that this component wishes to bind to.
@type password: C{str}
@param password: Password/secret this component uses to authenticate.
"""
# Note that we are sending 'to' our desired component JID.
xmlstream.ConnectAuthenticator.__init__(self, componentjid)
self.password = password
def associateWithStream(self, xs):
xs.version = (0, 0)
xmlstream.ConnectAuthenticator.associateWithStream(self, xs)
xs.initializers = [ComponentInitiatingInitializer(xs)]
class ListenComponentAuthenticator(xmlstream.ListenAuthenticator):
"""
Authenticator for accepting components.
@since: 8.2
@ivar secret: The shared secret used to authorized incoming component
connections.
@type secret: C{unicode}.
"""
namespace = NS_COMPONENT_ACCEPT
def __init__(self, secret):
self.secret = secret
xmlstream.ListenAuthenticator.__init__(self)
def associateWithStream(self, xs):
"""
Associate the authenticator with a stream.
This sets the stream's version to 0.0, because the XEP-0114 component
protocol was not designed for XMPP 1.0.
"""
xs.version = (0, 0)
xmlstream.ListenAuthenticator.associateWithStream(self, xs)
def streamStarted(self, rootElement):
"""
Called by the stream when it has started.
This examines the default namespace of the incoming stream and whether
there is a requested hostname for the component. Then it generates a
stream identifier, sends a response header and adds an observer for
the first incoming element, triggering L{onElement}.
"""
xmlstream.ListenAuthenticator.streamStarted(self, rootElement)
if rootElement.defaultUri != self.namespace:
exc = error.StreamError('invalid-namespace')
self.xmlstream.sendStreamError(exc)
return
# self.xmlstream.thisEntity is set to the address the component
# wants to assume.
if not self.xmlstream.thisEntity:
exc = error.StreamError('improper-addressing')
self.xmlstream.sendStreamError(exc)
return
self.xmlstream.sendHeader()
self.xmlstream.addOnetimeObserver('/*', self.onElement)
def onElement(self, element):
"""
Called on incoming XML Stanzas.
The very first element received should be a request for handshake.
Otherwise, the stream is dropped with a 'not-authorized' error. If a
handshake request was received, the hash is extracted and passed to
L{onHandshake}.
"""
if (element.uri, element.name) == (self.namespace, 'handshake'):
self.onHandshake(unicode(element))
else:
exc = error.StreamError('not-authorized')
self.xmlstream.sendStreamError(exc)
def onHandshake(self, handshake):
"""
Called upon receiving the handshake request.
This checks that the given hash in C{handshake} is equal to a
calculated hash, responding with a handshake reply or a stream error.
If the handshake was ok, the stream is authorized, and XML Stanzas may
be exchanged.
"""
calculatedHash = xmlstream.hashPassword(self.xmlstream.sid,
unicode(self.secret))
if handshake != calculatedHash:
exc = error.StreamError('not-authorized', text='Invalid hash')
self.xmlstream.sendStreamError(exc)
else:
self.xmlstream.send('<handshake/>')
self.xmlstream.dispatch(self.xmlstream,
xmlstream.STREAM_AUTHD_EVENT)
@implementer(ijabber.IService)
class Service(service.Service):
"""
External server-side component service.
"""
def componentConnected(self, xs):
pass
def componentDisconnected(self):
pass
def transportConnected(self, xs):
pass
def send(self, obj):
"""
Send data over service parent's XML stream.
@note: L{ServiceManager} maintains a queue for data sent using this
method when there is no current established XML stream. This data is
then sent as soon as a new stream has been established and initialized.
Subsequently, L{componentConnected} will be called again. If this
queueing is not desired, use C{send} on the XmlStream object (passed to
L{componentConnected}) directly.
@param obj: data to be sent over the XML stream. This is usually an
object providing L{domish.IElement}, or serialized XML. See
L{xmlstream.XmlStream} for details.
"""
self.parent.send(obj)
class ServiceManager(service.MultiService):
"""
Business logic for a managed component connection to a Jabber router.
This service maintains a single connection to a Jabber router and provides
facilities for packet routing and transmission. Business logic modules are
services implementing L{ijabber.IService} (like subclasses of L{Service}),
and added as sub-service.
"""
def __init__(self, jid, password):
service.MultiService.__init__(self)
# Setup defaults
self.jabberId = jid
self.xmlstream = None
# Internal buffer of packets
self._packetQueue = []
# Setup the xmlstream factory
self._xsFactory = componentFactory(self.jabberId, password)
# Register some lambda functions to keep the self.xmlstream var up to
# date
self._xsFactory.addBootstrap(xmlstream.STREAM_CONNECTED_EVENT,
self._connected)
self._xsFactory.addBootstrap(xmlstream.STREAM_AUTHD_EVENT, self._authd)
self._xsFactory.addBootstrap(xmlstream.STREAM_END_EVENT,
self._disconnected)
# Map addBootstrap and removeBootstrap to the underlying factory -- is
# this right? I have no clue...but it'll work for now, until i can
# think about it more.
self.addBootstrap = self._xsFactory.addBootstrap
self.removeBootstrap = self._xsFactory.removeBootstrap
def getFactory(self):
return self._xsFactory
def _connected(self, xs):
self.xmlstream = xs
for c in self:
if ijabber.IService.providedBy(c):
c.transportConnected(xs)
def _authd(self, xs):
# Flush all pending packets
for p in self._packetQueue:
self.xmlstream.send(p)
self._packetQueue = []
# Notify all child services which implement the IService interface
for c in self:
if ijabber.IService.providedBy(c):
c.componentConnected(xs)
def _disconnected(self, _):
self.xmlstream = None
# Notify all child services which implement
# the IService interface
for c in self:
if ijabber.IService.providedBy(c):
c.componentDisconnected()
def send(self, obj):
"""
Send data over the XML stream.
When there is no established XML stream, the data is queued and sent
out when a new XML stream has been established and initialized.
@param obj: data to be sent over the XML stream. This is usually an
object providing L{domish.IElement}, or serialized XML. See
L{xmlstream.XmlStream} for details.
"""
if self.xmlstream != None:
self.xmlstream.send(obj)
else:
self._packetQueue.append(obj)
def buildServiceManager(jid, password, strport):
"""
Constructs a pre-built L{ServiceManager}, using the specified strport
string.
"""
svc = ServiceManager(jid, password)
client_svc = jstrports.client(strport, svc.getFactory())
client_svc.setServiceParent(svc)
return svc
class Router(object):
"""
XMPP Server's Router.
A router connects the different components of the XMPP service and routes
messages between them based on the given routing table.
Connected components are trusted to have correct addressing in the
stanzas they offer for routing.
A route destination of L{None} adds a default route. Traffic for which no
specific route exists, will be routed to this default route.
@since: 8.2
@ivar routes: Routes based on the host part of JIDs. Maps host names to the
L{EventDispatcher<utility.EventDispatcher>}s that should
receive the traffic. A key of L{None} means the default
route.
@type routes: C{dict}
"""
def __init__(self):
self.routes = {}
def addRoute(self, destination, xs):
"""
Add a new route.
The passed XML Stream C{xs} will have an observer for all stanzas
added to route its outgoing traffic. In turn, traffic for
C{destination} will be passed to this stream.
@param destination: Destination of the route to be added as a host name
or L{None} for the default route.
@type destination: C{str} or L{None}.
@param xs: XML Stream to register the route for.
@type xs: L{EventDispatcher<utility.EventDispatcher>}.
"""
self.routes[destination] = xs
xs.addObserver('/*', self.route)
def removeRoute(self, destination, xs):
"""
Remove a route.
@param destination: Destination of the route that should be removed.
@type destination: C{str}.
@param xs: XML Stream to remove the route for.
@type xs: L{EventDispatcher<utility.EventDispatcher>}.
"""
xs.removeObserver('/*', self.route)
if (xs == self.routes[destination]):
del self.routes[destination]
def route(self, stanza):
"""
Route a stanza.
@param stanza: The stanza to be routed.
@type stanza: L{domish.Element}.
"""
destination = JID(stanza['to'])
log.msg("Routing to %s: %r" % (destination.full(), stanza.toXml()))
if destination.host in self.routes:
self.routes[destination.host].send(stanza)
else:
self.routes[None].send(stanza)
class XMPPComponentServerFactory(xmlstream.XmlStreamServerFactory):
"""
XMPP Component Server factory.
This factory accepts XMPP external component connections and makes
the router service route traffic for a component's bound domain
to that component.
@since: 8.2
"""
logTraffic = False
def __init__(self, router, secret='secret'):
self.router = router
self.secret = secret
def authenticatorFactory():
return ListenComponentAuthenticator(self.secret)
xmlstream.XmlStreamServerFactory.__init__(self, authenticatorFactory)
self.addBootstrap(xmlstream.STREAM_CONNECTED_EVENT,
self.onConnectionMade)
self.addBootstrap(xmlstream.STREAM_AUTHD_EVENT,
self.onAuthenticated)
self.serial = 0
def onConnectionMade(self, xs):
"""
Called when a component connection was made.
This enables traffic debugging on incoming streams.
"""
xs.serial = self.serial
self.serial += 1
def logDataIn(buf):
log.msg("RECV (%d): %r" % (xs.serial, buf))
def logDataOut(buf):
log.msg("SEND (%d): %r" % (xs.serial, buf))
if self.logTraffic:
xs.rawDataInFn = logDataIn
xs.rawDataOutFn = logDataOut
xs.addObserver(xmlstream.STREAM_ERROR_EVENT, self.onError)
def onAuthenticated(self, xs):
"""
Called when a component has successfully authenticated.
Add the component to the routing table and establish a handler
for a closed connection.
"""
destination = xs.thisEntity.host
self.router.addRoute(destination, xs)
xs.addObserver(xmlstream.STREAM_END_EVENT, self.onConnectionLost, 0,
destination, xs)
def onError(self, reason):
log.err(reason, "Stream Error")
def onConnectionLost(self, destination, xs, reason):
self.router.removeRoute(destination, xs)

View file

@ -0,0 +1,331 @@
# -*- test-case-name: twisted.words.test.test_jabbererror -*-
#
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
XMPP Error support.
"""
from __future__ import absolute_import, division
import copy
from twisted.python.compat import unicode
from twisted.words.xish import domish
NS_XML = "http://www.w3.org/XML/1998/namespace"
NS_XMPP_STREAMS = "urn:ietf:params:xml:ns:xmpp-streams"
NS_XMPP_STANZAS = "urn:ietf:params:xml:ns:xmpp-stanzas"
STANZA_CONDITIONS = {
'bad-request': {'code': '400', 'type': 'modify'},
'conflict': {'code': '409', 'type': 'cancel'},
'feature-not-implemented': {'code': '501', 'type': 'cancel'},
'forbidden': {'code': '403', 'type': 'auth'},
'gone': {'code': '302', 'type': 'modify'},
'internal-server-error': {'code': '500', 'type': 'wait'},
'item-not-found': {'code': '404', 'type': 'cancel'},
'jid-malformed': {'code': '400', 'type': 'modify'},
'not-acceptable': {'code': '406', 'type': 'modify'},
'not-allowed': {'code': '405', 'type': 'cancel'},
'not-authorized': {'code': '401', 'type': 'auth'},
'payment-required': {'code': '402', 'type': 'auth'},
'recipient-unavailable': {'code': '404', 'type': 'wait'},
'redirect': {'code': '302', 'type': 'modify'},
'registration-required': {'code': '407', 'type': 'auth'},
'remote-server-not-found': {'code': '404', 'type': 'cancel'},
'remote-server-timeout': {'code': '504', 'type': 'wait'},
'resource-constraint': {'code': '500', 'type': 'wait'},
'service-unavailable': {'code': '503', 'type': 'cancel'},
'subscription-required': {'code': '407', 'type': 'auth'},
'undefined-condition': {'code': '500', 'type': None},
'unexpected-request': {'code': '400', 'type': 'wait'},
}
CODES_TO_CONDITIONS = {
'302': ('gone', 'modify'),
'400': ('bad-request', 'modify'),
'401': ('not-authorized', 'auth'),
'402': ('payment-required', 'auth'),
'403': ('forbidden', 'auth'),
'404': ('item-not-found', 'cancel'),
'405': ('not-allowed', 'cancel'),
'406': ('not-acceptable', 'modify'),
'407': ('registration-required', 'auth'),
'408': ('remote-server-timeout', 'wait'),
'409': ('conflict', 'cancel'),
'500': ('internal-server-error', 'wait'),
'501': ('feature-not-implemented', 'cancel'),
'502': ('service-unavailable', 'wait'),
'503': ('service-unavailable', 'cancel'),
'504': ('remote-server-timeout', 'wait'),
'510': ('service-unavailable', 'cancel'),
}
class BaseError(Exception):
"""
Base class for XMPP error exceptions.
@cvar namespace: The namespace of the C{error} element generated by
C{getElement}.
@type namespace: C{str}
@ivar condition: The error condition. The valid values are defined by
subclasses of L{BaseError}.
@type contition: C{str}
@ivar text: Optional text message to supplement the condition or application
specific condition.
@type text: C{unicode}
@ivar textLang: Identifier of the language used for the message in C{text}.
Values are as described in RFC 3066.
@type textLang: C{str}
@ivar appCondition: Application specific condition element, supplementing
the error condition in C{condition}.
@type appCondition: object providing L{domish.IElement}.
"""
namespace = None
def __init__(self, condition, text=None, textLang=None, appCondition=None):
Exception.__init__(self)
self.condition = condition
self.text = text
self.textLang = textLang
self.appCondition = appCondition
def __str__(self):
message = "%s with condition %r" % (self.__class__.__name__,
self.condition)
if self.text:
message += ': ' + self.text
return message
def getElement(self):
"""
Get XML representation from self.
The method creates an L{domish} representation of the
error data contained in this exception.
@rtype: L{domish.Element}
"""
error = domish.Element((None, 'error'))
error.addElement((self.namespace, self.condition))
if self.text:
text = error.addElement((self.namespace, 'text'),
content=self.text)
if self.textLang:
text[(NS_XML, 'lang')] = self.textLang
if self.appCondition:
error.addChild(self.appCondition)
return error
class StreamError(BaseError):
"""
Stream Error exception.
Refer to RFC 3920, section 4.7.3, for the allowed values for C{condition}.
"""
namespace = NS_XMPP_STREAMS
def getElement(self):
"""
Get XML representation from self.
Overrides the base L{BaseError.getElement} to make sure the returned
element is in the XML Stream namespace.
@rtype: L{domish.Element}
"""
from twisted.words.protocols.jabber.xmlstream import NS_STREAMS
error = BaseError.getElement(self)
error.uri = NS_STREAMS
return error
class StanzaError(BaseError):
"""
Stanza Error exception.
Refer to RFC 3920, section 9.3, for the allowed values for C{condition} and
C{type}.
@ivar type: The stanza error type. Gives a suggestion to the recipient
of the error on how to proceed.
@type type: C{str}
@ivar code: A numeric identifier for the error condition for backwards
compatibility with pre-XMPP Jabber implementations.
"""
namespace = NS_XMPP_STANZAS
def __init__(self, condition, type=None, text=None, textLang=None,
appCondition=None):
BaseError.__init__(self, condition, text, textLang, appCondition)
if type is None:
try:
type = STANZA_CONDITIONS[condition]['type']
except KeyError:
pass
self.type = type
try:
self.code = STANZA_CONDITIONS[condition]['code']
except KeyError:
self.code = None
self.children = []
self.iq = None
def getElement(self):
"""
Get XML representation from self.
Overrides the base L{BaseError.getElement} to make sure the returned
element has a C{type} attribute and optionally a legacy C{code}
attribute.
@rtype: L{domish.Element}
"""
error = BaseError.getElement(self)
error['type'] = self.type
if self.code:
error['code'] = self.code
return error
def toResponse(self, stanza):
"""
Construct error response stanza.
The C{stanza} is transformed into an error response stanza by
swapping the C{to} and C{from} addresses and inserting an error
element.
@note: This creates a shallow copy of the list of child elements of the
stanza. The child elements themselves are not copied themselves,
and references to their parent element will still point to the
original stanza element.
The serialization of an element does not use the reference to
its parent, so the typical use case of immediately sending out
the constructed error response is not affected.
@param stanza: the stanza to respond to
@type stanza: L{domish.Element}
"""
from twisted.words.protocols.jabber.xmlstream import toResponse
response = toResponse(stanza, stanzaType='error')
response.children = copy.copy(stanza.children)
response.addChild(self.getElement())
return response
def _parseError(error, errorNamespace):
"""
Parses an error element.
@param error: The error element to be parsed
@type error: L{domish.Element}
@param errorNamespace: The namespace of the elements that hold the error
condition and text.
@type errorNamespace: C{str}
@return: Dictionary with extracted error information. If present, keys
C{condition}, C{text}, C{textLang} have a string value,
and C{appCondition} has an L{domish.Element} value.
@rtype: C{dict}
"""
condition = None
text = None
textLang = None
appCondition = None
for element in error.elements():
if element.uri == errorNamespace:
if element.name == 'text':
text = unicode(element)
textLang = element.getAttribute((NS_XML, 'lang'))
else:
condition = element.name
else:
appCondition = element
return {
'condition': condition,
'text': text,
'textLang': textLang,
'appCondition': appCondition,
}
def exceptionFromStreamError(element):
"""
Build an exception object from a stream error.
@param element: the stream error
@type element: L{domish.Element}
@return: the generated exception object
@rtype: L{StreamError}
"""
error = _parseError(element, NS_XMPP_STREAMS)
exception = StreamError(error['condition'],
error['text'],
error['textLang'],
error['appCondition'])
return exception
def exceptionFromStanza(stanza):
"""
Build an exception object from an error stanza.
@param stanza: the error stanza
@type stanza: L{domish.Element}
@return: the generated exception object
@rtype: L{StanzaError}
"""
children = []
condition = text = textLang = appCondition = type = code = None
for element in stanza.elements():
if element.name == 'error' and element.uri == stanza.uri:
code = element.getAttribute('code')
type = element.getAttribute('type')
error = _parseError(element, NS_XMPP_STANZAS)
condition = error['condition']
text = error['text']
textLang = error['textLang']
appCondition = error['appCondition']
if not condition and code:
condition, type = CODES_TO_CONDITIONS[code]
text = unicode(stanza.error)
else:
children.append(element)
if condition is None:
# TODO: raise exception instead?
return StanzaError(None)
exception = StanzaError(condition, type, text, textLang, appCondition)
exception.children = children
exception.stanza = stanza
return exception

View file

@ -0,0 +1,201 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Public Jabber Interfaces.
"""
from zope.interface import Attribute, Interface
class IInitializer(Interface):
"""
Interface for XML stream initializers.
Initializers perform a step in getting the XML stream ready to be
used for the exchange of XML stanzas.
"""
class IInitiatingInitializer(IInitializer):
"""
Interface for XML stream initializers for the initiating entity.
"""
xmlstream = Attribute("""The associated XML stream""")
def initialize():
"""
Initiate the initialization step.
May return a deferred when the initialization is done asynchronously.
"""
class IIQResponseTracker(Interface):
"""
IQ response tracker interface.
The XMPP stanza C{iq} has a request-response nature that fits
naturally with deferreds. You send out a request and when the response
comes back a deferred is fired.
The L{twisted.words.protocols.jabber.client.IQ} class implements a C{send}
method that returns a deferred. This deferred is put in a dictionary that
is kept in an L{XmlStream} object, keyed by the request stanzas C{id}
attribute.
An object providing this interface (usually an instance of L{XmlStream}),
keeps the said dictionary and sets observers on the iq stanzas of type
C{result} and C{error} and lets the callback fire the associated deferred.
"""
iqDeferreds = Attribute("Dictionary of deferreds waiting for an iq "
"response")
class IXMPPHandler(Interface):
"""
Interface for XMPP protocol handlers.
Objects that provide this interface can be added to a stream manager to
handle of (part of) an XMPP extension protocol.
"""
parent = Attribute("""XML stream manager for this handler""")
xmlstream = Attribute("""The managed XML stream""")
def setHandlerParent(parent):
"""
Set the parent of the handler.
@type parent: L{IXMPPHandlerCollection}
"""
def disownHandlerParent(parent):
"""
Remove the parent of the handler.
@type parent: L{IXMPPHandlerCollection}
"""
def makeConnection(xs):
"""
A connection over the underlying transport of the XML stream has been
established.
At this point, no traffic has been exchanged over the XML stream
given in C{xs}.
This should setup L{xmlstream} and call L{connectionMade}.
@type xs:
L{twisted.words.protocols.jabber.xmlstream.XmlStream}
"""
def connectionMade():
"""
Called after a connection has been established.
This method can be used to change properties of the XML Stream, its
authenticator or the stream manager prior to stream initialization
(including authentication).
"""
def connectionInitialized():
"""
The XML stream has been initialized.
At this point, authentication was successful, and XML stanzas can be
exchanged over the XML stream L{xmlstream}. This method can be
used to setup observers for incoming stanzas.
"""
def connectionLost(reason):
"""
The XML stream has been closed.
Subsequent use of C{parent.send} will result in data being queued
until a new connection has been established.
@type reason: L{twisted.python.failure.Failure}
"""
class IXMPPHandlerCollection(Interface):
"""
Collection of handlers.
Contain several handlers and manage their connection.
"""
def __iter__():
"""
Get an iterator over all child handlers.
"""
def addHandler(handler):
"""
Add a child handler.
@type handler: L{IXMPPHandler}
"""
def removeHandler(handler):
"""
Remove a child handler.
@type handler: L{IXMPPHandler}
"""
class IService(Interface):
"""
External server-side component service interface.
Services that provide this interface can be added to L{ServiceManager} to
implement (part of) the functionality of the server-side component.
"""
def componentConnected(xs):
"""
Parent component has established a connection.
At this point, authentication was successful, and XML stanzas
can be exchanged over the XML stream C{xs}. This method can be used
to setup observers for incoming stanzas.
@param xs: XML Stream that represents the established connection.
@type xs: L{xmlstream.XmlStream}
"""
def componentDisconnected():
"""
Parent component has lost the connection to the Jabber server.
Subsequent use of C{self.parent.send} will result in data being
queued until a new connection has been established.
"""
def transportConnected(xs):
"""
Parent component has established a connection over the underlying
transport.
At this point, no traffic has been exchanged over the XML stream. This
method can be used to change properties of the XML Stream (in C{xs}),
the service manager or it's authenticator prior to stream
initialization (including authentication).
"""

View file

@ -0,0 +1,253 @@
# -*- test-case-name: twisted.words.test.test_jabberjid -*-
#
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Jabber Identifier support.
This module provides an object to represent Jabber Identifiers (JIDs) and
parse string representations into them with proper checking for illegal
characters, case folding and canonicalisation through L{stringprep<twisted.words.protocols.jabber.xmpp_stringprep>}.
"""
from twisted.python.compat import _PY3, unicode
from twisted.words.protocols.jabber.xmpp_stringprep import nodeprep, resourceprep, nameprep
class InvalidFormat(Exception):
"""
The given string could not be parsed into a valid Jabber Identifier (JID).
"""
def parse(jidstring):
"""
Parse given JID string into its respective parts and apply stringprep.
@param jidstring: string representation of a JID.
@type jidstring: L{unicode}
@return: tuple of (user, host, resource), each of type L{unicode} as
the parsed and stringprep'd parts of the given JID. If the
given string did not have a user or resource part, the respective
field in the tuple will hold L{None}.
@rtype: L{tuple}
"""
user = None
host = None
resource = None
# Search for delimiters
user_sep = jidstring.find("@")
res_sep = jidstring.find("/")
if user_sep == -1:
if res_sep == -1:
# host
host = jidstring
else:
# host/resource
host = jidstring[0:res_sep]
resource = jidstring[res_sep + 1:] or None
else:
if res_sep == -1:
# user@host
user = jidstring[0:user_sep] or None
host = jidstring[user_sep + 1:]
else:
if user_sep < res_sep:
# user@host/resource
user = jidstring[0:user_sep] or None
host = jidstring[user_sep + 1:user_sep + (res_sep - user_sep)]
resource = jidstring[res_sep + 1:] or None
else:
# host/resource (with an @ in resource)
host = jidstring[0:res_sep]
resource = jidstring[res_sep + 1:] or None
return prep(user, host, resource)
def prep(user, host, resource):
"""
Perform stringprep on all JID fragments.
@param user: The user part of the JID.
@type user: L{unicode}
@param host: The host part of the JID.
@type host: L{unicode}
@param resource: The resource part of the JID.
@type resource: L{unicode}
@return: The given parts with stringprep applied.
@rtype: L{tuple}
"""
if user:
try:
user = nodeprep.prepare(unicode(user))
except UnicodeError:
raise InvalidFormat("Invalid character in username")
else:
user = None
if not host:
raise InvalidFormat("Server address required.")
else:
try:
host = nameprep.prepare(unicode(host))
except UnicodeError:
raise InvalidFormat("Invalid character in hostname")
if resource:
try:
resource = resourceprep.prepare(unicode(resource))
except UnicodeError:
raise InvalidFormat("Invalid character in resource")
else:
resource = None
return (user, host, resource)
__internJIDs = {}
def internJID(jidstring):
"""
Return interned JID.
@rtype: L{JID}
"""
if jidstring in __internJIDs:
return __internJIDs[jidstring]
else:
j = JID(jidstring)
__internJIDs[jidstring] = j
return j
class JID(object):
"""
Represents a stringprep'd Jabber ID.
JID objects are hashable so they can be used in sets and as keys in
dictionaries.
"""
def __init__(self, str=None, tuple=None):
if not (str or tuple):
raise RuntimeError("You must provide a value for either 'str' or "
"'tuple' arguments.")
if str:
user, host, res = parse(str)
else:
user, host, res = prep(*tuple)
self.user = user
self.host = host
self.resource = res
def userhost(self):
"""
Extract the bare JID as a unicode string.
A bare JID does not have a resource part, so this returns either
C{user@host} or just C{host}.
@rtype: L{unicode}
"""
if self.user:
return u"%s@%s" % (self.user, self.host)
else:
return self.host
def userhostJID(self):
"""
Extract the bare JID.
A bare JID does not have a resource part, so this returns a
L{JID} object representing either C{user@host} or just C{host}.
If the object this method is called upon doesn't have a resource
set, it will return itself. Otherwise, the bare JID object will
be created, interned using L{internJID}.
@rtype: L{JID}
"""
if self.resource:
return internJID(self.userhost())
else:
return self
def full(self):
"""
Return the string representation of this JID.
@rtype: L{unicode}
"""
if self.user:
if self.resource:
return u"%s@%s/%s" % (self.user, self.host, self.resource)
else:
return u"%s@%s" % (self.user, self.host)
else:
if self.resource:
return u"%s/%s" % (self.host, self.resource)
else:
return self.host
def __eq__(self, other):
"""
Equality comparison.
L{JID}s compare equal if their user, host and resource parts all
compare equal. When comparing against instances of other types, it
uses the default comparison.
"""
if isinstance(other, JID):
return (self.user == other.user and
self.host == other.host and
self.resource == other.resource)
else:
return NotImplemented
def __ne__(self, other):
"""
Inequality comparison.
This negates L{__eq__} for comparison with JIDs and uses the default
comparison for other types.
"""
result = self.__eq__(other)
if result is NotImplemented:
return result
else:
return not result
def __hash__(self):
"""
Calculate hash.
L{JID}s with identical constituent user, host and resource parts have
equal hash values. In combination with the comparison defined on JIDs,
this allows for using L{JID}s in sets and as dictionary keys.
"""
return hash((self.user, self.host, self.resource))
def __unicode__(self):
"""
Get unicode representation.
Return the string representation of this JID as a unicode string.
@see: L{full}
"""
return self.full()
if _PY3:
__str__ = __unicode__
def __repr__(self):
"""
Get object representation.
Returns a string that would create a new JID object that compares equal
to this one.
"""
return 'JID(%r)' % self.full()

View file

@ -0,0 +1,33 @@
# -*- test-case-name: twisted.words.test -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
""" A temporary placeholder for client-capable strports, until we
sufficient use cases get identified """
from __future__ import absolute_import, division
from twisted.internet.endpoints import _parse
def _parseTCPSSL(factory, domain, port):
""" For the moment, parse TCP or SSL connections the same """
return (domain, int(port), factory), {}
def _parseUNIX(factory, address):
return (address, factory), {}
_funcs = { "tcp" : _parseTCPSSL,
"unix" : _parseUNIX,
"ssl" : _parseTCPSSL }
def parse(description, factory):
args, kw = _parse(description)
return (args[0].upper(),) + _funcs[args[0]](factory, *args[1:], **kw)
def client(description, factory):
from twisted.application import internet
name, args, kw = parse(description, factory)
return getattr(internet, name + 'Client')(*args, **kw)

View file

@ -0,0 +1,233 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
XMPP-specific SASL profile.
"""
from __future__ import absolute_import, division
from base64 import b64decode, b64encode
import re
from twisted.internet import defer
from twisted.python.compat import unicode
from twisted.words.protocols.jabber import sasl_mechanisms, xmlstream
from twisted.words.xish import domish
NS_XMPP_SASL = 'urn:ietf:params:xml:ns:xmpp-sasl'
def get_mechanisms(xs):
"""
Parse the SASL feature to extract the available mechanism names.
"""
mechanisms = []
for element in xs.features[(NS_XMPP_SASL, 'mechanisms')].elements():
if element.name == 'mechanism':
mechanisms.append(unicode(element))
return mechanisms
class SASLError(Exception):
"""
SASL base exception.
"""
class SASLNoAcceptableMechanism(SASLError):
"""
The server did not present an acceptable SASL mechanism.
"""
class SASLAuthError(SASLError):
"""
SASL Authentication failed.
"""
def __init__(self, condition=None):
self.condition = condition
def __str__(self):
return "SASLAuthError with condition %r" % self.condition
class SASLIncorrectEncodingError(SASLError):
"""
SASL base64 encoding was incorrect.
RFC 3920 specifies that any characters not in the base64 alphabet
and padding characters present elsewhere than at the end of the string
MUST be rejected. See also L{fromBase64}.
This exception is raised whenever the encoded string does not adhere
to these additional restrictions or when the decoding itself fails.
The recommended behaviour for so-called receiving entities (like servers in
client-to-server connections, see RFC 3920 for terminology) is to fail the
SASL negotiation with a C{'incorrect-encoding'} condition. For initiating
entities, one should assume the receiving entity to be either buggy or
malevolent. The stream should be terminated and reconnecting is not
advised.
"""
base64Pattern = re.compile("^[0-9A-Za-z+/]*[0-9A-Za-z+/=]{,2}$")
def fromBase64(s):
"""
Decode base64 encoded string.
This helper performs regular decoding of a base64 encoded string, but also
rejects any characters that are not in the base64 alphabet and padding
occurring elsewhere from the last or last two characters, as specified in
section 14.9 of RFC 3920. This safeguards against various attack vectors
among which the creation of a covert channel that "leaks" information.
"""
if base64Pattern.match(s) is None:
raise SASLIncorrectEncodingError()
try:
return b64decode(s)
except Exception as e:
raise SASLIncorrectEncodingError(str(e))
class SASLInitiatingInitializer(xmlstream.BaseFeatureInitiatingInitializer):
"""
Stream initializer that performs SASL authentication.
The supported mechanisms by this initializer are C{DIGEST-MD5}, C{PLAIN}
and C{ANONYMOUS}. The C{ANONYMOUS} SASL mechanism is used when the JID, set
on the authenticator, does not have a localpart (username), requesting an
anonymous session where the username is generated by the server.
Otherwise, C{DIGEST-MD5} and C{PLAIN} are attempted, in that order.
"""
feature = (NS_XMPP_SASL, 'mechanisms')
_deferred = None
def setMechanism(self):
"""
Select and setup authentication mechanism.
Uses the authenticator's C{jid} and C{password} attribute for the
authentication credentials. If no supported SASL mechanisms are
advertized by the receiving party, a failing deferred is returned with
a L{SASLNoAcceptableMechanism} exception.
"""
jid = self.xmlstream.authenticator.jid
password = self.xmlstream.authenticator.password
mechanisms = get_mechanisms(self.xmlstream)
if jid.user is not None:
if 'DIGEST-MD5' in mechanisms:
self.mechanism = sasl_mechanisms.DigestMD5('xmpp', jid.host, None,
jid.user, password)
elif 'PLAIN' in mechanisms:
self.mechanism = sasl_mechanisms.Plain(None, jid.user, password)
else:
raise SASLNoAcceptableMechanism()
else:
if 'ANONYMOUS' in mechanisms:
self.mechanism = sasl_mechanisms.Anonymous()
else:
raise SASLNoAcceptableMechanism()
def start(self):
"""
Start SASL authentication exchange.
"""
self.setMechanism()
self._deferred = defer.Deferred()
self.xmlstream.addObserver('/challenge', self.onChallenge)
self.xmlstream.addOnetimeObserver('/success', self.onSuccess)
self.xmlstream.addOnetimeObserver('/failure', self.onFailure)
self.sendAuth(self.mechanism.getInitialResponse())
return self._deferred
def sendAuth(self, data=None):
"""
Initiate authentication protocol exchange.
If an initial client response is given in C{data}, it will be
sent along.
@param data: initial client response.
@type data: C{str} or L{None}.
"""
auth = domish.Element((NS_XMPP_SASL, 'auth'))
auth['mechanism'] = self.mechanism.name
if data is not None:
auth.addContent(b64encode(data).decode('ascii') or u'=')
self.xmlstream.send(auth)
def sendResponse(self, data=b''):
"""
Send response to a challenge.
@param data: client response.
@type data: L{bytes}.
"""
response = domish.Element((NS_XMPP_SASL, 'response'))
if data:
response.addContent(b64encode(data).decode('ascii'))
self.xmlstream.send(response)
def onChallenge(self, element):
"""
Parse challenge and send response from the mechanism.
@param element: the challenge protocol element.
@type element: L{domish.Element}.
"""
try:
challenge = fromBase64(unicode(element))
except SASLIncorrectEncodingError:
self._deferred.errback()
else:
self.sendResponse(self.mechanism.getResponse(challenge))
def onSuccess(self, success):
"""
Clean up observers, reset the XML stream and send a new header.
@param success: the success protocol element. For now unused, but
could hold additional data.
@type success: L{domish.Element}
"""
self.xmlstream.removeObserver('/challenge', self.onChallenge)
self.xmlstream.removeObserver('/failure', self.onFailure)
self.xmlstream.reset()
self.xmlstream.sendHeader()
self._deferred.callback(xmlstream.Reset)
def onFailure(self, failure):
"""
Clean up observers, parse the failure and errback the deferred.
@param failure: the failure protocol element. Holds details on
the error condition.
@type failure: L{domish.Element}
"""
self.xmlstream.removeObserver('/challenge', self.onChallenge)
self.xmlstream.removeObserver('/success', self.onSuccess)
try:
condition = failure.firstChildElement().name
except AttributeError:
condition = None
self._deferred.errback(SASLAuthError(condition))

View file

@ -0,0 +1,293 @@
# -*- test-case-name: twisted.words.test.test_jabbersaslmechanisms -*-
#
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Protocol agnostic implementations of SASL authentication mechanisms.
"""
from __future__ import absolute_import, division
import binascii, random, time, os
from hashlib import md5
from zope.interface import Interface, Attribute, implementer
from twisted.python.compat import iteritems, networkString
class ISASLMechanism(Interface):
name = Attribute("""Common name for the SASL Mechanism.""")
def getInitialResponse():
"""
Get the initial client response, if defined for this mechanism.
@return: initial client response string.
@rtype: C{str}.
"""
def getResponse(challenge):
"""
Get the response to a server challenge.
@param challenge: server challenge.
@type challenge: C{str}.
@return: client response.
@rtype: C{str}.
"""
@implementer(ISASLMechanism)
class Anonymous(object):
"""
Implements the ANONYMOUS SASL authentication mechanism.
This mechanism is defined in RFC 2245.
"""
name = 'ANONYMOUS'
def getInitialResponse(self):
return None
@implementer(ISASLMechanism)
class Plain(object):
"""
Implements the PLAIN SASL authentication mechanism.
The PLAIN SASL authentication mechanism is defined in RFC 2595.
"""
name = 'PLAIN'
def __init__(self, authzid, authcid, password):
"""
@param authzid: The authorization identity.
@type authzid: L{unicode}
@param authcid: The authentication identity.
@type authcid: L{unicode}
@param password: The plain-text password.
@type password: L{unicode}
"""
self.authzid = authzid or u''
self.authcid = authcid or u''
self.password = password or u''
def getInitialResponse(self):
return (self.authzid.encode('utf-8') + b"\x00" +
self.authcid.encode('utf-8') + b"\x00" +
self.password.encode('utf-8'))
@implementer(ISASLMechanism)
class DigestMD5(object):
"""
Implements the DIGEST-MD5 SASL authentication mechanism.
The DIGEST-MD5 SASL authentication mechanism is defined in RFC 2831.
"""
name = 'DIGEST-MD5'
def __init__(self, serv_type, host, serv_name, username, password):
"""
@param serv_type: An indication of what kind of server authentication
is being attempted against. For example, C{u"xmpp"}.
@type serv_type: C{unicode}
@param host: The authentication hostname. Also known as the realm.
This is used as a scope to help select the right credentials.
@type host: C{unicode}
@param serv_name: An additional identifier for the server.
@type serv_name: C{unicode}
@param username: The authentication username to use to respond to a
challenge.
@type username: C{unicode}
@param username: The authentication password to use to respond to a
challenge.
@type password: C{unicode}
"""
self.username = username
self.password = password
self.defaultRealm = host
self.digest_uri = u'%s/%s' % (serv_type, host)
if serv_name is not None:
self.digest_uri += u'/%s' % (serv_name,)
def getInitialResponse(self):
return None
def getResponse(self, challenge):
directives = self._parse(challenge)
# Compat for implementations that do not send this along with
# a successful authentication.
if b'rspauth' in directives:
return b''
charset = directives[b'charset'].decode('ascii')
try:
realm = directives[b'realm']
except KeyError:
realm = self.defaultRealm.encode(charset)
return self._genResponse(charset,
realm,
directives[b'nonce'])
def _parse(self, challenge):
"""
Parses the server challenge.
Splits the challenge into a dictionary of directives with values.
@return: challenge directives and their values.
@rtype: C{dict} of C{str} to C{str}.
"""
s = challenge
paramDict = {}
cur = 0
remainingParams = True
while remainingParams:
# Parse a param. We can't just split on commas, because there can
# be some commas inside (quoted) param values, e.g.:
# qop="auth,auth-int"
middle = s.index(b"=", cur)
name = s[cur:middle].lstrip()
middle += 1
if s[middle:middle+1] == b'"':
middle += 1
end = s.index(b'"', middle)
value = s[middle:end]
cur = s.find(b',', end) + 1
if cur == 0:
remainingParams = False
else:
end = s.find(b',', middle)
if end == -1:
value = s[middle:].rstrip()
remainingParams = False
else:
value = s[middle:end].rstrip()
cur = end + 1
paramDict[name] = value
for param in (b'qop', b'cipher'):
if param in paramDict:
paramDict[param] = paramDict[param].split(b',')
return paramDict
def _unparse(self, directives):
"""
Create message string from directives.
@param directives: dictionary of directives (names to their values).
For certain directives, extra quotes are added, as
needed.
@type directives: C{dict} of C{str} to C{str}
@return: message string.
@rtype: C{str}.
"""
directive_list = []
for name, value in iteritems(directives):
if name in (b'username', b'realm', b'cnonce',
b'nonce', b'digest-uri', b'authzid', b'cipher'):
directive = name + b'=' + value
else:
directive = name + b'=' + value
directive_list.append(directive)
return b','.join(directive_list)
def _calculateResponse(self, cnonce, nc, nonce,
username, password, realm, uri):
"""
Calculates response with given encoded parameters.
@return: The I{response} field of a response to a Digest-MD5 challenge
of the given parameters.
@rtype: L{bytes}
"""
def H(s):
return md5(s).digest()
def HEX(n):
return binascii.b2a_hex(n)
def KD(k, s):
return H(k + b':' + s)
a1 = (H(username + b":" + realm + b":" + password) + b":" +
nonce + b":" +
cnonce)
a2 = b"AUTHENTICATE:" + uri
response = HEX(KD(HEX(H(a1)),
nonce + b":" + nc + b":" + cnonce + b":" +
b"auth" + b":" + HEX(H(a2))))
return response
def _genResponse(self, charset, realm, nonce):
"""
Generate response-value.
Creates a response to a challenge according to section 2.1.2.1 of
RFC 2831 using the C{charset}, C{realm} and C{nonce} directives
from the challenge.
"""
try:
username = self.username.encode(charset)
password = self.password.encode(charset)
digest_uri = self.digest_uri.encode(charset)
except UnicodeError:
# TODO - add error checking
raise
nc = networkString('%08x' % (1,)) # TODO: support subsequent auth.
cnonce = self._gen_nonce()
qop = b'auth'
# TODO - add support for authzid
response = self._calculateResponse(cnonce, nc, nonce,
username, password, realm,
digest_uri)
directives = {b'username': username,
b'realm' : realm,
b'nonce' : nonce,
b'cnonce' : cnonce,
b'nc' : nc,
b'qop' : qop,
b'digest-uri': digest_uri,
b'response': response,
b'charset': charset.encode('ascii')}
return self._unparse(directives)
def _gen_nonce(self):
nonceString = "%f:%f:%d" % (random.random(), time.time(), os.getpid())
nonceBytes = networkString(nonceString)
return md5(nonceBytes).hexdigest().encode('ascii')

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,244 @@
# -*- test-case-name: twisted.words.test.test_jabberxmppstringprep -*-
#
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
from encodings import idna
from itertools import chain
import stringprep
# We require Unicode version 3.2.
from unicodedata import ucd_3_2_0 as unicodedata
from twisted.python.compat import unichr
from twisted.python.deprecate import deprecatedModuleAttribute
from incremental import Version
from zope.interface import Interface, implementer
crippled = False
deprecatedModuleAttribute(
Version("Twisted", 13, 1, 0),
"crippled is always False",
__name__,
"crippled")
class ILookupTable(Interface):
"""
Interface for character lookup classes.
"""
def lookup(c):
"""
Return whether character is in this table.
"""
class IMappingTable(Interface):
"""
Interface for character mapping classes.
"""
def map(c):
"""
Return mapping for character.
"""
@implementer(ILookupTable)
class LookupTableFromFunction:
def __init__(self, in_table_function):
self.lookup = in_table_function
@implementer(ILookupTable)
class LookupTable:
def __init__(self, table):
self._table = table
def lookup(self, c):
return c in self._table
@implementer(IMappingTable)
class MappingTableFromFunction:
def __init__(self, map_table_function):
self.map = map_table_function
@implementer(IMappingTable)
class EmptyMappingTable:
def __init__(self, in_table_function):
self._in_table_function = in_table_function
def map(self, c):
if self._in_table_function(c):
return None
else:
return c
class Profile:
def __init__(self, mappings=[], normalize=True, prohibiteds=[],
check_unassigneds=True, check_bidi=True):
self.mappings = mappings
self.normalize = normalize
self.prohibiteds = prohibiteds
self.do_check_unassigneds = check_unassigneds
self.do_check_bidi = check_bidi
def prepare(self, string):
result = self.map(string)
if self.normalize:
result = unicodedata.normalize("NFKC", result)
self.check_prohibiteds(result)
if self.do_check_unassigneds:
self.check_unassigneds(result)
if self.do_check_bidi:
self.check_bidirectionals(result)
return result
def map(self, string):
result = []
for c in string:
result_c = c
for mapping in self.mappings:
result_c = mapping.map(c)
if result_c != c:
break
if result_c is not None:
result.append(result_c)
return u"".join(result)
def check_prohibiteds(self, string):
for c in string:
for table in self.prohibiteds:
if table.lookup(c):
raise UnicodeError("Invalid character %s" % repr(c))
def check_unassigneds(self, string):
for c in string:
if stringprep.in_table_a1(c):
raise UnicodeError("Unassigned code point %s" % repr(c))
def check_bidirectionals(self, string):
found_LCat = False
found_RandALCat = False
for c in string:
if stringprep.in_table_d1(c):
found_RandALCat = True
if stringprep.in_table_d2(c):
found_LCat = True
if found_LCat and found_RandALCat:
raise UnicodeError("Violation of BIDI Requirement 2")
if found_RandALCat and not (stringprep.in_table_d1(string[0]) and
stringprep.in_table_d1(string[-1])):
raise UnicodeError("Violation of BIDI Requirement 3")
class NamePrep:
""" Implements preparation of internationalized domain names.
This class implements preparing internationalized domain names using the
rules defined in RFC 3491, section 4 (Conversion operations).
We do not perform step 4 since we deal with unicode representations of
domain names and do not convert from or to ASCII representations using
punycode encoding. When such a conversion is needed, the C{idna} standard
library provides the C{ToUnicode()} and C{ToASCII()} functions. Note that
C{idna} itself assumes UseSTD3ASCIIRules to be false.
The following steps are performed by C{prepare()}:
- Split the domain name in labels at the dots (RFC 3490, 3.1)
- Apply nameprep proper on each label (RFC 3491)
- Enforce the restrictions on ASCII characters in host names by
assuming STD3ASCIIRules to be true. (STD 3)
- Rejoin the labels using the label separator U+002E (full stop).
"""
# Prohibited characters.
prohibiteds = [unichr(n) for n in chain(range(0x00, 0x2c + 1),
range(0x2e, 0x2f + 1),
range(0x3a, 0x40 + 1),
range(0x5b, 0x60 + 1),
range(0x7b, 0x7f + 1))]
def prepare(self, string):
result = []
labels = idna.dots.split(string)
if labels and len(labels[-1]) == 0:
trailing_dot = u'.'
del labels[-1]
else:
trailing_dot = u''
for label in labels:
result.append(self.nameprep(label))
return u".".join(result) + trailing_dot
def check_prohibiteds(self, string):
for c in string:
if c in self.prohibiteds:
raise UnicodeError("Invalid character %s" % repr(c))
def nameprep(self, label):
label = idna.nameprep(label)
self.check_prohibiteds(label)
if label[0] == u'-':
raise UnicodeError("Invalid leading hyphen-minus")
if label[-1] == u'-':
raise UnicodeError("Invalid trailing hyphen-minus")
return label
C_11 = LookupTableFromFunction(stringprep.in_table_c11)
C_12 = LookupTableFromFunction(stringprep.in_table_c12)
C_21 = LookupTableFromFunction(stringprep.in_table_c21)
C_22 = LookupTableFromFunction(stringprep.in_table_c22)
C_3 = LookupTableFromFunction(stringprep.in_table_c3)
C_4 = LookupTableFromFunction(stringprep.in_table_c4)
C_5 = LookupTableFromFunction(stringprep.in_table_c5)
C_6 = LookupTableFromFunction(stringprep.in_table_c6)
C_7 = LookupTableFromFunction(stringprep.in_table_c7)
C_8 = LookupTableFromFunction(stringprep.in_table_c8)
C_9 = LookupTableFromFunction(stringprep.in_table_c9)
B_1 = EmptyMappingTable(stringprep.in_table_b1)
B_2 = MappingTableFromFunction(stringprep.map_table_b2)
nodeprep = Profile(mappings=[B_1, B_2],
prohibiteds=[C_11, C_12, C_21, C_22,
C_3, C_4, C_5, C_6, C_7, C_8, C_9,
LookupTable([u'"', u'&', u"'", u'/',
u':', u'<', u'>', u'@'])])
resourceprep = Profile(mappings=[B_1,],
prohibiteds=[C_12, C_21, C_22,
C_3, C_4, C_5, C_6, C_7, C_8, C_9])
nameprep = NamePrep()

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,74 @@
# -*- test-case-name: twisted.words.test.test_tap -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Shiny new words service maker
"""
import sys, socket
from twisted.application import strports
from twisted.application.service import MultiService
from twisted.python import usage
from twisted import plugin
from twisted.words import iwords, service
from twisted.cred import checkers, credentials, portal, strcred
class Options(usage.Options, strcred.AuthOptionMixin):
supportedInterfaces = [credentials.IUsernamePassword]
optParameters = [
('hostname', None, socket.gethostname(),
'Name of this server; purely an informative')]
compData = usage.Completions(multiUse=["group"])
interfacePlugins = {}
plg = None
for plg in plugin.getPlugins(iwords.IProtocolPlugin):
assert plg.name not in interfacePlugins
interfacePlugins[plg.name] = plg
optParameters.append((
plg.name + '-port',
None, None,
'strports description of the port to bind for the ' + plg.name + ' server'))
del plg
def __init__(self, *a, **kw):
usage.Options.__init__(self, *a, **kw)
self['groups'] = []
def opt_group(self, name):
"""Specify a group which should exist
"""
self['groups'].append(name.decode(sys.stdin.encoding))
def opt_passwd(self, filename):
"""
Name of a passwd-style file. (This is for
backwards-compatibility only; you should use the --auth
command instead.)
"""
self.addChecker(checkers.FilePasswordDB(filename))
def makeService(config):
credCheckers = config.get('credCheckers', [])
wordsRealm = service.InMemoryWordsRealm(config['hostname'])
wordsPortal = portal.Portal(wordsRealm, credCheckers)
msvc = MultiService()
# XXX Attribute lookup on config is kind of bad - hrm.
for plgName in config.interfacePlugins:
port = config.get(plgName + '-port')
if port is not None:
factory = config.interfacePlugins[plgName].getFactory(wordsRealm, wordsPortal)
svc = strports.service(port, factory)
svc.setServiceParent(msvc)
# This is bogus. createGroup is async. makeService must be
# allowed to return a Deferred or some crap.
for g in config['groups']:
wordsRealm.createGroup(g)
return msvc

View file

@ -0,0 +1 @@
"Words tests"

View file

@ -0,0 +1,68 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.words.im.basechat}.
"""
from twisted.trial import unittest
from twisted.words.im import basechat, basesupport
class ChatUITests(unittest.TestCase):
"""
Tests for the L{basechat.ChatUI} chat client.
"""
def setUp(self):
self.ui = basechat.ChatUI()
self.account = basesupport.AbstractAccount("fooAccount", False, "foo",
"password", "host", "port")
self.person = basesupport.AbstractPerson("foo", self.account)
def test_contactChangedNickNoKey(self):
"""
L{basechat.ChatUI.contactChangedNick} on an
L{twisted.words.im.interfaces.IPerson} who doesn't have an account
associated with the L{basechat.ChatUI} instance has no effect.
"""
self.assertEqual(self.person.name, "foo")
self.assertEqual(self.person.account, self.account)
self.ui.contactChangedNick(self.person, "bar")
self.assertEqual(self.person.name, "foo")
self.assertEqual(self.person.account, self.account)
def test_contactChangedNickNoConversation(self):
"""
L{basechat.ChatUI.contactChangedNick} changes the name for an
L{twisted.words.im.interfaces.IPerson}.
"""
self.ui.persons[self.person.name, self.person.account] = self.person
self.assertEqual(self.person.name, "foo")
self.assertEqual(self.person.account, self.account)
self.ui.contactChangedNick(self.person, "bar")
self.assertEqual(self.person.name, "bar")
self.assertEqual(self.person.account, self.account)
def test_contactChangedNickHasConversation(self):
"""
If an L{twisted.words.im.interfaces.IPerson} is in a
L{basechat.Conversation}, L{basechat.ChatUI.contactChangedNick} causes a
name change for that person in both the L{basechat.Conversation} and the
L{basechat.ChatUI}.
"""
self.ui.persons[self.person.name, self.person.account] = self.person
conversation = basechat.Conversation(self.person, self.ui)
self.ui.conversations[self.person] = conversation
self.assertEqual(self.person.name, "foo")
self.assertEqual(self.person.account, self.account)
self.ui.contactChangedNick(self.person, "bar")
self.assertEqual(self.person.name, "bar")
self.assertEqual(self.person.account, self.account)

View file

@ -0,0 +1,97 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
from twisted.trial import unittest
from twisted.words.im import basesupport
from twisted.internet import error, defer
class DummyAccount(basesupport.AbstractAccount):
"""
An account object that will do nothing when asked to start to log on.
"""
loginHasFailed = False
loginCallbackCalled = False
def _startLogOn(self, *args):
"""
Set self.loginDeferred to the same as the deferred returned, allowing a
testcase to .callback or .errback.
@return: A deferred.
"""
self.loginDeferred = defer.Deferred()
return self.loginDeferred
def _loginFailed(self, result):
self.loginHasFailed = True
return basesupport.AbstractAccount._loginFailed(self, result)
def _cb_logOn(self, result):
self.loginCallbackCalled = True
return basesupport.AbstractAccount._cb_logOn(self, result)
class DummyUI(object):
"""
Provide just the interface required to be passed to AbstractAccount.logOn.
"""
clientRegistered = False
def registerAccountClient(self, result):
self.clientRegistered = True
class ClientMsgTests(unittest.TestCase):
def makeUI(self):
return DummyUI()
def makeAccount(self):
return DummyAccount('la', False, 'la', None, 'localhost', 6667)
def test_connect(self):
"""
Test that account.logOn works, and it calls the right callback when a
connection is established.
"""
account = self.makeAccount()
ui = self.makeUI()
d = account.logOn(ui)
account.loginDeferred.callback(None)
def check(result):
self.assertFalse(account.loginHasFailed,
"Login shouldn't have failed")
self.assertTrue(account.loginCallbackCalled,
"We should be logged in")
d.addCallback(check)
return d
def test_failedConnect(self):
"""
Test that account.logOn works, and it calls the right callback when a
connection is established.
"""
account = self.makeAccount()
ui = self.makeUI()
d = account.logOn(ui)
account.loginDeferred.errback(Exception())
def err(reason):
self.assertTrue(account.loginHasFailed, "Login should have failed")
self.assertFalse(account.loginCallbackCalled,
"We shouldn't be logged in")
self.assertTrue(not ui.clientRegistered,
"Client shouldn't be registered in the UI")
cb = lambda r: self.assertTrue(False, "Shouldn't get called back")
d.addCallbacks(cb, err)
return d
def test_alreadyConnecting(self):
"""
Test that it can fail sensibly when someone tried to connect before
we did.
"""
account = self.makeAccount()
ui = self.makeUI()
account.logOn(ui)
self.assertRaises(error.ConnectError, account.logOn, ui)

View file

@ -0,0 +1,587 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.words.xish.domish}, a DOM-like library for XMPP.
"""
from __future__ import absolute_import, division
from zope.interface.verify import verifyObject
from twisted.python.compat import _PY3, unicode
from twisted.python.reflect import requireModule
from twisted.trial import unittest
from twisted.words.xish import domish
class ElementTests(unittest.TestCase):
"""
Tests for L{domish.Element}.
"""
def test_interface(self):
"""
L{domish.Element} implements L{domish.IElement}.
"""
verifyObject(domish.IElement, domish.Element((None, u"foo")))
def test_escaping(self):
"""
The built-in entity references are properly encoded.
"""
s = "&<>'\""
self.assertEqual(domish.escapeToXml(s), "&amp;&lt;&gt;'\"")
self.assertEqual(domish.escapeToXml(s, 1), "&amp;&lt;&gt;&apos;&quot;")
def test_namespace(self):
"""
An attribute on L{domish.Namespace} yields a qualified name.
"""
ns = domish.Namespace("testns")
self.assertEqual(ns.foo, ("testns", "foo"))
def test_elementInit(self):
"""
Basic L{domish.Element} initialization tests.
"""
e = domish.Element((None, "foo"))
self.assertEqual(e.name, "foo")
self.assertEqual(e.uri, None)
self.assertEqual(e.defaultUri, None)
self.assertEqual(e.parent, None)
e = domish.Element(("", "foo"))
self.assertEqual(e.name, "foo")
self.assertEqual(e.uri, "")
self.assertEqual(e.defaultUri, "")
self.assertEqual(e.parent, None)
e = domish.Element(("testns", "foo"))
self.assertEqual(e.name, "foo")
self.assertEqual(e.uri, "testns")
self.assertEqual(e.defaultUri, "testns")
self.assertEqual(e.parent, None)
e = domish.Element(("testns", "foo"), "test2ns")
self.assertEqual(e.name, "foo")
self.assertEqual(e.uri, "testns")
self.assertEqual(e.defaultUri, "test2ns")
def test_childOps(self):
"""
Basic L{domish.Element} child tests.
"""
e = domish.Element(("testns", "foo"))
e.addContent(u"somecontent")
b2 = e.addElement(("testns2", "bar2"))
e["attrib1"] = "value1"
e[("testns2", "attrib2")] = "value2"
e.addElement("bar")
e.addElement("bar")
e.addContent(u"abc")
e.addContent(u"123")
# Check content merging
self.assertEqual(e.children[-1], "abc123")
# Check direct child accessor
self.assertEqual(e.bar2, b2)
e.bar2.addContent(u"subcontent")
e.bar2["bar2value"] = "somevalue"
# Check child ops
self.assertEqual(e.children[1], e.bar2)
self.assertEqual(e.children[2], e.bar)
# Check attribute ops
self.assertEqual(e["attrib1"], "value1")
del e["attrib1"]
self.assertEqual(e.hasAttribute("attrib1"), 0)
self.assertEqual(e.hasAttribute("attrib2"), 0)
self.assertEqual(e[("testns2", "attrib2")], "value2")
def test_characterData(self):
"""
Extract character data using L{unicode}.
"""
element = domish.Element((u"testns", u"foo"))
element.addContent(u"somecontent")
text = unicode(element)
self.assertEqual(u"somecontent", text)
self.assertIsInstance(text, unicode)
def test_characterDataNativeString(self):
"""
Extract ascii character data using L{str}.
"""
element = domish.Element((u"testns", u"foo"))
element.addContent(u"somecontent")
text = str(element)
self.assertEqual("somecontent", text)
self.assertIsInstance(text, str)
def test_characterDataUnicode(self):
"""
Extract character data using L{unicode}.
"""
element = domish.Element((u"testns", u"foo"))
element.addContent(u"\N{SNOWMAN}")
text = unicode(element)
self.assertEqual(u"\N{SNOWMAN}", text)
self.assertIsInstance(text, unicode)
def test_characterDataBytes(self):
"""
Extract character data as UTF-8 using L{bytes}.
"""
element = domish.Element((u"testns", u"foo"))
element.addContent(u"\N{SNOWMAN}")
text = bytes(element)
self.assertEqual(u"\N{SNOWMAN}".encode('utf-8'), text)
self.assertIsInstance(text, bytes)
def test_characterDataMixed(self):
"""
Mixing addChild with cdata and element, the first cdata is returned.
"""
element = domish.Element((u"testns", u"foo"))
element.addChild(u"abc")
element.addElement("bar")
element.addChild(u"def")
self.assertEqual(u"abc", unicode(element))
def test_addContent(self):
"""
Unicode strings passed to C{addContent} become the character data.
"""
element = domish.Element((u"testns", u"foo"))
element.addContent(u'unicode')
self.assertEqual(u"unicode", unicode(element))
def test_addContentNativeStringASCII(self):
"""
ASCII native strings passed to C{addContent} become the character data.
"""
element = domish.Element((u"testns", u"foo"))
element.addContent('native')
self.assertEqual(u"native", unicode(element))
def test_addContentBytes(self):
"""
Byte strings passed to C{addContent} are not acceptable on Python 3.
"""
element = domish.Element((u"testns", u"foo"))
self.assertRaises(TypeError, element.addContent, b'bytes')
if not _PY3:
test_addContentBytes.skip = (
"Bytes behavior of addContent only provided on Python 3.")
def test_addContentBytesNonASCII(self):
"""
Non-ASCII byte strings passed to C{addContent} yield L{UnicodeError}.
"""
element = domish.Element((u"testns", u"foo"))
self.assertRaises(UnicodeError, element.addContent, b'\xe2\x98\x83')
if _PY3:
test_addContentBytesNonASCII.skip = (
"Bytes behavior of addContent only provided on Python 2.")
def test_addElementContent(self):
"""
Content passed to addElement becomes character data on the new child.
"""
element = domish.Element((u"testns", u"foo"))
child = element.addElement("bar", content=u"abc")
self.assertEqual(u"abc", unicode(child))
def test_elements(self):
"""
Calling C{elements} without arguments on a L{domish.Element} returns
all child elements, whatever the qualified name.
"""
e = domish.Element((u"testns", u"foo"))
c1 = e.addElement(u"name")
c2 = e.addElement((u"testns2", u"baz"))
c3 = e.addElement(u"quux")
c4 = e.addElement((u"testns", u"name"))
elts = list(e.elements())
self.assertIn(c1, elts)
self.assertIn(c2, elts)
self.assertIn(c3, elts)
self.assertIn(c4, elts)
def test_elementsWithQN(self):
"""
Calling C{elements} with a namespace and local name on a
L{domish.Element} returns all child elements with that qualified name.
"""
e = domish.Element((u"testns", u"foo"))
c1 = e.addElement(u"name")
c2 = e.addElement((u"testns2", u"baz"))
c3 = e.addElement(u"quux")
c4 = e.addElement((u"testns", u"name"))
elts = list(e.elements(u"testns", u"name"))
self.assertIn(c1, elts)
self.assertNotIn(c2, elts)
self.assertNotIn(c3, elts)
self.assertIn(c4, elts)
class DomishStreamTestsMixin:
"""
Mixin defining tests for different stream implementations.
@ivar streamClass: A no-argument callable which will be used to create an
XML parser which can produce a stream of elements from incremental
input.
"""
def setUp(self):
self.doc_started = False
self.doc_ended = False
self.root = None
self.elements = []
self.stream = self.streamClass()
self.stream.DocumentStartEvent = self._docStarted
self.stream.ElementEvent = self.elements.append
self.stream.DocumentEndEvent = self._docEnded
def _docStarted(self, root):
self.root = root
self.doc_started = True
def _docEnded(self):
self.doc_ended = True
def doTest(self, xml):
self.stream.parse(xml)
def testHarness(self):
xml = b"<root><child/><child2/></root>"
self.stream.parse(xml)
self.assertEqual(self.doc_started, True)
self.assertEqual(self.root.name, 'root')
self.assertEqual(self.elements[0].name, 'child')
self.assertEqual(self.elements[1].name, 'child2')
self.assertEqual(self.doc_ended, True)
def testBasic(self):
xml = b"<stream:stream xmlns:stream='etherx' xmlns='jabber'>\n" + \
b" <message to='bar'>" + \
b" <x xmlns='xdelay'>some&amp;data&gt;</x>" + \
b" </message>" + \
b"</stream:stream>"
self.stream.parse(xml)
self.assertEqual(self.root.name, 'stream')
self.assertEqual(self.root.uri, 'etherx')
self.assertEqual(self.elements[0].name, 'message')
self.assertEqual(self.elements[0].uri, 'jabber')
self.assertEqual(self.elements[0]['to'], 'bar')
self.assertEqual(self.elements[0].x.uri, 'xdelay')
self.assertEqual(unicode(self.elements[0].x), 'some&data>')
def testNoRootNS(self):
xml = b"<stream><error xmlns='etherx'/></stream>"
self.stream.parse(xml)
self.assertEqual(self.root.uri, '')
self.assertEqual(self.elements[0].uri, 'etherx')
def testNoDefaultNS(self):
xml = b"<stream:stream xmlns:stream='etherx'><error/></stream:stream>"
self.stream.parse(xml)
self.assertEqual(self.root.uri, 'etherx')
self.assertEqual(self.root.defaultUri, '')
self.assertEqual(self.elements[0].uri, '')
self.assertEqual(self.elements[0].defaultUri, '')
def testChildDefaultNS(self):
xml = b"<root xmlns='testns'><child/></root>"
self.stream.parse(xml)
self.assertEqual(self.root.uri, 'testns')
self.assertEqual(self.elements[0].uri, 'testns')
def testEmptyChildNS(self):
xml = b"""<root xmlns='testns'>
<child1><child2 xmlns=''/></child1>
</root>"""
self.stream.parse(xml)
self.assertEqual(self.elements[0].child2.uri, '')
def test_namespaceWithWhitespace(self):
"""
Whitespace in an xmlns value is preserved in the resulting node's C{uri}
attribute.
"""
xml = b"<root xmlns:foo=' bar baz '><foo:bar foo:baz='quux'/></root>"
self.stream.parse(xml)
self.assertEqual(self.elements[0].uri, " bar baz ")
self.assertEqual(
self.elements[0].attributes, {(" bar baz ", "baz"): "quux"})
def test_attributesWithNamespaces(self):
"""
Attributes with namespace are parsed without Exception.
(https://twistedmatrix.com/trac/ticket/9730 regression test)
"""
xml = b"""<root xmlns:test='http://example.org' xml:lang='en'>
<test:test>test</test:test>
</root>"""
# with Python 3.8 and without #9730 fix, the following error would
# happen at next line:
# ``RuntimeError: dictionary keys changed during iteration``
self.stream.parse(xml)
self.assertEqual(self.elements[0].uri, "http://example.org")
def testChildPrefix(self):
xml = b"<root xmlns='testns' xmlns:foo='testns2'><foo:child/></root>"
self.stream.parse(xml)
self.assertEqual(self.root.localPrefixes['foo'], 'testns2')
self.assertEqual(self.elements[0].uri, 'testns2')
def testUnclosedElement(self):
self.assertRaises(domish.ParserError, self.stream.parse,
b"<root><error></root>")
def test_namespaceReuse(self):
"""
Test that reuse of namespaces does affect an element's serialization.
When one element uses a prefix for a certain namespace, this is
stored in the C{localPrefixes} attribute of the element. We want
to make sure that elements created after such use, won't have this
prefix end up in their C{localPrefixes} attribute, too.
"""
xml = b"""<root>
<foo:child1 xmlns:foo='testns'/>
<child2 xmlns='testns'/>
</root>"""
self.stream.parse(xml)
self.assertEqual('child1', self.elements[0].name)
self.assertEqual('testns', self.elements[0].uri)
self.assertEqual('', self.elements[0].defaultUri)
self.assertEqual({'foo': 'testns'}, self.elements[0].localPrefixes)
self.assertEqual('child2', self.elements[1].name)
self.assertEqual('testns', self.elements[1].uri)
self.assertEqual('testns', self.elements[1].defaultUri)
self.assertEqual({}, self.elements[1].localPrefixes)
class DomishExpatStreamTests(DomishStreamTestsMixin, unittest.TestCase):
"""
Tests for L{domish.ExpatElementStream}, the expat-based element stream
implementation.
"""
streamClass = domish.ExpatElementStream
if requireModule('pyexpat', default=None) is None:
skip = "pyexpat is required for ExpatElementStream tests."
else:
skip = None
class DomishSuxStreamTests(DomishStreamTestsMixin, unittest.TestCase):
"""
Tests for L{domish.SuxElementStream}, the L{twisted.web.sux}-based element
stream implementation.
"""
streamClass = domish.SuxElementStream
class SerializerTests(unittest.TestCase):
def testNoNamespace(self):
e = domish.Element((None, "foo"))
self.assertEqual(e.toXml(), "<foo/>")
self.assertEqual(e.toXml(closeElement = 0), "<foo>")
def testDefaultNamespace(self):
e = domish.Element(("testns", "foo"))
self.assertEqual(e.toXml(), "<foo xmlns='testns'/>")
def testOtherNamespace(self):
e = domish.Element(("testns", "foo"), "testns2")
self.assertEqual(e.toXml({'testns': 'bar'}),
"<bar:foo xmlns:bar='testns' xmlns='testns2'/>")
def testChildDefaultNamespace(self):
e = domish.Element(("testns", "foo"))
e.addElement("bar")
self.assertEqual(e.toXml(), "<foo xmlns='testns'><bar/></foo>")
def testChildSameNamespace(self):
e = domish.Element(("testns", "foo"))
e.addElement(("testns", "bar"))
self.assertEqual(e.toXml(), "<foo xmlns='testns'><bar/></foo>")
def testChildSameDefaultNamespace(self):
e = domish.Element(("testns", "foo"))
e.addElement("bar", "testns")
self.assertEqual(e.toXml(), "<foo xmlns='testns'><bar/></foo>")
def testChildOtherDefaultNamespace(self):
e = domish.Element(("testns", "foo"))
e.addElement(("testns2", "bar"), 'testns2')
self.assertEqual(e.toXml(), "<foo xmlns='testns'><bar xmlns='testns2'/></foo>")
def testOnlyChildDefaultNamespace(self):
e = domish.Element((None, "foo"))
e.addElement(("ns2", "bar"), 'ns2')
self.assertEqual(e.toXml(), "<foo><bar xmlns='ns2'/></foo>")
def testOnlyChildDefaultNamespace2(self):
e = domish.Element((None, "foo"))
e.addElement("bar")
self.assertEqual(e.toXml(), "<foo><bar/></foo>")
def testChildInDefaultNamespace(self):
e = domish.Element(("testns", "foo"), "testns2")
e.addElement(("testns2", "bar"))
self.assertEqual(e.toXml(), "<xn0:foo xmlns:xn0='testns' xmlns='testns2'><bar/></xn0:foo>")
def testQualifiedAttribute(self):
e = domish.Element((None, "foo"),
attribs = {("testns2", "bar"): "baz"})
self.assertEqual(e.toXml(), "<foo xmlns:xn0='testns2' xn0:bar='baz'/>")
def testQualifiedAttributeDefaultNS(self):
e = domish.Element(("testns", "foo"),
attribs = {("testns", "bar"): "baz"})
self.assertEqual(e.toXml(), "<foo xmlns='testns' xmlns:xn0='testns' xn0:bar='baz'/>")
def testTwoChilds(self):
e = domish.Element(('', "foo"))
child1 = e.addElement(("testns", "bar"), "testns2")
child1.addElement(('testns2', 'quux'))
child2 = e.addElement(("testns3", "baz"), "testns4")
child2.addElement(('testns', 'quux'))
self.assertEqual(e.toXml(), "<foo><xn0:bar xmlns:xn0='testns' xmlns='testns2'><quux/></xn0:bar><xn1:baz xmlns:xn1='testns3' xmlns='testns4'><xn0:quux xmlns:xn0='testns'/></xn1:baz></foo>")
def testXMLNamespace(self):
e = domish.Element((None, "foo"),
attribs = {("http://www.w3.org/XML/1998/namespace",
"lang"): "en_US"})
self.assertEqual(e.toXml(), "<foo xml:lang='en_US'/>")
def testQualifiedAttributeGivenListOfPrefixes(self):
e = domish.Element((None, "foo"),
attribs = {("testns2", "bar"): "baz"})
self.assertEqual(e.toXml({"testns2": "qux"}),
"<foo xmlns:qux='testns2' qux:bar='baz'/>")
def testNSPrefix(self):
e = domish.Element((None, "foo"),
attribs = {("testns2", "bar"): "baz"})
c = e.addElement(("testns2", "qux"))
c[("testns2", "bar")] = "quux"
self.assertEqual(e.toXml(), "<foo xmlns:xn0='testns2' xn0:bar='baz'><xn0:qux xn0:bar='quux'/></foo>")
def testDefaultNSPrefix(self):
e = domish.Element((None, "foo"),
attribs = {("testns2", "bar"): "baz"})
c = e.addElement(("testns2", "qux"))
c[("testns2", "bar")] = "quux"
c.addElement('foo')
self.assertEqual(e.toXml(), "<foo xmlns:xn0='testns2' xn0:bar='baz'><xn0:qux xn0:bar='quux'><xn0:foo/></xn0:qux></foo>")
def testPrefixScope(self):
e = domish.Element(('testns', 'foo'))
self.assertEqual(e.toXml(prefixes={'testns': 'bar'},
prefixesInScope=['bar']),
"<bar:foo/>")
def testLocalPrefixes(self):
e = domish.Element(('testns', 'foo'), localPrefixes={'bar': 'testns'})
self.assertEqual(e.toXml(), "<bar:foo xmlns:bar='testns'/>")
def testLocalPrefixesWithChild(self):
e = domish.Element(('testns', 'foo'), localPrefixes={'bar': 'testns'})
e.addElement('baz')
self.assertIdentical(e.baz.defaultUri, None)
self.assertEqual(e.toXml(), "<bar:foo xmlns:bar='testns'><baz/></bar:foo>")
def test_prefixesReuse(self):
"""
Test that prefixes passed to serialization are not modified.
This test makes sure that passing a dictionary of prefixes repeatedly
to C{toXml} of elements does not cause serialization errors. A
previous implementation changed the passed in dictionary internally,
causing havoc later on.
"""
prefixes = {'testns': 'foo'}
# test passing of dictionary
s = domish.SerializerClass(prefixes=prefixes)
self.assertNotIdentical(prefixes, s.prefixes)
# test proper serialization on prefixes reuse
e = domish.Element(('testns2', 'foo'),
localPrefixes={'quux': 'testns2'})
self.assertEqual("<quux:foo xmlns:quux='testns2'/>",
e.toXml(prefixes=prefixes))
e = domish.Element(('testns2', 'foo'))
self.assertEqual("<foo xmlns='testns2'/>",
e.toXml(prefixes=prefixes))
def testRawXMLSerialization(self):
e = domish.Element((None, "foo"))
e.addRawXml("<abc123>")
# The testcase below should NOT generate valid XML -- that's
# the whole point of using the raw XML call -- it's the callers
# responsibility to ensure that the data inserted is valid
self.assertEqual(e.toXml(), "<foo><abc123></foo>")
def testRawXMLWithUnicodeSerialization(self):
e = domish.Element((None, "foo"))
e.addRawXml(u"<degree>\u00B0</degree>")
self.assertEqual(e.toXml(), u"<foo><degree>\u00B0</degree></foo>")
def testUnicodeSerialization(self):
e = domish.Element((None, "foo"))
e["test"] = u"my value\u0221e"
e.addContent(u"A degree symbol...\u00B0")
self.assertEqual(e.toXml(),
u"<foo test='my value\u0221e'>A degree symbol...\u00B0</foo>")

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,291 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for IRC portions of L{twisted.words.service}.
"""
from twisted.cred import checkers, portal
from twisted.test import proto_helpers
from twisted.words.protocols import irc
from twisted.words.service import InMemoryWordsRealm, IRCFactory, IRCUser
from twisted.words.test.test_irc import IRCTestCase
class IRCUserTests(IRCTestCase):
"""
Isolated tests for L{IRCUser}
"""
def setUp(self):
"""
Sets up a Realm, Portal, Factory, IRCUser, Transport, and Connection
for our tests.
"""
self.realm = InMemoryWordsRealm("example.com")
self.checker = checkers.InMemoryUsernamePasswordDatabaseDontUse()
self.portal = portal.Portal(self.realm, [self.checker])
self.checker.addUser(u"john", u"pass")
self.factory = IRCFactory(self.realm, self.portal)
self.ircUser = self.factory.buildProtocol(None)
self.stringTransport = proto_helpers.StringTransport()
self.ircUser.makeConnection(self.stringTransport)
def test_sendMessage(self):
"""
Sending a message to a user after they have sent NICK, but before they
have authenticated, results in a message from "example.com".
"""
self.ircUser.irc_NICK("", ["mynick"])
self.stringTransport.clear()
self.ircUser.sendMessage("foo")
self.assertEqualBufferValue(self.stringTransport.value(), ":example.com foo mynick\r\n")
def test_utf8Messages(self):
"""
When a UTF8 message is sent with sendMessage and the current IRCUser
has a UTF8 nick and is set to UTF8 encoding, the message will be
written to the transport.
"""
expectedResult = (u":example.com \u0442\u0435\u0441\u0442 "
u"\u043d\u0438\u043a\r\n").encode('utf-8')
self.ircUser.irc_NICK("", [u"\u043d\u0438\u043a".encode('utf-8')])
self.stringTransport.clear()
self.ircUser.sendMessage(u"\u0442\u0435\u0441\u0442".encode('utf-8'))
self.assertEqualBufferValue(self.stringTransport.value(), expectedResult)
def test_invalidEncodingNick(self):
"""
A NICK command sent with a nickname that cannot be decoded with the
current IRCUser's encoding results in a PRIVMSG from NickServ
indicating that the nickname could not be decoded.
"""
self.ircUser.irc_NICK("", [b"\xd4\xc5\xd3\xd4"])
self.assertRaises(UnicodeError)
def response(self):
"""
Grabs our responses and then clears the transport
"""
response = self.ircUser.transport.value()
self.ircUser.transport.clear()
if bytes != str and isinstance(response, bytes):
response = response.decode("utf-8")
response = response.splitlines()
return [irc.parsemsg(r) for r in response]
def scanResponse(self, response, messageType):
"""
Gets messages out of a response
@param response: The parsed IRC messages of the response, as returned
by L{IRCUserTests.response}
@param messageType: The string type of the desired messages.
@return: An iterator which yields 2-tuples of C{(index, ircMessage)}
"""
for n, message in enumerate(response):
if (message[1] == messageType):
yield n, message
def test_sendNickSendsGreeting(self):
"""
Receiving NICK without authenticating sends the MOTD Start and MOTD End
messages, which is required by certain popular IRC clients (such as
Pidgin) before a connection is considered to be fully established.
"""
self.ircUser.irc_NICK("", ["mynick"])
response = self.response()
start = list(self.scanResponse(response, irc.RPL_MOTDSTART))
end = list(self.scanResponse(response, irc.RPL_ENDOFMOTD))
self.assertEqual(start,
[(0, ('example.com', '375', ['mynick', '- example.com Message of the Day - ']))])
self.assertEqual(end,
[(1, ('example.com', '376', ['mynick', 'End of /MOTD command.']))])
def test_fullLogin(self):
"""
Receiving USER, PASS, NICK will log in the user, and transmit the
appropriate response messages.
"""
self.ircUser.irc_USER("", ["john doe"])
self.ircUser.irc_PASS("", ["pass"])
self.ircUser.irc_NICK("", ["john"])
version = ('Your host is example.com, running version %s' %
(self.factory._serverInfo["serviceVersion"],))
creation = ('This server was created on %s' %
(self.factory._serverInfo["creationDate"],))
self.assertEqual(self.response(),
[('example.com', '375',
['john', '- example.com Message of the Day - ']),
('example.com', '376', ['john', 'End of /MOTD command.']),
('example.com', '001', ['john', 'connected to Twisted IRC']),
('example.com', '002', ['john', version]),
('example.com', '003', ['john', creation]),
('example.com', '004',
['john', 'example.com', self.factory._serverInfo["serviceVersion"],
'w', 'n'])])
def test_PART(self):
"""
irc_PART
"""
self.ircUser.irc_NICK("testuser", ["mynick"])
response = self.response()
self.ircUser.transport.clear()
self.assertEqual(response[0][1], irc.RPL_MOTDSTART)
self.ircUser.irc_JOIN("testuser", ["somechannel"])
response = self.response()
self.ircUser.transport.clear()
self.assertEqual(response[0][1], irc.ERR_NOSUCHCHANNEL)
self.ircUser.irc_PART("testuser", [b"somechannel", b"booga"])
response = self.response()
self.ircUser.transport.clear()
self.assertEqual(response[0][1], irc.ERR_NOTONCHANNEL)
self.ircUser.irc_PART("testuser", [u"somechannel", u"booga"])
response = self.response()
self.ircUser.transport.clear()
self.assertEqual(response[0][1], irc.ERR_NOTONCHANNEL)
def test_NAMES(self):
"""
irc_NAMES
"""
self.ircUser.irc_NICK("", ["testuser"])
self.ircUser.irc_JOIN("", ["somechannel"])
self.ircUser.transport.clear()
self.ircUser.irc_NAMES("", ["somechannel"])
response = self.response()
self.assertEqual(response[0][1], irc.RPL_ENDOFNAMES)
class MocksyIRCUser(IRCUser):
def __init__(self):
self.realm = InMemoryWordsRealm("example.com")
self.mockedCodes = []
def sendMessage(self, code, *_, **__):
self.mockedCodes.append(code)
BADTEXT = b'\xff'
class IRCUserBadEncodingTests(IRCTestCase):
"""
Verifies that L{IRCUser} sends the correct error messages back to clients
when given indecipherable bytes
"""
# TODO: irc_NICK -- but NICKSERV is used for that, so it isn't as easy.
def setUp(self):
self.ircUser = MocksyIRCUser()
def assertChokesOnBadBytes(self, irc_x, error):
"""
Asserts that IRCUser sends the relevant error code when a given irc_x
dispatch method is given undecodable bytes.
@param irc_x: the name of the irc_FOO method to test.
For example, irc_x = 'PRIVMSG' will check irc_PRIVMSG
@param error: the error code irc_x should send. For example,
irc.ERR_NOTONCHANNEL
"""
getattr(self.ircUser, 'irc_%s' % irc_x)(None, [BADTEXT])
self.assertEqual(self.ircUser.mockedCodes, [error])
# No such channel
def test_JOIN(self):
"""
Tests that irc_JOIN sends ERR_NOSUCHCHANNEL if the channel name can't
be decoded.
"""
self.assertChokesOnBadBytes('JOIN', irc.ERR_NOSUCHCHANNEL)
def test_NAMES(self):
"""
Tests that irc_NAMES sends ERR_NOSUCHCHANNEL if the channel name can't
be decoded.
"""
self.assertChokesOnBadBytes('NAMES', irc.ERR_NOSUCHCHANNEL)
def test_TOPIC(self):
"""
Tests that irc_TOPIC sends ERR_NOSUCHCHANNEL if the channel name can't
be decoded.
"""
self.assertChokesOnBadBytes('TOPIC', irc.ERR_NOSUCHCHANNEL)
def test_LIST(self):
"""
Tests that irc_LIST sends ERR_NOSUCHCHANNEL if the channel name can't
be decoded.
"""
self.assertChokesOnBadBytes('LIST', irc.ERR_NOSUCHCHANNEL)
# No such nick
def test_MODE(self):
"""
Tests that irc_MODE sends ERR_NOSUCHNICK if the target name can't
be decoded.
"""
self.assertChokesOnBadBytes('MODE', irc.ERR_NOSUCHNICK)
def test_PRIVMSG(self):
"""
Tests that irc_PRIVMSG sends ERR_NOSUCHNICK if the target name can't
be decoded.
"""
self.assertChokesOnBadBytes('PRIVMSG', irc.ERR_NOSUCHNICK)
def test_WHOIS(self):
"""
Tests that irc_WHOIS sends ERR_NOSUCHNICK if the target name can't
be decoded.
"""
self.assertChokesOnBadBytes('WHOIS', irc.ERR_NOSUCHNICK)
# Not on channel
def test_PART(self):
"""
Tests that irc_PART sends ERR_NOTONCHANNEL if the target name can't
be decoded.
"""
self.assertChokesOnBadBytes('PART', irc.ERR_NOTONCHANNEL)
# Probably nothing
def test_WHO(self):
"""
Tests that irc_WHO immediately ends the WHO list if the target name
can't be decoded.
"""
self.assertChokesOnBadBytes('WHO', irc.RPL_ENDOFWHO)

View file

@ -0,0 +1,292 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.words.im.ircsupport}.
"""
from twisted.test.proto_helpers import StringTransport
from twisted.words.im.basechat import ChatUI, Conversation, GroupConversation
from twisted.words.im.ircsupport import IRCAccount, IRCProto
from twisted.words.im.locals import OfflineError
from twisted.words.test.test_irc import IRCTestCase
class StubConversation(Conversation):
def show(self):
pass
def showMessage(self, message, metadata):
self.message = message
self.metadata = metadata
class StubGroupConversation(GroupConversation):
def setTopic(self, topic, nickname):
self.topic = topic
self.topicSetBy = nickname
def show(self):
pass
def showGroupMessage(self, sender, text, metadata=None):
self.metadata = metadata
self.text = text
self.metadata = metadata
class StubChatUI(ChatUI):
def getConversation(self, group, Class=StubConversation, stayHidden=0):
return ChatUI.getGroupConversation(self, group, Class, stayHidden)
def getGroupConversation(self, group, Class=StubGroupConversation, stayHidden=0):
return ChatUI.getGroupConversation(self, group, Class, stayHidden)
class IRCProtoTests(IRCTestCase):
"""
Tests for L{IRCProto}.
"""
def setUp(self):
self.account = IRCAccount(
"Some account", False, "alice", None, "example.com", 6667)
self.proto = IRCProto(self.account, StubChatUI(), None)
self.transport = StringTransport()
def test_login(self):
"""
When L{IRCProto} is connected to a transport, it sends I{NICK} and
I{USER} commands with the username from the account object.
"""
self.proto.makeConnection(self.transport)
self.assertEqualBufferValue(
self.transport.value(),
"NICK alice\r\n"
"USER alice foo bar :Twisted-IM user\r\n")
def test_authenticate(self):
"""
If created with an account with a password, L{IRCProto} sends a
I{PASS} command before the I{NICK} and I{USER} commands.
"""
self.account.password = "secret"
self.proto.makeConnection(self.transport)
self.assertEqualBufferValue(
self.transport.value(),
"PASS secret\r\n"
"NICK alice\r\n"
"USER alice foo bar :Twisted-IM user\r\n")
def test_channels(self):
"""
If created with an account with a list of channels, L{IRCProto}
joins each of those channels after registering.
"""
self.account.channels = ['#foo', '#bar']
self.proto.makeConnection(self.transport)
self.assertEqualBufferValue(
self.transport.value(),
"NICK alice\r\n"
"USER alice foo bar :Twisted-IM user\r\n"
"JOIN #foo\r\n"
"JOIN #bar\r\n")
def test_isupport(self):
"""
L{IRCProto} can interpret I{ISUPPORT} (I{005}) messages from the server
and reflect their information in its C{supported} attribute.
"""
self.proto.makeConnection(self.transport)
self.proto.dataReceived(
":irc.example.com 005 alice MODES=4 CHANLIMIT=#:20\r\n")
self.assertEqual(4, self.proto.supported.getFeature("MODES"))
def test_nick(self):
"""
IRC NICK command changes the nickname of a user.
"""
self.proto.makeConnection(self.transport)
self.proto.dataReceived(":alice JOIN #group1\r\n")
self.proto.dataReceived(":alice1 JOIN #group1\r\n")
self.proto.dataReceived(":alice1 NICK newnick\r\n")
self.proto.dataReceived(":alice3 NICK newnick3\r\n")
self.assertIn("newnick", self.proto._ingroups)
self.assertNotIn("alice1", self.proto._ingroups)
def test_part(self):
"""
IRC PART command removes a user from an IRC channel.
"""
self.proto.makeConnection(self.transport)
self.proto.dataReceived(":alice1 JOIN #group1\r\n")
self.assertIn("group1", self.proto._ingroups["alice1"])
self.assertNotIn("group2", self.proto._ingroups["alice1"])
self.proto.dataReceived(":alice PART #group1\r\n")
self.proto.dataReceived(":alice1 PART #group1\r\n")
self.proto.dataReceived(":alice1 PART #group2\r\n")
self.assertNotIn("group1", self.proto._ingroups["alice1"])
self.assertNotIn("group2", self.proto._ingroups["alice1"])
def test_quit(self):
"""
IRC QUIT command removes a user from all IRC channels.
"""
self.proto.makeConnection(self.transport)
self.proto.dataReceived(":alice1 JOIN #group1\r\n")
self.assertIn("group1", self.proto._ingroups["alice1"])
self.assertNotIn("group2", self.proto._ingroups["alice1"])
self.proto.dataReceived(":alice1 JOIN #group3\r\n")
self.assertIn("group3", self.proto._ingroups["alice1"])
self.proto.dataReceived(":alice1 QUIT\r\n")
self.assertTrue(len(self.proto._ingroups["alice1"]) == 0)
self.proto.dataReceived(":alice3 QUIT\r\n")
def test_topic(self):
"""
IRC TOPIC command changes the topic of an IRC channel.
"""
self.proto.makeConnection(self.transport)
self.proto.dataReceived(":alice1 JOIN #group1\r\n")
self.proto.dataReceived(":alice1 TOPIC #group1 newtopic\r\n")
groupConversation = self.proto.getGroupConversation("group1")
self.assertEqual(groupConversation.topic, "newtopic")
self.assertEqual(groupConversation.topicSetBy, "alice1")
def test_privmsg(self):
"""
PRIVMSG sends a private message to a user or channel.
"""
self.proto.makeConnection(self.transport)
self.proto.dataReceived(":alice1 PRIVMSG t2 test_message_1\r\n")
conversation = self.proto.chat.getConversation(
self.proto.getPerson("alice1"))
self.assertEqual(conversation.message, "test_message_1")
self.proto.dataReceived(":alice1 PRIVMSG #group1 test_message_2\r\n")
groupConversation = self.proto.getGroupConversation("group1")
self.assertEqual(groupConversation.text, "test_message_2")
self.proto.setNick("alice")
self.proto.dataReceived(":alice PRIVMSG #foo test_message_3\r\n")
groupConversation = self.proto.getGroupConversation("foo")
self.assertFalse(hasattr(groupConversation, "text"))
conversation = self.proto.chat.getConversation(
self.proto.getPerson("alice"))
self.assertFalse(hasattr(conversation, "message"))
def test_action(self):
"""
CTCP ACTION to a user or channel.
"""
self.proto.makeConnection(self.transport)
self.proto.dataReceived(":alice1 PRIVMSG alice1 :\01ACTION smiles\01\r\n")
conversation = self.proto.chat.getConversation(
self.proto.getPerson("alice1"))
self.assertEqual(conversation.message, "smiles")
self.proto.dataReceived(":alice1 PRIVMSG #group1 :\01ACTION laughs\01\r\n")
groupConversation = self.proto.getGroupConversation("group1")
self.assertEqual(groupConversation.text, "laughs")
self.proto.setNick("alice")
self.proto.dataReceived(":alice PRIVMSG #group1 :\01ACTION cries\01\r\n")
groupConversation = self.proto.getGroupConversation("foo")
self.assertFalse(hasattr(groupConversation, "text"))
conversation = self.proto.chat.getConversation(
self.proto.getPerson("alice"))
self.assertFalse(hasattr(conversation, "message"))
def test_rplNamreply(self):
"""
RPL_NAMREPLY server response (353) lists all the users in a channel.
RPL_ENDOFNAMES server response (363) is sent at the end of RPL_NAMREPLY
to indicate that there are no more names.
"""
self.proto.makeConnection(self.transport)
self.proto.dataReceived(
":example.com 353 z3p = #bnl :pSwede Dan- SkOyg @MrOp +MrPlus\r\n")
expectedInGroups = {'Dan-': ['bnl'],
'pSwede': ['bnl'],
'SkOyg': ['bnl'],
'MrOp': ['bnl'],
'MrPlus': ['bnl']}
expectedNamReplies = {
'bnl': ['pSwede', 'Dan-', 'SkOyg', 'MrOp', 'MrPlus']}
self.assertEqual(expectedInGroups, self.proto._ingroups)
self.assertEqual(expectedNamReplies, self.proto._namreplies)
self.proto.dataReceived(
":example.com 366 alice #bnl :End of /NAMES list\r\n")
self.assertEqual({}, self.proto._namreplies)
groupConversation = self.proto.getGroupConversation("bnl")
self.assertEqual(expectedNamReplies['bnl'], groupConversation.members)
def test_rplTopic(self):
"""
RPL_TOPIC server response (332) is sent when a channel's topic is changed
"""
self.proto.makeConnection(self.transport)
self.proto.dataReceived(
":example.com 332 alice, #foo :Some random topic\r\n")
self.assertEqual("Some random topic", self.proto._topics["foo"])
def test_sendMessage(self):
"""
L{IRCPerson.sendMessage}
"""
self.proto.makeConnection(self.transport)
person = self.proto.getPerson("alice")
self.assertRaises(OfflineError, person.sendMessage, "Some message")
person.account.client = self.proto
self.transport.clear()
person.sendMessage("Some message 2")
self.assertEqual(self.transport.io.getvalue(),
b'PRIVMSG alice :Some message 2\r\n')
self.transport.clear()
person.sendMessage("smiles", {"style": "emote"})
self.assertEqual(self.transport.io.getvalue(),
b'PRIVMSG alice :\x01ACTION smiles\x01\r\n')
def test_sendGroupMessage(self):
"""
L{IRCGroup.sendGroupMessage}
"""
self.proto.makeConnection(self.transport)
group = self.proto.chat.getGroup("#foo", self.proto)
self.assertRaises(OfflineError, group.sendGroupMessage, "Some message")
group.account.client = self.proto
self.transport.clear()
group.sendGroupMessage("Some message 2")
self.assertEqual(self.transport.io.getvalue(),
b'PRIVMSG #foo :Some message 2\r\n')
self.transport.clear()
group.sendGroupMessage("smiles", {"style": "emote"})
self.assertEqual(self.transport.io.getvalue(),
b'PRIVMSG #foo :\x01ACTION smiles\x01\r\n')

View file

@ -0,0 +1,497 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.words.protocols.jabber.client}
"""
from __future__ import absolute_import, division
from hashlib import sha1
from twisted.internet import defer
from twisted.python.compat import unicode
from twisted.trial import unittest
from twisted.words.protocols.jabber import client, error, jid, xmlstream
from twisted.words.protocols.jabber.sasl import SASLInitiatingInitializer
from twisted.words.xish import utility
try:
from twisted.internet import ssl
except ImportError:
ssl = None
skipWhenNoSSL = "SSL not available"
else:
skipWhenNoSSL = None
IQ_AUTH_GET = '/iq[@type="get"]/query[@xmlns="jabber:iq:auth"]'
IQ_AUTH_SET = '/iq[@type="set"]/query[@xmlns="jabber:iq:auth"]'
NS_BIND = 'urn:ietf:params:xml:ns:xmpp-bind'
IQ_BIND_SET = '/iq[@type="set"]/bind[@xmlns="%s"]' % NS_BIND
NS_SESSION = 'urn:ietf:params:xml:ns:xmpp-session'
IQ_SESSION_SET = '/iq[@type="set"]/session[@xmlns="%s"]' % NS_SESSION
class CheckVersionInitializerTests(unittest.TestCase):
def setUp(self):
a = xmlstream.Authenticator()
xs = xmlstream.XmlStream(a)
self.init = client.CheckVersionInitializer(xs)
def testSupported(self):
"""
Test supported version number 1.0
"""
self.init.xmlstream.version = (1, 0)
self.init.initialize()
def testNotSupported(self):
"""
Test unsupported version number 0.0, and check exception.
"""
self.init.xmlstream.version = (0, 0)
exc = self.assertRaises(error.StreamError, self.init.initialize)
self.assertEqual('unsupported-version', exc.condition)
class InitiatingInitializerHarness(object):
"""
Testing harness for interacting with XML stream initializers.
This sets up an L{utility.XmlPipe} to create a communication channel between
the initializer and the stubbed receiving entity. It features a sink and
source side that both act similarly to a real L{xmlstream.XmlStream}. The
sink is augmented with an authenticator to which initializers can be added.
The harness also provides some utility methods to work with event observers
and deferreds.
"""
def setUp(self):
self.output = []
self.pipe = utility.XmlPipe()
self.xmlstream = self.pipe.sink
self.authenticator = xmlstream.ConnectAuthenticator('example.org')
self.xmlstream.authenticator = self.authenticator
def waitFor(self, event, handler):
"""
Observe an output event, returning a deferred.
The returned deferred will be fired when the given event has been
observed on the source end of the L{XmlPipe} tied to the protocol
under test. The handler is added as the first callback.
@param event: The event to be observed. See
L{utility.EventDispatcher.addOnetimeObserver}.
@param handler: The handler to be called with the observed event object.
@rtype: L{defer.Deferred}.
"""
d = defer.Deferred()
d.addCallback(handler)
self.pipe.source.addOnetimeObserver(event, d.callback)
return d
class IQAuthInitializerTests(InitiatingInitializerHarness, unittest.TestCase):
"""
Tests for L{client.IQAuthInitializer}.
"""
def setUp(self):
super(IQAuthInitializerTests, self).setUp()
self.init = client.IQAuthInitializer(self.xmlstream)
self.authenticator.jid = jid.JID('user@example.com/resource')
self.authenticator.password = u'secret'
def testPlainText(self):
"""
Test plain-text authentication.
Act as a server supporting plain-text authentication and expect the
C{password} field to be filled with the password. Then act as if
authentication succeeds.
"""
def onAuthGet(iq):
"""
Called when the initializer sent a query for authentication methods.
The response informs the client that plain-text authentication
is supported.
"""
# Create server response
response = xmlstream.toResponse(iq, 'result')
response.addElement(('jabber:iq:auth', 'query'))
response.query.addElement('username')
response.query.addElement('password')
response.query.addElement('resource')
# Set up an observer for the next request we expect.
d = self.waitFor(IQ_AUTH_SET, onAuthSet)
# Send server response
self.pipe.source.send(response)
return d
def onAuthSet(iq):
"""
Called when the initializer sent the authentication request.
The server checks the credentials and responds with an empty result
signalling success.
"""
self.assertEqual('user', unicode(iq.query.username))
self.assertEqual('secret', unicode(iq.query.password))
self.assertEqual('resource', unicode(iq.query.resource))
# Send server response
response = xmlstream.toResponse(iq, 'result')
self.pipe.source.send(response)
# Set up an observer for the request for authentication fields
d1 = self.waitFor(IQ_AUTH_GET, onAuthGet)
# Start the initializer
d2 = self.init.initialize()
return defer.gatherResults([d1, d2])
def testDigest(self):
"""
Test digest authentication.
Act as a server supporting digest authentication and expect the
C{digest} field to be filled with a sha1 digest of the concatenated
stream session identifier and password. Then act as if authentication
succeeds.
"""
def onAuthGet(iq):
"""
Called when the initializer sent a query for authentication methods.
The response informs the client that digest authentication is
supported.
"""
# Create server response
response = xmlstream.toResponse(iq, 'result')
response.addElement(('jabber:iq:auth', 'query'))
response.query.addElement('username')
response.query.addElement('digest')
response.query.addElement('resource')
# Set up an observer for the next request we expect.
d = self.waitFor(IQ_AUTH_SET, onAuthSet)
# Send server response
self.pipe.source.send(response)
return d
def onAuthSet(iq):
"""
Called when the initializer sent the authentication request.
The server checks the credentials and responds with an empty result
signalling success.
"""
self.assertEqual('user', unicode(iq.query.username))
self.assertEqual(sha1(b'12345secret').hexdigest(),
unicode(iq.query.digest))
self.assertEqual('resource', unicode(iq.query.resource))
# Send server response
response = xmlstream.toResponse(iq, 'result')
self.pipe.source.send(response)
# Digest authentication relies on the stream session identifier. Set it.
self.xmlstream.sid = u'12345'
# Set up an observer for the request for authentication fields
d1 = self.waitFor(IQ_AUTH_GET, onAuthGet)
# Start the initializer
d2 = self.init.initialize()
return defer.gatherResults([d1, d2])
def testFailRequestFields(self):
"""
Test initializer failure of request for fields for authentication.
"""
def onAuthGet(iq):
"""
Called when the initializer sent a query for authentication methods.
The server responds that the client is not authorized to authenticate.
"""
response = error.StanzaError('not-authorized').toResponse(iq)
self.pipe.source.send(response)
# Set up an observer for the request for authentication fields
d1 = self.waitFor(IQ_AUTH_GET, onAuthGet)
# Start the initializer
d2 = self.init.initialize()
# The initialized should fail with a stanza error.
self.assertFailure(d2, error.StanzaError)
return defer.gatherResults([d1, d2])
def testFailAuth(self):
"""
Test initializer failure to authenticate.
"""
def onAuthGet(iq):
"""
Called when the initializer sent a query for authentication methods.
The response informs the client that plain-text authentication
is supported.
"""
# Send server response
response = xmlstream.toResponse(iq, 'result')
response.addElement(('jabber:iq:auth', 'query'))
response.query.addElement('username')
response.query.addElement('password')
response.query.addElement('resource')
# Set up an observer for the next request we expect.
d = self.waitFor(IQ_AUTH_SET, onAuthSet)
# Send server response
self.pipe.source.send(response)
return d
def onAuthSet(iq):
"""
Called when the initializer sent the authentication request.
The server checks the credentials and responds with a not-authorized
stanza error.
"""
response = error.StanzaError('not-authorized').toResponse(iq)
self.pipe.source.send(response)
# Set up an observer for the request for authentication fields
d1 = self.waitFor(IQ_AUTH_GET, onAuthGet)
# Start the initializer
d2 = self.init.initialize()
# The initializer should fail with a stanza error.
self.assertFailure(d2, error.StanzaError)
return defer.gatherResults([d1, d2])
class BindInitializerTests(InitiatingInitializerHarness, unittest.TestCase):
"""
Tests for L{client.BindInitializer}.
"""
def setUp(self):
super(BindInitializerTests, self).setUp()
self.init = client.BindInitializer(self.xmlstream)
self.authenticator.jid = jid.JID('user@example.com/resource')
def testBasic(self):
"""
Set up a stream, and act as if resource binding succeeds.
"""
def onBind(iq):
response = xmlstream.toResponse(iq, 'result')
response.addElement((NS_BIND, 'bind'))
response.bind.addElement('jid',
content=u'user@example.com/other resource')
self.pipe.source.send(response)
def cb(result):
self.assertEqual(jid.JID('user@example.com/other resource'),
self.authenticator.jid)
d1 = self.waitFor(IQ_BIND_SET, onBind)
d2 = self.init.start()
d2.addCallback(cb)
return defer.gatherResults([d1, d2])
def testFailure(self):
"""
Set up a stream, and act as if resource binding fails.
"""
def onBind(iq):
response = error.StanzaError('conflict').toResponse(iq)
self.pipe.source.send(response)
d1 = self.waitFor(IQ_BIND_SET, onBind)
d2 = self.init.start()
self.assertFailure(d2, error.StanzaError)
return defer.gatherResults([d1, d2])
class SessionInitializerTests(InitiatingInitializerHarness, unittest.TestCase):
"""
Tests for L{client.SessionInitializer}.
"""
def setUp(self):
super(SessionInitializerTests, self).setUp()
self.init = client.SessionInitializer(self.xmlstream)
def testSuccess(self):
"""
Set up a stream, and act as if session establishment succeeds.
"""
def onSession(iq):
response = xmlstream.toResponse(iq, 'result')
self.pipe.source.send(response)
d1 = self.waitFor(IQ_SESSION_SET, onSession)
d2 = self.init.start()
return defer.gatherResults([d1, d2])
def testFailure(self):
"""
Set up a stream, and act as if session establishment fails.
"""
def onSession(iq):
response = error.StanzaError('forbidden').toResponse(iq)
self.pipe.source.send(response)
d1 = self.waitFor(IQ_SESSION_SET, onSession)
d2 = self.init.start()
self.assertFailure(d2, error.StanzaError)
return defer.gatherResults([d1, d2])
class BasicAuthenticatorTests(unittest.TestCase):
"""
Test for both BasicAuthenticator and basicClientFactory.
"""
def test_basic(self):
"""
Authenticator and stream are properly constructed by the factory.
The L{xmlstream.XmlStream} protocol created by the factory has the new
L{client.BasicAuthenticator} instance in its C{authenticator}
attribute. It is set up with C{jid} and C{password} as passed to the
factory, C{otherHost} taken from the client JID. The stream futher has
two initializers, for TLS and authentication, of which the first has
its C{required} attribute set to C{True}.
"""
self.client_jid = jid.JID('user@example.com/resource')
# Get an XmlStream instance. Note that it gets initialized with the
# XMPPAuthenticator (that has its associateWithXmlStream called) that
# is in turn initialized with the arguments to the factory.
xs = client.basicClientFactory(self.client_jid,
'secret').buildProtocol(None)
# test authenticator's instance variables
self.assertEqual('example.com', xs.authenticator.otherHost)
self.assertEqual(self.client_jid, xs.authenticator.jid)
self.assertEqual('secret', xs.authenticator.password)
# test list of initializers
tls, auth = xs.initializers
self.assertIsInstance(tls, xmlstream.TLSInitiatingInitializer)
self.assertIsInstance(auth, client.IQAuthInitializer)
self.assertFalse(tls.required)
class XMPPAuthenticatorTests(unittest.TestCase):
"""
Test for both XMPPAuthenticator and XMPPClientFactory.
"""
def test_basic(self):
"""
Test basic operations.
Setup an XMPPClientFactory, which sets up an XMPPAuthenticator, and let
it produce a protocol instance. Then inspect the instance variables of
the authenticator and XML stream objects.
"""
self.client_jid = jid.JID('user@example.com/resource')
# Get an XmlStream instance. Note that it gets initialized with the
# XMPPAuthenticator (that has its associateWithXmlStream called) that
# is in turn initialized with the arguments to the factory.
xs = client.XMPPClientFactory(self.client_jid,
'secret').buildProtocol(None)
# test authenticator's instance variables
self.assertEqual('example.com', xs.authenticator.otherHost)
self.assertEqual(self.client_jid, xs.authenticator.jid)
self.assertEqual('secret', xs.authenticator.password)
# test list of initializers
version, tls, sasl, bind, session = xs.initializers
self.assertIsInstance(tls, xmlstream.TLSInitiatingInitializer)
self.assertIsInstance(sasl, SASLInitiatingInitializer)
self.assertIsInstance(bind, client.BindInitializer)
self.assertIsInstance(session, client.SessionInitializer)
self.assertTrue(tls.required)
self.assertTrue(sasl.required)
self.assertTrue(bind.required)
self.assertFalse(session.required)
def test_tlsConfiguration(self):
"""
A TLS configuration is passed to the TLS initializer.
"""
configs = []
def init(self, xs, required=True, configurationForTLS=None):
configs.append(configurationForTLS)
self.client_jid = jid.JID('user@example.com/resource')
# Get an XmlStream instance. Note that it gets initialized with the
# XMPPAuthenticator (that has its associateWithXmlStream called) that
# is in turn initialized with the arguments to the factory.
configurationForTLS = ssl.CertificateOptions()
factory = client.XMPPClientFactory(
self.client_jid, 'secret',
configurationForTLS=configurationForTLS)
self.patch(xmlstream.TLSInitiatingInitializer, "__init__", init)
xs = factory.buildProtocol(None)
# test list of initializers
version, tls, sasl, bind, session = xs.initializers
self.assertIsInstance(tls, xmlstream.TLSInitiatingInitializer)
self.assertIs(configurationForTLS, configs[0])
test_tlsConfiguration.skip = skipWhenNoSSL

View file

@ -0,0 +1,440 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.words.protocols.jabber.component}
"""
from hashlib import sha1
from zope.interface.verify import verifyObject
from twisted.python import failure
from twisted.python.compat import unicode
from twisted.trial import unittest
from twisted.words.protocols.jabber import component, ijabber, xmlstream
from twisted.words.protocols.jabber.jid import JID
from twisted.words.xish import domish
from twisted.words.xish.utility import XmlPipe
class DummyTransport:
def __init__(self, list):
self.list = list
def write(self, bytes):
self.list.append(bytes)
class ComponentInitiatingInitializerTests(unittest.TestCase):
def setUp(self):
self.output = []
self.authenticator = xmlstream.Authenticator()
self.authenticator.password = u'secret'
self.xmlstream = xmlstream.XmlStream(self.authenticator)
self.xmlstream.namespace = 'test:component'
self.xmlstream.send = self.output.append
self.xmlstream.connectionMade()
self.xmlstream.dataReceived(
"<stream:stream xmlns='test:component' "
"xmlns:stream='http://etherx.jabber.org/streams' "
"from='example.com' id='12345' version='1.0'>")
self.xmlstream.sid = u'12345'
self.init = component.ComponentInitiatingInitializer(self.xmlstream)
def testHandshake(self):
"""
Test basic operations of component handshake.
"""
d = self.init.initialize()
# the initializer should have sent the handshake request
handshake = self.output[-1]
self.assertEqual('handshake', handshake.name)
self.assertEqual('test:component', handshake.uri)
self.assertEqual(sha1(b'12345' + b'secret').hexdigest(),
unicode(handshake))
# successful authentication
handshake.children = []
self.xmlstream.dataReceived(handshake.toXml())
return d
class ComponentAuthTests(unittest.TestCase):
def authPassed(self, stream):
self.authComplete = True
def testAuth(self):
self.authComplete = False
outlist = []
ca = component.ConnectComponentAuthenticator(u"cjid", u"secret")
xs = xmlstream.XmlStream(ca)
xs.transport = DummyTransport(outlist)
xs.addObserver(xmlstream.STREAM_AUTHD_EVENT,
self.authPassed)
# Go...
xs.connectionMade()
xs.dataReceived(b"<stream:stream xmlns='jabber:component:accept' xmlns:stream='http://etherx.jabber.org/streams' from='cjid' id='12345'>")
# Calculate what we expect the handshake value to be
hv = sha1(b"12345" + b"secret").hexdigest().encode('ascii')
self.assertEqual(outlist[1], b"<handshake>" + hv + b"</handshake>")
xs.dataReceived("<handshake/>")
self.assertEqual(self.authComplete, True)
class ServiceTests(unittest.TestCase):
"""
Tests for L{component.Service}.
"""
def test_interface(self):
"""
L{component.Service} implements L{ijabber.IService}.
"""
service = component.Service()
verifyObject(ijabber.IService, service)
class JabberServiceHarness(component.Service):
def __init__(self):
self.componentConnectedFlag = False
self.componentDisconnectedFlag = False
self.transportConnectedFlag = False
def componentConnected(self, xmlstream):
self.componentConnectedFlag = True
def componentDisconnected(self):
self.componentDisconnectedFlag = True
def transportConnected(self, xmlstream):
self.transportConnectedFlag = True
class JabberServiceManagerTests(unittest.TestCase):
def testSM(self):
# Setup service manager and test harnes
sm = component.ServiceManager("foo", "password")
svc = JabberServiceHarness()
svc.setServiceParent(sm)
# Create a write list
wlist = []
# Setup a XmlStream
xs = sm.getFactory().buildProtocol(None)
xs.transport = self
xs.transport.write = wlist.append
# Indicate that it's connected
xs.connectionMade()
# Ensure the test service harness got notified
self.assertEqual(True, svc.transportConnectedFlag)
# Jump ahead and pretend like the stream got auth'd
xs.dispatch(xs, xmlstream.STREAM_AUTHD_EVENT)
# Ensure the test service harness got notified
self.assertEqual(True, svc.componentConnectedFlag)
# Pretend to drop the connection
xs.connectionLost(None)
# Ensure the test service harness got notified
self.assertEqual(True, svc.componentDisconnectedFlag)
class RouterTests(unittest.TestCase):
"""
Tests for L{component.Router}.
"""
def test_addRoute(self):
"""
Test route registration and routing on incoming stanzas.
"""
router = component.Router()
routed = []
router.route = lambda element: routed.append(element)
pipe = XmlPipe()
router.addRoute('example.org', pipe.sink)
self.assertEqual(1, len(router.routes))
self.assertEqual(pipe.sink, router.routes['example.org'])
element = domish.Element(('testns', 'test'))
pipe.source.send(element)
self.assertEqual([element], routed)
def test_route(self):
"""
Test routing of a message.
"""
component1 = XmlPipe()
component2 = XmlPipe()
router = component.Router()
router.addRoute('component1.example.org', component1.sink)
router.addRoute('component2.example.org', component2.sink)
outgoing = []
component2.source.addObserver('/*',
lambda element: outgoing.append(element))
stanza = domish.Element((None, 'presence'))
stanza['from'] = 'component1.example.org'
stanza['to'] = 'component2.example.org'
component1.source.send(stanza)
self.assertEqual([stanza], outgoing)
def test_routeDefault(self):
"""
Test routing of a message using the default route.
The default route is the one with L{None} as its key in the
routing table. It is taken when there is no more specific route
in the routing table that matches the stanza's destination.
"""
component1 = XmlPipe()
s2s = XmlPipe()
router = component.Router()
router.addRoute('component1.example.org', component1.sink)
router.addRoute(None, s2s.sink)
outgoing = []
s2s.source.addObserver('/*', lambda element: outgoing.append(element))
stanza = domish.Element((None, 'presence'))
stanza['from'] = 'component1.example.org'
stanza['to'] = 'example.com'
component1.source.send(stanza)
self.assertEqual([stanza], outgoing)
class ListenComponentAuthenticatorTests(unittest.TestCase):
"""
Tests for L{component.ListenComponentAuthenticator}.
"""
def setUp(self):
self.output = []
authenticator = component.ListenComponentAuthenticator('secret')
self.xmlstream = xmlstream.XmlStream(authenticator)
self.xmlstream.send = self.output.append
def loseConnection(self):
"""
Stub loseConnection because we are a transport.
"""
self.xmlstream.connectionLost("no reason")
def test_streamStarted(self):
"""
The received stream header should set several attributes.
"""
observers = []
def addOnetimeObserver(event, observerfn):
observers.append((event, observerfn))
xs = self.xmlstream
xs.addOnetimeObserver = addOnetimeObserver
xs.makeConnection(self)
self.assertIdentical(None, xs.sid)
self.assertFalse(xs._headerSent)
xs.dataReceived("<stream:stream xmlns='jabber:component:accept' "
"xmlns:stream='http://etherx.jabber.org/streams' "
"to='component.example.org'>")
self.assertEqual((0, 0), xs.version)
self.assertNotIdentical(None, xs.sid)
self.assertTrue(xs._headerSent)
self.assertEqual(('/*', xs.authenticator.onElement), observers[-1])
def test_streamStartedWrongNamespace(self):
"""
The received stream header should have a correct namespace.
"""
streamErrors = []
xs = self.xmlstream
xs.sendStreamError = streamErrors.append
xs.makeConnection(self)
xs.dataReceived("<stream:stream xmlns='jabber:client' "
"xmlns:stream='http://etherx.jabber.org/streams' "
"to='component.example.org'>")
self.assertEqual(1, len(streamErrors))
self.assertEqual('invalid-namespace', streamErrors[-1].condition)
def test_streamStartedNoTo(self):
"""
The received stream header should have a 'to' attribute.
"""
streamErrors = []
xs = self.xmlstream
xs.sendStreamError = streamErrors.append
xs.makeConnection(self)
xs.dataReceived("<stream:stream xmlns='jabber:component:accept' "
"xmlns:stream='http://etherx.jabber.org/streams'>")
self.assertEqual(1, len(streamErrors))
self.assertEqual('improper-addressing', streamErrors[-1].condition)
def test_onElement(self):
"""
We expect a handshake element with a hash.
"""
handshakes = []
xs = self.xmlstream
xs.authenticator.onHandshake = handshakes.append
handshake = domish.Element(('jabber:component:accept', 'handshake'))
handshake.addContent(u'1234')
xs.authenticator.onElement(handshake)
self.assertEqual('1234', handshakes[-1])
def test_onElementNotHandshake(self):
"""
Reject elements that are not handshakes
"""
handshakes = []
streamErrors = []
xs = self.xmlstream
xs.authenticator.onHandshake = handshakes.append
xs.sendStreamError = streamErrors.append
element = domish.Element(('jabber:component:accept', 'message'))
xs.authenticator.onElement(element)
self.assertFalse(handshakes)
self.assertEqual('not-authorized', streamErrors[-1].condition)
def test_onHandshake(self):
"""
Receiving a handshake matching the secret authenticates the stream.
"""
authd = []
def authenticated(xs):
authd.append(xs)
xs = self.xmlstream
xs.addOnetimeObserver(xmlstream.STREAM_AUTHD_EVENT, authenticated)
xs.sid = u'1234'
theHash = '32532c0f7dbf1253c095b18b18e36d38d94c1256'
xs.authenticator.onHandshake(theHash)
self.assertEqual('<handshake/>', self.output[-1])
self.assertEqual(1, len(authd))
def test_onHandshakeWrongHash(self):
"""
Receiving a bad handshake should yield a stream error.
"""
streamErrors = []
authd = []
def authenticated(xs):
authd.append(xs)
xs = self.xmlstream
xs.addOnetimeObserver(xmlstream.STREAM_AUTHD_EVENT, authenticated)
xs.sendStreamError = streamErrors.append
xs.sid = u'1234'
theHash = '1234'
xs.authenticator.onHandshake(theHash)
self.assertEqual('not-authorized', streamErrors[-1].condition)
self.assertEqual(0, len(authd))
class XMPPComponentServerFactoryTests(unittest.TestCase):
"""
Tests for L{component.XMPPComponentServerFactory}.
"""
def setUp(self):
self.router = component.Router()
self.factory = component.XMPPComponentServerFactory(self.router,
'secret')
self.xmlstream = self.factory.buildProtocol(None)
self.xmlstream.thisEntity = JID('component.example.org')
def test_makeConnection(self):
"""
A new connection increases the stream serial count. No logs by default.
"""
self.xmlstream.dispatch(self.xmlstream,
xmlstream.STREAM_CONNECTED_EVENT)
self.assertEqual(0, self.xmlstream.serial)
self.assertEqual(1, self.factory.serial)
self.assertIdentical(None, self.xmlstream.rawDataInFn)
self.assertIdentical(None, self.xmlstream.rawDataOutFn)
def test_makeConnectionLogTraffic(self):
"""
Setting logTraffic should set up raw data loggers.
"""
self.factory.logTraffic = True
self.xmlstream.dispatch(self.xmlstream,
xmlstream.STREAM_CONNECTED_EVENT)
self.assertNotIdentical(None, self.xmlstream.rawDataInFn)
self.assertNotIdentical(None, self.xmlstream.rawDataOutFn)
def test_onError(self):
"""
An observer for stream errors should trigger onError to log it.
"""
self.xmlstream.dispatch(self.xmlstream,
xmlstream.STREAM_CONNECTED_EVENT)
class TestError(Exception):
pass
reason = failure.Failure(TestError())
self.xmlstream.dispatch(reason, xmlstream.STREAM_ERROR_EVENT)
self.assertEqual(1, len(self.flushLoggedErrors(TestError)))
def test_connectionInitialized(self):
"""
Make sure a new stream is added to the routing table.
"""
self.xmlstream.dispatch(self.xmlstream, xmlstream.STREAM_AUTHD_EVENT)
self.assertIn('component.example.org', self.router.routes)
self.assertIdentical(self.xmlstream,
self.router.routes['component.example.org'])
def test_connectionLost(self):
"""
Make sure a stream is removed from the routing table on disconnect.
"""
self.xmlstream.dispatch(self.xmlstream, xmlstream.STREAM_AUTHD_EVENT)
self.xmlstream.dispatch(None, xmlstream.STREAM_END_EVENT)
self.assertNotIn('component.example.org', self.router.routes)

View file

@ -0,0 +1,333 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.words.protocols.jabber.error}.
"""
from __future__ import absolute_import, division
from twisted.python.compat import unicode
from twisted.trial import unittest
from twisted.words.protocols.jabber import error
from twisted.words.xish import domish
NS_XML = 'http://www.w3.org/XML/1998/namespace'
NS_STREAMS = 'http://etherx.jabber.org/streams'
NS_XMPP_STREAMS = 'urn:ietf:params:xml:ns:xmpp-streams'
NS_XMPP_STANZAS = 'urn:ietf:params:xml:ns:xmpp-stanzas'
class BaseErrorTests(unittest.TestCase):
def test_getElementPlain(self):
"""
Test getting an element for a plain error.
"""
e = error.BaseError('feature-not-implemented')
element = e.getElement()
self.assertIdentical(element.uri, None)
self.assertEqual(len(element.children), 1)
def test_getElementText(self):
"""
Test getting an element for an error with a text.
"""
e = error.BaseError('feature-not-implemented', u'text')
element = e.getElement()
self.assertEqual(len(element.children), 2)
self.assertEqual(unicode(element.text), 'text')
self.assertEqual(element.text.getAttribute((NS_XML, 'lang')), None)
def test_getElementTextLang(self):
"""
Test getting an element for an error with a text and language.
"""
e = error.BaseError('feature-not-implemented', u'text', 'en_US')
element = e.getElement()
self.assertEqual(len(element.children), 2)
self.assertEqual(unicode(element.text), 'text')
self.assertEqual(element.text[(NS_XML, 'lang')], 'en_US')
def test_getElementAppCondition(self):
"""
Test getting an element for an error with an app specific condition.
"""
ac = domish.Element(('testns', 'myerror'))
e = error.BaseError('feature-not-implemented', appCondition=ac)
element = e.getElement()
self.assertEqual(len(element.children), 2)
self.assertEqual(element.myerror, ac)
class StreamErrorTests(unittest.TestCase):
def test_getElementPlain(self):
"""
Test namespace of the element representation of an error.
"""
e = error.StreamError('feature-not-implemented')
element = e.getElement()
self.assertEqual(element.uri, NS_STREAMS)
def test_getElementConditionNamespace(self):
"""
Test that the error condition element has the correct namespace.
"""
e = error.StreamError('feature-not-implemented')
element = e.getElement()
self.assertEqual(NS_XMPP_STREAMS, getattr(element, 'feature-not-implemented').uri)
def test_getElementTextNamespace(self):
"""
Test that the error text element has the correct namespace.
"""
e = error.StreamError('feature-not-implemented', u'text')
element = e.getElement()
self.assertEqual(NS_XMPP_STREAMS, element.text.uri)
class StanzaErrorTests(unittest.TestCase):
"""
Tests for L{error.StreamError}.
"""
def test_typeRemoteServerTimeout(self):
"""
Remote Server Timeout should yield type wait, code 504.
"""
e = error.StanzaError('remote-server-timeout')
self.assertEqual('wait', e.type)
self.assertEqual('504', e.code)
def test_getElementPlain(self):
"""
Test getting an element for a plain stanza error.
"""
e = error.StanzaError('feature-not-implemented')
element = e.getElement()
self.assertEqual(element.uri, None)
self.assertEqual(element['type'], 'cancel')
self.assertEqual(element['code'], '501')
def test_getElementType(self):
"""
Test getting an element for a stanza error with a given type.
"""
e = error.StanzaError('feature-not-implemented', 'auth')
element = e.getElement()
self.assertEqual(element.uri, None)
self.assertEqual(element['type'], 'auth')
self.assertEqual(element['code'], '501')
def test_getElementConditionNamespace(self):
"""
Test that the error condition element has the correct namespace.
"""
e = error.StanzaError('feature-not-implemented')
element = e.getElement()
self.assertEqual(NS_XMPP_STANZAS, getattr(element, 'feature-not-implemented').uri)
def test_getElementTextNamespace(self):
"""
Test that the error text element has the correct namespace.
"""
e = error.StanzaError('feature-not-implemented', text=u'text')
element = e.getElement()
self.assertEqual(NS_XMPP_STANZAS, element.text.uri)
def test_toResponse(self):
"""
Test an error response is generated from a stanza.
The addressing on the (new) response stanza should be reversed, an
error child (with proper properties) added and the type set to
C{'error'}.
"""
stanza = domish.Element(('jabber:client', 'message'))
stanza['type'] = 'chat'
stanza['to'] = 'user1@example.com'
stanza['from'] = 'user2@example.com/resource'
e = error.StanzaError('service-unavailable')
response = e.toResponse(stanza)
self.assertNotIdentical(response, stanza)
self.assertEqual(response['from'], 'user1@example.com')
self.assertEqual(response['to'], 'user2@example.com/resource')
self.assertEqual(response['type'], 'error')
self.assertEqual(response.error.children[0].name,
'service-unavailable')
self.assertEqual(response.error['type'], 'cancel')
self.assertNotEqual(stanza.children, response.children)
class ParseErrorTests(unittest.TestCase):
"""
Tests for L{error._parseError}.
"""
def setUp(self):
self.error = domish.Element((None, 'error'))
def test_empty(self):
"""
Test parsing of the empty error element.
"""
result = error._parseError(self.error, 'errorns')
self.assertEqual({'condition': None,
'text': None,
'textLang': None,
'appCondition': None}, result)
def test_condition(self):
"""
Test parsing of an error element with a condition.
"""
self.error.addElement(('errorns', 'bad-request'))
result = error._parseError(self.error, 'errorns')
self.assertEqual('bad-request', result['condition'])
def test_text(self):
"""
Test parsing of an error element with a text.
"""
text = self.error.addElement(('errorns', 'text'))
text.addContent(u'test')
result = error._parseError(self.error, 'errorns')
self.assertEqual('test', result['text'])
self.assertEqual(None, result['textLang'])
def test_textLang(self):
"""
Test parsing of an error element with a text with a defined language.
"""
text = self.error.addElement(('errorns', 'text'))
text[NS_XML, 'lang'] = 'en_US'
text.addContent(u'test')
result = error._parseError(self.error, 'errorns')
self.assertEqual('en_US', result['textLang'])
def test_appCondition(self):
"""
Test parsing of an error element with an app specific condition.
"""
condition = self.error.addElement(('testns', 'condition'))
result = error._parseError(self.error, 'errorns')
self.assertEqual(condition, result['appCondition'])
def test_appConditionMultiple(self):
"""
Test parsing of an error element with multiple app specific conditions.
"""
self.error.addElement(('testns', 'condition'))
condition = self.error.addElement(('testns', 'condition2'))
result = error._parseError(self.error, 'errorns')
self.assertEqual(condition, result['appCondition'])
class ExceptionFromStanzaTests(unittest.TestCase):
def test_basic(self):
"""
Test basic operations of exceptionFromStanza.
Given a realistic stanza, check if a sane exception is returned.
Using this stanza::
<iq type='error'
from='pubsub.shakespeare.lit'
to='francisco@denmark.lit/barracks'
id='subscriptions1'>
<pubsub xmlns='http://jabber.org/protocol/pubsub'>
<subscriptions/>
</pubsub>
<error type='cancel'>
<feature-not-implemented
xmlns='urn:ietf:params:xml:ns:xmpp-stanzas'/>
<unsupported xmlns='http://jabber.org/protocol/pubsub#errors'
feature='retrieve-subscriptions'/>
</error>
</iq>
"""
stanza = domish.Element((None, 'stanza'))
p = stanza.addElement(('http://jabber.org/protocol/pubsub', 'pubsub'))
p.addElement('subscriptions')
e = stanza.addElement('error')
e['type'] = 'cancel'
e.addElement((NS_XMPP_STANZAS, 'feature-not-implemented'))
uc = e.addElement(('http://jabber.org/protocol/pubsub#errors',
'unsupported'))
uc['feature'] = 'retrieve-subscriptions'
result = error.exceptionFromStanza(stanza)
self.assertIsInstance(result, error.StanzaError)
self.assertEqual('feature-not-implemented', result.condition)
self.assertEqual('cancel', result.type)
self.assertEqual(uc, result.appCondition)
self.assertEqual([p], result.children)
def test_legacy(self):
"""
Test legacy operations of exceptionFromStanza.
Given a realistic stanza with only legacy (pre-XMPP) error information,
check if a sane exception is returned.
Using this stanza::
<message type='error'
to='piers@pipetree.com/Home'
from='qmacro@jaber.org'>
<body>Are you there?</body>
<error code='502'>Unable to resolve hostname.</error>
</message>
"""
stanza = domish.Element((None, 'stanza'))
p = stanza.addElement('body', content=u'Are you there?')
e = stanza.addElement('error', content=u'Unable to resolve hostname.')
e['code'] = '502'
result = error.exceptionFromStanza(stanza)
self.assertIsInstance(result, error.StanzaError)
self.assertEqual('service-unavailable', result.condition)
self.assertEqual('wait', result.type)
self.assertEqual('Unable to resolve hostname.', result.text)
self.assertEqual([p], result.children)
class ExceptionFromStreamErrorTests(unittest.TestCase):
def test_basic(self):
"""
Test basic operations of exceptionFromStreamError.
Given a realistic stream error, check if a sane exception is returned.
Using this error::
<stream:error xmlns:stream='http://etherx.jabber.org/streams'>
<xml-not-well-formed xmlns='urn:ietf:params:xml:ns:xmpp-streams'/>
</stream:error>
"""
e = domish.Element(('http://etherx.jabber.org/streams', 'error'))
e.addElement((NS_XMPP_STREAMS, 'xml-not-well-formed'))
result = error.exceptionFromStreamError(e)
self.assertIsInstance(result, error.StreamError)
self.assertEqual('xml-not-well-formed', result.condition)

View file

@ -0,0 +1,226 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.words.protocols.jabber.jid}.
"""
from twisted.python.compat import unicode
from twisted.trial import unittest
from twisted.words.protocols.jabber import jid
class JIDParsingTests(unittest.TestCase):
def test_parse(self):
"""
Test different forms of JIDs.
"""
# Basic forms
self.assertEqual(jid.parse("user@host/resource"),
("user", "host", "resource"))
self.assertEqual(jid.parse("user@host"),
("user", "host", None))
self.assertEqual(jid.parse("host"),
(None, "host", None))
self.assertEqual(jid.parse("host/resource"),
(None, "host", "resource"))
# More interesting forms
self.assertEqual(jid.parse("foo/bar@baz"),
(None, "foo", "bar@baz"))
self.assertEqual(jid.parse("boo@foo/bar@baz"),
("boo", "foo", "bar@baz"))
self.assertEqual(jid.parse("boo@foo/bar/baz"),
("boo", "foo", "bar/baz"))
self.assertEqual(jid.parse("boo/foo@bar@baz"),
(None, "boo", "foo@bar@baz"))
self.assertEqual(jid.parse("boo/foo/bar"),
(None, "boo", "foo/bar"))
self.assertEqual(jid.parse("boo//foo"),
(None, "boo", "/foo"))
def test_noHost(self):
"""
Test for failure on no host part.
"""
self.assertRaises(jid.InvalidFormat, jid.parse, "user@")
def test_doubleAt(self):
"""
Test for failure on double @ signs.
This should fail because @ is not a valid character for the host
part of the JID.
"""
self.assertRaises(jid.InvalidFormat, jid.parse, "user@@host")
def test_multipleAt(self):
"""
Test for failure on two @ signs.
This should fail because @ is not a valid character for the host
part of the JID.
"""
self.assertRaises(jid.InvalidFormat, jid.parse, "user@host@host")
# Basic tests for case mapping. These are fallback tests for the
# prepping done in twisted.words.protocols.jabber.xmpp_stringprep
def test_prepCaseMapUser(self):
"""
Test case mapping of the user part of the JID.
"""
self.assertEqual(jid.prep("UsEr", "host", "resource"),
("user", "host", "resource"))
def test_prepCaseMapHost(self):
"""
Test case mapping of the host part of the JID.
"""
self.assertEqual(jid.prep("user", "hoST", "resource"),
("user", "host", "resource"))
def test_prepNoCaseMapResource(self):
"""
Test no case mapping of the resourcce part of the JID.
"""
self.assertEqual(jid.prep("user", "hoST", "resource"),
("user", "host", "resource"))
self.assertNotEqual(jid.prep("user", "host", "Resource"),
("user", "host", "resource"))
class JIDTests(unittest.TestCase):
def test_noneArguments(self):
"""
Test that using no arguments raises an exception.
"""
self.assertRaises(RuntimeError, jid.JID)
def test_attributes(self):
"""
Test that the attributes correspond with the JID parts.
"""
j = jid.JID("user@host/resource")
self.assertEqual(j.user, "user")
self.assertEqual(j.host, "host")
self.assertEqual(j.resource, "resource")
def test_userhost(self):
"""
Test the extraction of the bare JID.
"""
j = jid.JID("user@host/resource")
self.assertEqual("user@host", j.userhost())
def test_userhostOnlyHost(self):
"""
Test the extraction of the bare JID of the full form host/resource.
"""
j = jid.JID("host/resource")
self.assertEqual("host", j.userhost())
def test_userhostJID(self):
"""
Test getting a JID object of the bare JID.
"""
j1 = jid.JID("user@host/resource")
j2 = jid.internJID("user@host")
self.assertIdentical(j2, j1.userhostJID())
def test_userhostJIDNoResource(self):
"""
Test getting a JID object of the bare JID when there was no resource.
"""
j = jid.JID("user@host")
self.assertIdentical(j, j.userhostJID())
def test_fullHost(self):
"""
Test giving a string representation of the JID with only a host part.
"""
j = jid.JID(tuple=(None, 'host', None))
self.assertEqual('host', j.full())
def test_fullHostResource(self):
"""
Test giving a string representation of the JID with host, resource.
"""
j = jid.JID(tuple=(None, 'host', 'resource'))
self.assertEqual('host/resource', j.full())
def test_fullUserHost(self):
"""
Test giving a string representation of the JID with user, host.
"""
j = jid.JID(tuple=('user', 'host', None))
self.assertEqual('user@host', j.full())
def test_fullAll(self):
"""
Test giving a string representation of the JID.
"""
j = jid.JID(tuple=('user', 'host', 'resource'))
self.assertEqual('user@host/resource', j.full())
def test_equality(self):
"""
Test JID equality.
"""
j1 = jid.JID("user@host/resource")
j2 = jid.JID("user@host/resource")
self.assertNotIdentical(j1, j2)
self.assertEqual(j1, j2)
def test_equalityWithNonJIDs(self):
"""
Test JID equality.
"""
j = jid.JID("user@host/resource")
self.assertFalse(j == 'user@host/resource')
def test_inequality(self):
"""
Test JID inequality.
"""
j1 = jid.JID("user1@host/resource")
j2 = jid.JID("user2@host/resource")
self.assertNotEqual(j1, j2)
def test_inequalityWithNonJIDs(self):
"""
Test JID equality.
"""
j = jid.JID("user@host/resource")
self.assertNotEqual(j, 'user@host/resource')
def test_hashable(self):
"""
Test JID hashability.
"""
j1 = jid.JID("user@host/resource")
j2 = jid.JID("user@host/resource")
self.assertEqual(hash(j1), hash(j2))
def test_unicode(self):
"""
Test unicode representation of JIDs.
"""
j = jid.JID(tuple=('user', 'host', 'resource'))
self.assertEqual(u"user@host/resource", unicode(j))
def test_repr(self):
"""
Test representation of JID objects.
"""
j = jid.JID(tuple=('user', 'host', 'resource'))
self.assertEqual("JID(%s)" % repr(u'user@host/resource'), repr(j))
class InternJIDTests(unittest.TestCase):
def test_identity(self):
"""
Test that two interned JIDs yield the same object.
"""
j1 = jid.internJID("user@host")
j2 = jid.internJID("user@host")
self.assertIdentical(j1, j2)

View file

@ -0,0 +1,36 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.words.protocols.jabber.jstrports}.
"""
from __future__ import absolute_import, division
from twisted.trial import unittest
from twisted.words.protocols.jabber import jstrports
from twisted.application.internet import TCPClient
class JabberStrPortsPlaceHolderTests(unittest.TestCase):
"""
Tests for L{jstrports}
"""
def test_parse(self):
"""
L{jstrports.parse} accepts an endpoint description string and returns a
tuple and dict of parsed endpoint arguments.
"""
expected = ('TCP', ('DOMAIN', 65535, 'Factory'), {})
got = jstrports.parse("tcp:DOMAIN:65535", "Factory")
self.assertEqual(expected, got)
def test_client(self):
"""
L{jstrports.client} returns a L{TCPClient} service.
"""
got = jstrports.client("tcp:DOMAIN:65535", "Factory")
self.assertIsInstance(got, TCPClient)

View file

@ -0,0 +1,292 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
from __future__ import absolute_import, division
from zope.interface import implementer
from twisted.internet import defer
from twisted.python.compat import unicode
from twisted.trial import unittest
from twisted.words.protocols.jabber import sasl, sasl_mechanisms, xmlstream, jid
from twisted.words.xish import domish
NS_XMPP_SASL = 'urn:ietf:params:xml:ns:xmpp-sasl'
@implementer(sasl_mechanisms.ISASLMechanism)
class DummySASLMechanism(object):
"""
Dummy SASL mechanism.
This just returns the initialResponse passed on creation, stores any
challenges and replies with the value of C{response}.
@ivar challenge: Last received challenge.
@type challenge: C{unicode}.
@ivar initialResponse: Initial response to be returned when requested
via C{getInitialResponse} or L{None}.
@type initialResponse: C{unicode}
"""
challenge = None
name = u"DUMMY"
response = b""
def __init__(self, initialResponse):
self.initialResponse = initialResponse
def getInitialResponse(self):
return self.initialResponse
def getResponse(self, challenge):
self.challenge = challenge
return self.response
class DummySASLInitiatingInitializer(sasl.SASLInitiatingInitializer):
"""
Dummy SASL Initializer for initiating entities.
This hardwires the SASL mechanism to L{DummySASLMechanism}, that is
instantiated with the value of C{initialResponse}.
@ivar initialResponse: The initial response to be returned by the
dummy SASL mechanism or L{None}.
@type initialResponse: C{unicode}.
"""
initialResponse = None
def setMechanism(self):
self.mechanism = DummySASLMechanism(self.initialResponse)
class SASLInitiatingInitializerTests(unittest.TestCase):
"""
Tests for L{sasl.SASLInitiatingInitializer}
"""
def setUp(self):
self.output = []
self.authenticator = xmlstream.Authenticator()
self.xmlstream = xmlstream.XmlStream(self.authenticator)
self.xmlstream.send = self.output.append
self.xmlstream.connectionMade()
self.xmlstream.dataReceived(b"<stream:stream xmlns='jabber:client' "
b"xmlns:stream='http://etherx.jabber.org/streams' "
b"from='example.com' id='12345' version='1.0'>")
self.init = DummySASLInitiatingInitializer(self.xmlstream)
def test_onFailure(self):
"""
Test that the SASL error condition is correctly extracted.
"""
failure = domish.Element(('urn:ietf:params:xml:ns:xmpp-sasl',
'failure'))
failure.addElement('not-authorized')
self.init._deferred = defer.Deferred()
self.init.onFailure(failure)
self.assertFailure(self.init._deferred, sasl.SASLAuthError)
self.init._deferred.addCallback(lambda e:
self.assertEqual('not-authorized',
e.condition))
return self.init._deferred
def test_sendAuthInitialResponse(self):
"""
Test starting authentication with an initial response.
"""
self.init.initialResponse = b"dummy"
self.init.start()
auth = self.output[0]
self.assertEqual(NS_XMPP_SASL, auth.uri)
self.assertEqual(u'auth', auth.name)
self.assertEqual(u'DUMMY', auth['mechanism'])
self.assertEqual(u'ZHVtbXk=', unicode(auth))
def test_sendAuthNoInitialResponse(self):
"""
Test starting authentication without an initial response.
"""
self.init.initialResponse = None
self.init.start()
auth = self.output[0]
self.assertEqual(u'', str(auth))
def test_sendAuthEmptyInitialResponse(self):
"""
Test starting authentication where the initial response is empty.
"""
self.init.initialResponse = b""
self.init.start()
auth = self.output[0]
self.assertEqual('=', unicode(auth))
def test_onChallenge(self):
"""
Test receiving a challenge message.
"""
d = self.init.start()
challenge = domish.Element((NS_XMPP_SASL, 'challenge'))
challenge.addContent(u'bXkgY2hhbGxlbmdl')
self.init.onChallenge(challenge)
self.assertEqual(b'my challenge', self.init.mechanism.challenge)
self.init.onSuccess(None)
return d
def test_onChallengeResponse(self):
"""
A non-empty response gets encoded and included as character data.
"""
d = self.init.start()
challenge = domish.Element((NS_XMPP_SASL, 'challenge'))
challenge.addContent(u'bXkgY2hhbGxlbmdl')
self.init.mechanism.response = b"response"
self.init.onChallenge(challenge)
response = self.output[1]
self.assertEqual(u'cmVzcG9uc2U=', unicode(response))
self.init.onSuccess(None)
return d
def test_onChallengeEmpty(self):
"""
Test receiving an empty challenge message.
"""
d = self.init.start()
challenge = domish.Element((NS_XMPP_SASL, 'challenge'))
self.init.onChallenge(challenge)
self.assertEqual(b'', self.init.mechanism.challenge)
self.init.onSuccess(None)
return d
def test_onChallengeIllegalPadding(self):
"""
Test receiving a challenge message with illegal padding.
"""
d = self.init.start()
challenge = domish.Element((NS_XMPP_SASL, 'challenge'))
challenge.addContent(u'bXkg=Y2hhbGxlbmdl')
self.init.onChallenge(challenge)
self.assertFailure(d, sasl.SASLIncorrectEncodingError)
return d
def test_onChallengeIllegalCharacters(self):
"""
Test receiving a challenge message with illegal characters.
"""
d = self.init.start()
challenge = domish.Element((NS_XMPP_SASL, 'challenge'))
challenge.addContent(u'bXkg*Y2hhbGxlbmdl')
self.init.onChallenge(challenge)
self.assertFailure(d, sasl.SASLIncorrectEncodingError)
return d
def test_onChallengeMalformed(self):
"""
Test receiving a malformed challenge message.
"""
d = self.init.start()
challenge = domish.Element((NS_XMPP_SASL, 'challenge'))
challenge.addContent(u'a')
self.init.onChallenge(challenge)
self.assertFailure(d, sasl.SASLIncorrectEncodingError)
return d
class SASLInitiatingInitializerSetMechanismTests(unittest.TestCase):
"""
Test for L{sasl.SASLInitiatingInitializer.setMechanism}.
"""
def setUp(self):
self.output = []
self.authenticator = xmlstream.Authenticator()
self.xmlstream = xmlstream.XmlStream(self.authenticator)
self.xmlstream.send = self.output.append
self.xmlstream.connectionMade()
self.xmlstream.dataReceived("<stream:stream xmlns='jabber:client' "
"xmlns:stream='http://etherx.jabber.org/streams' "
"from='example.com' id='12345' version='1.0'>")
self.init = sasl.SASLInitiatingInitializer(self.xmlstream)
def _setMechanism(self, name):
"""
Set up the XML Stream to have a SASL feature with the given mechanism.
"""
feature = domish.Element((NS_XMPP_SASL, 'mechanisms'))
feature.addElement('mechanism', content=name)
self.xmlstream.features[(feature.uri, feature.name)] = feature
self.init.setMechanism()
return self.init.mechanism.name
def test_anonymous(self):
"""
Test setting ANONYMOUS as the authentication mechanism.
"""
self.authenticator.jid = jid.JID('example.com')
self.authenticator.password = None
name = u"ANONYMOUS"
self.assertEqual(name, self._setMechanism(name))
def test_plain(self):
"""
Test setting PLAIN as the authentication mechanism.
"""
self.authenticator.jid = jid.JID('test@example.com')
self.authenticator.password = 'secret'
name = u"PLAIN"
self.assertEqual(name, self._setMechanism(name))
def test_digest(self):
"""
Test setting DIGEST-MD5 as the authentication mechanism.
"""
self.authenticator.jid = jid.JID('test@example.com')
self.authenticator.password = 'secret'
name = u"DIGEST-MD5"
self.assertEqual(name, self._setMechanism(name))
def test_notAcceptable(self):
"""
Test using an unacceptable SASL authentication mechanism.
"""
self.authenticator.jid = jid.JID('test@example.com')
self.authenticator.password = u'secret'
self.assertRaises(sasl.SASLNoAcceptableMechanism,
self._setMechanism, u'SOMETHING_UNACCEPTABLE')
def test_notAcceptableWithoutUser(self):
"""
Test using an unacceptable SASL authentication mechanism with no JID.
"""
self.authenticator.jid = jid.JID('example.com')
self.authenticator.password = u'secret'
self.assertRaises(sasl.SASLNoAcceptableMechanism,
self._setMechanism, u'SOMETHING_UNACCEPTABLE')

View file

@ -0,0 +1,163 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.words.protocols.jabber.sasl_mechanisms}.
"""
from __future__ import absolute_import, division
from twisted.trial import unittest
from twisted.python.compat import networkString
from twisted.words.protocols.jabber import sasl_mechanisms
class PlainTests(unittest.TestCase):
"""
Tests for L{twisted.words.protocols.jabber.sasl_mechanisms.Plain}.
"""
def test_getInitialResponse(self):
"""
Test the initial response.
"""
m = sasl_mechanisms.Plain(None, u'test', u'secret')
self.assertEqual(m.getInitialResponse(), b'\x00test\x00secret')
class AnonymousTests(unittest.TestCase):
"""
Tests for L{twisted.words.protocols.jabber.sasl_mechanisms.Anonymous}.
"""
def test_getInitialResponse(self):
"""
Test the initial response to be empty.
"""
m = sasl_mechanisms.Anonymous()
self.assertEqual(m.getInitialResponse(), None)
class DigestMD5Tests(unittest.TestCase):
"""
Tests for L{twisted.words.protocols.jabber.sasl_mechanisms.DigestMD5}.
"""
def setUp(self):
self.mechanism = sasl_mechanisms.DigestMD5(
u'xmpp', u'example.org', None, u'test', u'secret')
def test_getInitialResponse(self):
"""
Test that no initial response is generated.
"""
self.assertIdentical(self.mechanism.getInitialResponse(), None)
def test_getResponse(self):
"""
The response to a Digest-MD5 challenge includes the parameters from the
challenge.
"""
challenge = (
b'realm="localhost",nonce="1234",qop="auth",charset=utf-8,'
b'algorithm=md5-sess')
directives = self.mechanism._parse(
self.mechanism.getResponse(challenge))
del directives[b"cnonce"], directives[b"response"]
self.assertEqual({
b'username': b'test', b'nonce': b'1234', b'nc': b'00000001',
b'qop': [b'auth'], b'charset': b'utf-8',
b'realm': b'localhost', b'digest-uri': b'xmpp/example.org'
}, directives)
def test_getResponseNonAsciiRealm(self):
"""
Bytes outside the ASCII range in the challenge are nevertheless
included in the response.
"""
challenge = (b'realm="\xc3\xa9chec.example.org",nonce="1234",'
b'qop="auth",charset=utf-8,algorithm=md5-sess')
directives = self.mechanism._parse(
self.mechanism.getResponse(challenge))
del directives[b"cnonce"], directives[b"response"]
self.assertEqual({
b'username': b'test', b'nonce': b'1234', b'nc': b'00000001',
b'qop': [b'auth'], b'charset': b'utf-8',
b'realm': b'\xc3\xa9chec.example.org',
b'digest-uri': b'xmpp/example.org'}, directives)
def test_getResponseNoRealm(self):
"""
The response to a challenge without a realm uses the host part of the
JID as the realm.
"""
challenge = b'nonce="1234",qop="auth",charset=utf-8,algorithm=md5-sess'
directives = self.mechanism._parse(
self.mechanism.getResponse(challenge))
self.assertEqual(directives[b'realm'], b'example.org')
def test_getResponseNoRealmIDN(self):
"""
If the challenge does not include a realm and the host part of the JID
includes bytes outside of the ASCII range, the response still includes
the host part of the JID as the realm.
"""
self.mechanism = sasl_mechanisms.DigestMD5(
u'xmpp', u'\u00e9chec.example.org', None, u'test', u'secret')
challenge = b'nonce="1234",qop="auth",charset=utf-8,algorithm=md5-sess'
directives = self.mechanism._parse(
self.mechanism.getResponse(challenge))
self.assertEqual(directives[b'realm'], b'\xc3\xa9chec.example.org')
def test_getResponseRspauth(self):
"""
If the challenge just has a rspauth directive, the response is empty.
"""
challenge = \
b'rspauth=cnNwYXV0aD1lYTQwZjYwMzM1YzQyN2I1NTI3Yjg0ZGJhYmNkZmZmZA=='
response = self.mechanism.getResponse(challenge)
self.assertEqual(b"", response)
def test_calculateResponse(self):
"""
The response to a Digest-MD5 challenge is computed according to RFC
2831.
"""
charset = 'utf-8'
nonce = b'OA6MG9tEQGm2hh'
nc = networkString('%08x' % (1,))
cnonce = b'OA6MHXh6VqTrRk'
username = u'\u0418chris'
password = u'\u0418secret'
host = u'\u0418elwood.innosoft.com'
digestURI = u'imap/\u0418elwood.innosoft.com'.encode(charset)
mechanism = sasl_mechanisms.DigestMD5(
b'imap', host, None, username, password)
response = mechanism._calculateResponse(
cnonce, nc, nonce, username.encode(charset),
password.encode(charset), host.encode(charset), digestURI)
self.assertEqual(response, b'7928f233258be88392424d094453c5e3')
def test_parse(self):
"""
A challenge can be parsed into a L{dict} with L{bytes} or L{list}
values.
"""
challenge = (
b'nonce="1234",qop="auth,auth-conf",charset=utf-8,'
b'algorithm=md5-sess,cipher="des,3des"')
directives = self.mechanism._parse(challenge)
self.assertEqual({
b"algorithm": b"md5-sess", b"nonce": b"1234",
b"charset": b"utf-8", b"qop": [b'auth', b'auth-conf'],
b"cipher": [b'des', b'3des']
}, directives)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,115 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
from twisted.trial import unittest
from twisted.words.protocols.jabber.xmpp_stringprep import (
nodeprep, resourceprep, nameprep)
class DeprecationTests(unittest.TestCase):
"""
Deprecations in L{twisted.words.protocols.jabber.xmpp_stringprep}.
"""
def test_crippled(self):
"""
L{xmpp_stringprep.crippled} is deprecated and always returns C{False}.
"""
from twisted.words.protocols.jabber.xmpp_stringprep import crippled
warnings = self.flushWarnings(
offendingFunctions=[self.test_crippled])
self.assertEqual(DeprecationWarning, warnings[0]['category'])
self.assertEqual(
"twisted.words.protocols.jabber.xmpp_stringprep.crippled was "
"deprecated in Twisted 13.1.0: crippled is always False",
warnings[0]['message'])
self.assertEqual(1, len(warnings))
self.assertEqual(crippled, False)
class XMPPStringPrepTests(unittest.TestCase):
"""
The nodeprep stringprep profile is similar to the resourceprep profile,
but does an extra mapping of characters (table B.2) and disallows
more characters (table C.1.1 and eight extra punctuation characters).
Due to this similarity, the resourceprep tests are more extensive, and
the nodeprep tests only address the mappings additional restrictions.
The nameprep profile is nearly identical to the nameprep implementation in
L{encodings.idna}, but that implementation assumes the C{UseSTD4ASCIIRules}
flag to be false. This implementation assumes it to be true, and restricts
the allowed set of characters. The tests here only check for the
differences.
"""
def testResourcePrep(self):
self.assertEqual(resourceprep.prepare(u'resource'), u'resource')
self.assertNotEqual(resourceprep.prepare(u'Resource'), u'resource')
self.assertEqual(resourceprep.prepare(u' '), u' ')
self.assertEqual(resourceprep.prepare(u'Henry \u2163'), u'Henry IV')
self.assertEqual(resourceprep.prepare(u'foo\xad\u034f\u1806\u180b'
u'bar\u200b\u2060'
u'baz\ufe00\ufe08\ufe0f\ufeff'),
u'foobarbaz')
self.assertEqual(resourceprep.prepare(u'\u00a0'), u' ')
self.assertRaises(UnicodeError, resourceprep.prepare, u'\u1680')
self.assertEqual(resourceprep.prepare(u'\u2000'), u' ')
self.assertEqual(resourceprep.prepare(u'\u200b'), u'')
self.assertRaises(UnicodeError, resourceprep.prepare, u'\u0010\u007f')
self.assertRaises(UnicodeError, resourceprep.prepare, u'\u0085')
self.assertRaises(UnicodeError, resourceprep.prepare, u'\u180e')
self.assertEqual(resourceprep.prepare(u'\ufeff'), u'')
self.assertRaises(UnicodeError, resourceprep.prepare, u'\uf123')
self.assertRaises(UnicodeError, resourceprep.prepare, u'\U000f1234')
self.assertRaises(UnicodeError, resourceprep.prepare, u'\U0010f234')
self.assertRaises(UnicodeError, resourceprep.prepare, u'\U0008fffe')
self.assertRaises(UnicodeError, resourceprep.prepare, u'\U0010ffff')
self.assertRaises(UnicodeError, resourceprep.prepare, u'\udf42')
self.assertRaises(UnicodeError, resourceprep.prepare, u'\ufffd')
self.assertRaises(UnicodeError, resourceprep.prepare, u'\u2ff5')
self.assertEqual(resourceprep.prepare(u'\u0341'), u'\u0301')
self.assertRaises(UnicodeError, resourceprep.prepare, u'\u200e')
self.assertRaises(UnicodeError, resourceprep.prepare, u'\u202a')
self.assertRaises(UnicodeError, resourceprep.prepare, u'\U000e0001')
self.assertRaises(UnicodeError, resourceprep.prepare, u'\U000e0042')
self.assertRaises(UnicodeError, resourceprep.prepare, u'foo\u05bebar')
self.assertRaises(UnicodeError, resourceprep.prepare, u'foo\ufd50bar')
#self.assertEqual(resourceprep.prepare(u'foo\ufb38bar'),
# u'foo\u064ebar')
self.assertRaises(UnicodeError, resourceprep.prepare, u'\u06271')
self.assertEqual(resourceprep.prepare(u'\u06271\u0628'),
u'\u06271\u0628')
self.assertRaises(UnicodeError, resourceprep.prepare, u'\U000e0002')
def testNodePrep(self):
self.assertEqual(nodeprep.prepare(u'user'), u'user')
self.assertEqual(nodeprep.prepare(u'User'), u'user')
self.assertRaises(UnicodeError, nodeprep.prepare, u'us&er')
def test_nodeprepUnassignedInUnicode32(self):
"""
Make sure unassigned code points from Unicode 3.2 are rejected.
"""
self.assertRaises(UnicodeError, nodeprep.prepare, u'\u1d39')
def testNamePrep(self):
self.assertEqual(nameprep.prepare(u'example.com'), u'example.com')
self.assertEqual(nameprep.prepare(u'Example.com'), u'example.com')
self.assertRaises(UnicodeError, nameprep.prepare, u'ex@mple.com')
self.assertRaises(UnicodeError, nameprep.prepare, u'-example.com')
self.assertRaises(UnicodeError, nameprep.prepare, u'example-.com')
self.assertEqual(nameprep.prepare(u'stra\u00dfe.example.com'),
u'strasse.example.com')
def test_nameprepTrailingDot(self):
"""
A trailing dot in domain names is preserved.
"""
self.assertEqual(nameprep.prepare(u'example.com.'), u'example.com.')

View file

@ -0,0 +1,843 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.words.service}.
"""
import time
from twisted.cred import portal, credentials, checkers
from twisted.internet import address, defer, reactor
from twisted.internet.defer import Deferred, DeferredList, maybeDeferred, succeed
from twisted.python.compat import unicode
from twisted.spread import pb
from twisted.test import proto_helpers
from twisted.trial import unittest
from twisted.words import ewords, service
from twisted.words.protocols import irc
class RealmTests(unittest.TestCase):
def _entityCreationTest(self, kind):
# Kind is "user" or "group"
realm = service.InMemoryWordsRealm("realmname")
name = u'test' + kind.lower()
create = getattr(realm, 'create' + kind.title())
get = getattr(realm, 'get' + kind.title())
flag = 'create' + kind.title() + 'OnRequest'
dupExc = getattr(ewords, 'Duplicate' + kind.title())
noSuchExc = getattr(ewords, 'NoSuch' + kind.title())
# Creating should succeed
p = self.successResultOf(create(name))
self.assertEqual(name, p.name)
# Creating the same user again should not
self.failureResultOf(create(name)).trap(dupExc)
# Getting a non-existent user should succeed if createUserOnRequest is True
setattr(realm, flag, True)
p = self.successResultOf(get(u"new" + kind.lower()))
self.assertEqual("new" + kind.lower(), p.name)
# Getting that user again should return the same object
newp = self.successResultOf(get(u"new" + kind.lower()))
self.assertIdentical(p, newp)
# Getting a non-existent user should fail if createUserOnRequest is False
setattr(realm, flag, False)
self.failureResultOf(get(u"another" + kind.lower())).trap(noSuchExc)
def testUserCreation(self):
return self._entityCreationTest("User")
def testGroupCreation(self):
return self._entityCreationTest("Group")
def testUserRetrieval(self):
realm = service.InMemoryWordsRealm("realmname")
# Make a user to play around with
user = self.successResultOf(realm.createUser(u"testuser"))
# Make sure getting the user returns the same object
retrieved = self.successResultOf(realm.getUser(u"testuser"))
self.assertIdentical(user, retrieved)
# Make sure looking up the user also returns the same object
lookedUp = self.successResultOf(realm.lookupUser(u"testuser"))
self.assertIdentical(retrieved, lookedUp)
# Make sure looking up a user who does not exist fails
(self.failureResultOf(realm.lookupUser(u"nosuchuser"))
.trap(ewords.NoSuchUser))
def testUserAddition(self):
realm = service.InMemoryWordsRealm("realmname")
# Create and manually add a user to the realm
p = service.User("testuser")
user = self.successResultOf(realm.addUser(p))
self.assertIdentical(p, user)
# Make sure getting that user returns the same object
retrieved = self.successResultOf(realm.getUser(u"testuser"))
self.assertIdentical(user, retrieved)
# Make sure looking up that user returns the same object
lookedUp = self.successResultOf(realm.lookupUser(u"testuser"))
self.assertIdentical(retrieved, lookedUp)
def testGroupRetrieval(self):
realm = service.InMemoryWordsRealm("realmname")
group = self.successResultOf(realm.createGroup(u"testgroup"))
retrieved = self.successResultOf(realm.getGroup(u"testgroup"))
self.assertIdentical(group, retrieved)
(self.failureResultOf(realm.getGroup(u"nosuchgroup"))
.trap(ewords.NoSuchGroup))
def testGroupAddition(self):
realm = service.InMemoryWordsRealm("realmname")
p = service.Group("testgroup")
self.successResultOf(realm.addGroup(p))
group = self.successResultOf(realm.getGroup(u"testGroup"))
self.assertIdentical(p, group)
def testGroupUsernameCollision(self):
"""
Try creating a group with the same name as an existing user and
assert that it succeeds, since users and groups should not be in the
same namespace and collisions should be impossible.
"""
realm = service.InMemoryWordsRealm("realmname")
self.successResultOf(realm.createUser(u"test"))
self.successResultOf(realm.createGroup(u"test"))
def testEnumeration(self):
realm = service.InMemoryWordsRealm("realmname")
self.successResultOf(realm.createGroup(u"groupone"))
self.successResultOf(realm.createGroup(u"grouptwo"))
groups = self.successResultOf(realm.itergroups())
n = [g.name for g in groups]
n.sort()
self.assertEqual(n, ["groupone", "grouptwo"])
class TestCaseUserAgg(object):
def __init__(self, user, realm, factory, address=address.IPv4Address('TCP', '127.0.0.1', 54321)):
self.user = user
self.transport = proto_helpers.StringTransportWithDisconnection()
self.protocol = factory.buildProtocol(address)
self.transport.protocol = self.protocol
self.user.mind = self.protocol
self.protocol.makeConnection(self.transport)
def write(self, stuff):
self.protocol.dataReceived(stuff)
class IRCProtocolTests(unittest.TestCase):
STATIC_USERS = [
u'useruser', u'otheruser', u'someguy', u'firstuser', u'username',
u'userone', u'usertwo', u'userthree', 'userfour', b'userfive', u'someuser']
def setUp(self):
self.realm = service.InMemoryWordsRealm("realmname")
self.checker = checkers.InMemoryUsernamePasswordDatabaseDontUse()
self.portal = portal.Portal(self.realm, [self.checker])
self.factory = service.IRCFactory(self.realm, self.portal)
c = []
for nick in self.STATIC_USERS:
if isinstance(nick, bytes):
nick = nick.decode("utf-8")
c.append(self.realm.createUser(nick))
self.checker.addUser(nick, nick + u"_password")
return DeferredList(c)
def _assertGreeting(self, user):
"""
The user has been greeted with the four messages that are (usually)
considered to start an IRC session.
Asserts that the required responses were received.
"""
# Make sure we get 1-4 at least
response = self._response(user)
expected = [irc.RPL_WELCOME, irc.RPL_YOURHOST, irc.RPL_CREATED,
irc.RPL_MYINFO]
for (prefix, command, args) in response:
if command in expected:
expected.remove(command)
self.assertFalse(expected, "Missing responses for %r" % (expected,))
def _login(self, user, nick, password=None):
if password is None:
password = nick + "_password"
user.write(u'PASS %s\r\n' % (password,))
user.write(u'NICK %s extrainfo\r\n' % (nick,))
def _loggedInUser(self, name):
user = self.successResultOf(self.realm.lookupUser(name))
agg = TestCaseUserAgg(user, self.realm, self.factory)
self._login(agg, name)
return agg
def _response(self, user, messageType=None):
"""
Extracts the user's response, and returns a list of parsed lines.
If messageType is defined, only messages of that type will be returned.
"""
response = user.transport.value()
if bytes != str and isinstance(response, bytes):
response = response.decode("utf-8")
response = response.splitlines()
user.transport.clear()
result = []
for message in map(irc.parsemsg, response):
if messageType is None or message[1] == messageType:
result.append(message)
return result
def testPASSLogin(self):
user = self._loggedInUser(u'firstuser')
self._assertGreeting(user)
def test_nickServLogin(self):
"""
Sending NICK without PASS will prompt the user for their password.
When the user sends their password to NickServ, it will respond with a
Greeting.
"""
firstuser = self.successResultOf(self.realm.lookupUser(u'firstuser'))
user = TestCaseUserAgg(firstuser, self.realm, self.factory)
user.write('NICK firstuser extrainfo\r\n')
response = self._response(user, 'PRIVMSG')
self.assertEqual(len(response), 1)
self.assertEqual(response[0][0], service.NICKSERV)
self.assertEqual(response[0][1], 'PRIVMSG')
self.assertEqual(response[0][2], ['firstuser', 'Password?'])
user.transport.clear()
user.write('PRIVMSG nickserv firstuser_password\r\n')
self._assertGreeting(user)
def testFailedLogin(self):
firstuser = self.successResultOf(self.realm.lookupUser(u'firstuser'))
user = TestCaseUserAgg(firstuser, self.realm, self.factory)
self._login(user, u"firstuser", u"wrongpass")
response = self._response(user, "PRIVMSG")
self.assertEqual(len(response), 1)
self.assertEqual(response[0][2], ['firstuser', 'Login failed. Goodbye.'])
def testLogout(self):
logout = []
firstuser = self.successResultOf(self.realm.lookupUser(u'firstuser'))
user = TestCaseUserAgg(firstuser, self.realm, self.factory)
self._login(user, "firstuser")
user.protocol.logout = lambda: logout.append(True)
user.write('QUIT\r\n')
self.assertEqual(logout, [True])
def testJoin(self):
firstuser = self.successResultOf(self.realm.lookupUser(u'firstuser'))
somechannel = self.successResultOf(
self.realm.createGroup(u"somechannel"))
somechannel.meta['topic'] = 'some random topic'
# Bring in one user, make sure he gets into the channel sanely
user = TestCaseUserAgg(firstuser, self.realm, self.factory)
self._login(user, "firstuser")
user.transport.clear()
user.write('JOIN #somechannel\r\n')
response = self._response(user)
self.assertEqual(len(response), 5)
# Join message
self.assertEqual(response[0][0], 'firstuser!firstuser@realmname')
self.assertEqual(response[0][1], 'JOIN')
self.assertEqual(response[0][2], ['#somechannel'])
# User list
self.assertEqual(response[1][1], '353')
self.assertEqual(response[2][1], '366')
# Topic (or lack thereof, as the case may be)
self.assertEqual(response[3][1], '332')
self.assertEqual(response[4][1], '333')
# Hook up another client! It is a CHAT SYSTEM!!!!!!!
other = self._loggedInUser(u'otheruser')
other.transport.clear()
user.transport.clear()
other.write('JOIN #somechannel\r\n')
# At this point, both users should be in the channel
response = self._response(other)
event = self._response(user)
self.assertEqual(len(event), 1)
self.assertEqual(event[0][0], 'otheruser!otheruser@realmname')
self.assertEqual(event[0][1], 'JOIN')
self.assertEqual(event[0][2], ['#somechannel'])
self.assertEqual(response[1][0], 'realmname')
self.assertEqual(response[1][1], '353')
self.assertIn(response[1][2], [
['otheruser', '=', '#somechannel', 'firstuser otheruser'],
['otheruser', '=', '#somechannel', 'otheruser firstuser'],
])
def test_joinTopicless(self):
"""
When a user joins a group without a topic, no topic information is
sent to that user.
"""
firstuser = self.successResultOf(self.realm.lookupUser(u'firstuser'))
self.successResultOf(self.realm.createGroup(u"somechannel"))
# Bring in one user, make sure he gets into the channel sanely
user = TestCaseUserAgg(firstuser, self.realm, self.factory)
self._login(user, "firstuser")
user.transport.clear()
user.write('JOIN #somechannel\r\n')
response = self._response(user)
responseCodes = [r[1] for r in response]
self.assertNotIn('332', responseCodes)
self.assertNotIn('333', responseCodes)
def testLeave(self):
user = self._loggedInUser(u'useruser')
self.successResultOf(self.realm.createGroup(u"somechannel"))
user.write('JOIN #somechannel\r\n')
user.transport.clear()
other = self._loggedInUser(u'otheruser')
other.write('JOIN #somechannel\r\n')
user.transport.clear()
other.transport.clear()
user.write('PART #somechannel\r\n')
response = self._response(user)
event = self._response(other)
self.assertEqual(len(response), 1)
self.assertEqual(response[0][0], 'useruser!useruser@realmname')
self.assertEqual(response[0][1], 'PART')
self.assertEqual(response[0][2], ['#somechannel', 'leaving'])
self.assertEqual(response, event)
# Now again, with a part message
user.write('JOIN #somechannel\r\n')
user.transport.clear()
other.transport.clear()
user.write('PART #somechannel :goodbye stupidheads\r\n')
response = self._response(user)
event = self._response(other)
self.assertEqual(len(response), 1)
self.assertEqual(response[0][0], 'useruser!useruser@realmname')
self.assertEqual(response[0][1], 'PART')
self.assertEqual(response[0][2], ['#somechannel', 'goodbye stupidheads'])
self.assertEqual(response, event)
user.write(b'JOIN #somechannel\r\n')
user.transport.clear()
other.transport.clear()
user.write(b'PART #somechannel :goodbye stupidheads1\r\n')
response = self._response(user)
event = self._response(other)
self.assertEqual(len(response), 1)
self.assertEqual(response[0][0], 'useruser!useruser@realmname')
self.assertEqual(response[0][1], 'PART')
self.assertEqual(response[0][2], ['#somechannel', 'goodbye stupidheads1'])
self.assertEqual(response, event)
def testGetTopic(self):
user = self._loggedInUser(u'useruser')
group = service.Group("somechannel")
group.meta["topic"] = "This is a test topic."
group.meta["topic_author"] = "some_fellow"
group.meta["topic_date"] = 77777777
self.successResultOf(self.realm.addGroup(group))
user.transport.clear()
user.write("JOIN #somechannel\r\n")
response = self._response(user)
self.assertEqual(response[3][0], 'realmname')
self.assertEqual(response[3][1], '332')
# XXX Sigh. irc.parsemsg() is not as correct as one might hope.
self.assertEqual(response[3][2], ['useruser', '#somechannel', 'This is a test topic.'])
self.assertEqual(response[4][1], '333')
self.assertEqual(response[4][2], ['useruser', '#somechannel', 'some_fellow', '77777777'])
user.transport.clear()
user.write('TOPIC #somechannel\r\n')
response = self._response(user)
self.assertEqual(response[0][1], '332')
self.assertEqual(response[0][2], ['useruser', '#somechannel', 'This is a test topic.'])
self.assertEqual(response[1][1], '333')
self.assertEqual(response[1][2], ['useruser', '#somechannel', 'some_fellow', '77777777'])
def testSetTopic(self):
user = self._loggedInUser(u'useruser')
somechannel = self.successResultOf(
self.realm.createGroup(u"somechannel"))
user.write("JOIN #somechannel\r\n")
other = self._loggedInUser(u'otheruser')
other.write("JOIN #somechannel\r\n")
user.transport.clear()
other.transport.clear()
other.write('TOPIC #somechannel :This is the new topic.\r\n')
response = self._response(other)
event = self._response(user)
self.assertEqual(response, event)
self.assertEqual(response[0][0], 'otheruser!otheruser@realmname')
self.assertEqual(response[0][1], 'TOPIC')
self.assertEqual(response[0][2], ['#somechannel', 'This is the new topic.'])
other.transport.clear()
somechannel.meta['topic_date'] = 12345
other.write('TOPIC #somechannel\r\n')
response = self._response(other)
self.assertEqual(response[0][1], '332')
self.assertEqual(response[0][2], ['otheruser', '#somechannel', 'This is the new topic.'])
self.assertEqual(response[1][1], '333')
self.assertEqual(response[1][2], ['otheruser', '#somechannel', 'otheruser', '12345'])
other.transport.clear()
other.write('TOPIC #asdlkjasd\r\n')
response = self._response(other)
self.assertEqual(response[0][1], '403')
def testGroupMessage(self):
user = self._loggedInUser(u'useruser')
self.successResultOf(self.realm.createGroup(u"somechannel"))
user.write("JOIN #somechannel\r\n")
other = self._loggedInUser(u'otheruser')
other.write("JOIN #somechannel\r\n")
user.transport.clear()
other.transport.clear()
user.write('PRIVMSG #somechannel :Hello, world.\r\n')
response = self._response(user)
event = self._response(other)
self.assertFalse(response)
self.assertEqual(len(event), 1)
self.assertEqual(event[0][0], 'useruser!useruser@realmname')
self.assertEqual(event[0][1], 'PRIVMSG', -1)
self.assertEqual(event[0][2], ['#somechannel', 'Hello, world.'])
def testPrivateMessage(self):
user = self._loggedInUser(u'useruser')
other = self._loggedInUser(u'otheruser')
user.transport.clear()
other.transport.clear()
user.write('PRIVMSG otheruser :Hello, monkey.\r\n')
response = self._response(user)
event = self._response(other)
self.assertFalse(response)
self.assertEqual(len(event), 1)
self.assertEqual(event[0][0], 'useruser!useruser@realmname')
self.assertEqual(event[0][1], 'PRIVMSG')
self.assertEqual(event[0][2], ['otheruser', 'Hello, monkey.'])
user.write('PRIVMSG nousernamedthis :Hello, monkey.\r\n')
response = self._response(user)
self.assertEqual(len(response), 1)
self.assertEqual(response[0][0], 'realmname')
self.assertEqual(response[0][1], '401')
self.assertEqual(response[0][2], ['useruser', 'nousernamedthis', 'No such nick/channel.'])
def testOper(self):
user = self._loggedInUser(u'useruser')
user.transport.clear()
user.write('OPER user pass\r\n')
response = self._response(user)
self.assertEqual(len(response), 1)
self.assertEqual(response[0][1], '491')
def testGetUserMode(self):
user = self._loggedInUser(u'useruser')
user.transport.clear()
user.write('MODE useruser\r\n')
response = self._response(user)
self.assertEqual(len(response), 1)
self.assertEqual(response[0][0], 'realmname')
self.assertEqual(response[0][1], '221')
self.assertEqual(response[0][2], ['useruser', '+'])
def testSetUserMode(self):
user = self._loggedInUser(u'useruser')
user.transport.clear()
user.write('MODE useruser +abcd\r\n')
response = self._response(user)
self.assertEqual(len(response), 1)
self.assertEqual(response[0][1], '472')
def testGetGroupMode(self):
user = self._loggedInUser(u'useruser')
self.successResultOf(self.realm.createGroup(u"somechannel"))
user.write('JOIN #somechannel\r\n')
user.transport.clear()
user.write('MODE #somechannel\r\n')
response = self._response(user)
self.assertEqual(len(response), 1)
self.assertEqual(response[0][1], '324')
def testSetGroupMode(self):
user = self._loggedInUser(u'useruser')
self.successResultOf(self.realm.createGroup(u"groupname"))
user.write('JOIN #groupname\r\n')
user.transport.clear()
user.write('MODE #groupname +abcd\r\n')
response = self._response(user)
self.assertEqual(len(response), 1)
self.assertEqual(response[0][1], '472')
def testWho(self):
group = service.Group('groupname')
self.successResultOf(self.realm.addGroup(group))
users = []
for nick in u'userone', u'usertwo', u'userthree':
u = self._loggedInUser(nick)
users.append(u)
users[-1].write('JOIN #groupname\r\n')
for user in users:
user.transport.clear()
users[0].write('WHO #groupname\r\n')
r = self._response(users[0])
self.assertFalse(self._response(users[1]))
self.assertFalse(self._response(users[2]))
wantusers = ['userone', 'usertwo', 'userthree']
for (prefix, code, stuff) in r[:-1]:
self.assertEqual(prefix, 'realmname')
self.assertEqual(code, '352')
(myname, group, theirname, theirhost, theirserver, theirnick, flag, extra) = stuff
self.assertEqual(myname, 'userone')
self.assertEqual(group, '#groupname')
self.assertTrue(theirname in wantusers)
self.assertEqual(theirhost, 'realmname')
self.assertEqual(theirserver, 'realmname')
wantusers.remove(theirnick)
self.assertEqual(flag, 'H')
self.assertEqual(extra, '0 ' + theirnick)
self.assertFalse(wantusers)
prefix, code, stuff = r[-1]
self.assertEqual(prefix, 'realmname')
self.assertEqual(code, '315')
myname, channel, extra = stuff
self.assertEqual(myname, 'userone')
self.assertEqual(channel, '#groupname')
self.assertEqual(extra, 'End of /WHO list.')
def testList(self):
user = self._loggedInUser(u"someuser")
user.transport.clear()
somegroup = self.successResultOf(self.realm.createGroup(u"somegroup"))
somegroup.size = lambda: succeed(17)
somegroup.meta['topic'] = 'this is the topic woo'
# Test one group
user.write('LIST #somegroup\r\n')
r = self._response(user)
self.assertEqual(len(r), 2)
resp, end = r
self.assertEqual(resp[0], 'realmname')
self.assertEqual(resp[1], '322')
self.assertEqual(resp[2][0], 'someuser')
self.assertEqual(resp[2][1], 'somegroup')
self.assertEqual(resp[2][2], '17')
self.assertEqual(resp[2][3], 'this is the topic woo')
self.assertEqual(end[0], 'realmname')
self.assertEqual(end[1], '323')
self.assertEqual(end[2][0], 'someuser')
self.assertEqual(end[2][1], 'End of /LIST')
user.transport.clear()
# Test all groups
user.write('LIST\r\n')
r = self._response(user)
self.assertEqual(len(r), 2)
fg1, end = r
self.assertEqual(fg1[1], '322')
self.assertEqual(fg1[2][1], 'somegroup')
self.assertEqual(fg1[2][2], '17')
self.assertEqual(fg1[2][3], 'this is the topic woo')
self.assertEqual(end[1], '323')
def testWhois(self):
user = self._loggedInUser(u'someguy')
otherguy = service.User("otherguy")
otherguy.itergroups = lambda: iter([
service.Group('groupA'),
service.Group('groupB')])
otherguy.signOn = 10
otherguy.lastMessage = time.time() - 15
self.successResultOf(self.realm.addUser(otherguy))
user.transport.clear()
user.write('WHOIS otherguy\r\n')
r = self._response(user)
self.assertEqual(len(r), 5)
wuser, wserver, idle, channels, end = r
self.assertEqual(wuser[0], 'realmname')
self.assertEqual(wuser[1], '311')
self.assertEqual(wuser[2][0], 'someguy')
self.assertEqual(wuser[2][1], 'otherguy')
self.assertEqual(wuser[2][2], 'otherguy')
self.assertEqual(wuser[2][3], 'realmname')
self.assertEqual(wuser[2][4], '*')
self.assertEqual(wuser[2][5], 'otherguy')
self.assertEqual(wserver[0], 'realmname')
self.assertEqual(wserver[1], '312')
self.assertEqual(wserver[2][0], 'someguy')
self.assertEqual(wserver[2][1], 'otherguy')
self.assertEqual(wserver[2][2], 'realmname')
self.assertEqual(wserver[2][3], 'Hi mom!')
self.assertEqual(idle[0], 'realmname')
self.assertEqual(idle[1], '317')
self.assertEqual(idle[2][0], 'someguy')
self.assertEqual(idle[2][1], 'otherguy')
self.assertEqual(idle[2][2], '15')
self.assertEqual(idle[2][3], '10')
self.assertEqual(idle[2][4], "seconds idle, signon time")
self.assertEqual(channels[0], 'realmname')
self.assertEqual(channels[1], '319')
self.assertEqual(channels[2][0], 'someguy')
self.assertEqual(channels[2][1], 'otherguy')
self.assertEqual(channels[2][2], '#groupA #groupB')
self.assertEqual(end[0], 'realmname')
self.assertEqual(end[1], '318')
self.assertEqual(end[2][0], 'someguy')
self.assertEqual(end[2][1], 'otherguy')
self.assertEqual(end[2][2], 'End of WHOIS list.')
class TestMind(service.PBMind):
def __init__(self, *a, **kw):
self.joins = []
self.parts = []
self.messages = []
self.meta = []
def remote_userJoined(self, user, group):
self.joins.append((user, group))
def remote_userLeft(self, user, group, reason):
self.parts.append((user, group, reason))
def remote_receive(self, sender, recipient, message):
self.messages.append((sender, recipient, message))
def remote_groupMetaUpdate(self, group, meta):
self.meta.append((group, meta))
pb.setUnjellyableForClass(TestMind, service.PBMindReference)
class PBProtocolTests(unittest.TestCase):
def setUp(self):
self.realm = service.InMemoryWordsRealm("realmname")
self.checker = checkers.InMemoryUsernamePasswordDatabaseDontUse()
self.portal = portal.Portal(
self.realm, [self.checker])
self.serverFactory = pb.PBServerFactory(self.portal)
self.serverFactory.protocol = self._protocolFactory
self.serverFactory.unsafeTracebacks = True
self.clientFactory = pb.PBClientFactory()
self.clientFactory.unsafeTracebacks = True
self.serverPort = reactor.listenTCP(0, self.serverFactory)
self.clientConn = reactor.connectTCP(
'127.0.0.1',
self.serverPort.getHost().port,
self.clientFactory)
def _protocolFactory(self, *args, **kw):
self._serverProtocol = pb.Broker(0)
return self._serverProtocol
def tearDown(self):
d3 = Deferred()
self._serverProtocol.notifyOnDisconnect(lambda: d3.callback(None))
return DeferredList([
maybeDeferred(self.serverPort.stopListening),
maybeDeferred(self.clientConn.disconnect), d3])
def _loggedInAvatar(self, name, password, mind):
nameBytes = name
if isinstance(name, unicode):
nameBytes = name.encode("ascii")
creds = credentials.UsernamePassword(nameBytes, password)
self.checker.addUser(nameBytes, password)
d = self.realm.createUser(name)
d.addCallback(lambda ign: self.clientFactory.login(creds, mind))
return d
@defer.inlineCallbacks
def testGroups(self):
mindone = TestMind()
one = yield self._loggedInAvatar(u"one", b"p1", mindone)
mindtwo = TestMind()
two = yield self._loggedInAvatar(u"two", b"p2", mindtwo)
mindThree = TestMind()
three = yield self._loggedInAvatar(b"three", b"p3", mindThree)
yield self.realm.createGroup(u"foobar")
yield self.realm.createGroup(b"barfoo")
groupone = yield one.join(u"foobar")
grouptwo = yield two.join(b"barfoo")
yield two.join(u"foobar")
yield two.join(b"barfoo")
yield three.join(u"foobar")
yield groupone.send({b"text": b"hello, monkeys"})
yield groupone.leave()
yield grouptwo.leave()

View file

@ -0,0 +1,78 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
from twisted.cred import credentials, error
from twisted.words import tap
from twisted.trial import unittest
class WordsTapTests(unittest.TestCase):
"""
Ensures that the twisted.words.tap API works.
"""
PASSWD_TEXT = b"admin:admin\njoe:foo\n"
admin = credentials.UsernamePassword(b'admin', b'admin')
joeWrong = credentials.UsernamePassword(b'joe', b'bar')
def setUp(self):
"""
Create a file with two users.
"""
self.filename = self.mktemp()
self.file = open(self.filename, 'wb')
self.file.write(self.PASSWD_TEXT)
self.file.flush()
def tearDown(self):
"""
Close the dummy user database.
"""
self.file.close()
def test_hostname(self):
"""
Tests that the --hostname parameter gets passed to Options.
"""
opt = tap.Options()
opt.parseOptions(['--hostname', 'myhost'])
self.assertEqual(opt['hostname'], 'myhost')
def test_passwd(self):
"""
Tests the --passwd command for backwards-compatibility.
"""
opt = tap.Options()
opt.parseOptions(['--passwd', self.file.name])
self._loginTest(opt)
def test_auth(self):
"""
Tests that the --auth command generates a checker.
"""
opt = tap.Options()
opt.parseOptions(['--auth', 'file:'+self.file.name])
self._loginTest(opt)
def _loginTest(self, opt):
"""
This method executes both positive and negative authentication
tests against whatever credentials checker has been stored in
the Options class.
@param opt: An instance of L{tap.Options}.
"""
self.assertEqual(len(opt['credCheckers']), 1)
checker = opt['credCheckers'][0]
self.assertFailure(checker.requestAvatarId(self.joeWrong),
error.UnauthorizedLogin)
def _gotAvatar(username):
self.assertEqual(username, self.admin.username)
return checker.requestAvatarId(self.admin).addCallback(_gotAvatar)

View file

@ -0,0 +1,348 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Test cases for twisted.words.xish.utility
"""
from __future__ import absolute_import, division
from collections import OrderedDict
from twisted.trial import unittest
from twisted.words.xish import utility
from twisted.words.xish.domish import Element
from twisted.words.xish.utility import EventDispatcher
class CallbackTracker:
"""
Test helper for tracking callbacks.
Increases a counter on each call to L{call} and stores the object
passed in the call.
"""
def __init__(self):
self.called = 0
self.obj = None
def call(self, obj):
self.called = self.called + 1
self.obj = obj
class OrderedCallbackTracker:
"""
Test helper for tracking callbacks and their order.
"""
def __init__(self):
self.callList = []
def call1(self, object):
self.callList.append(self.call1)
def call2(self, object):
self.callList.append(self.call2)
def call3(self, object):
self.callList.append(self.call3)
class EventDispatcherTests(unittest.TestCase):
"""
Tests for L{EventDispatcher}.
"""
def testStuff(self):
d = EventDispatcher()
cb1 = CallbackTracker()
cb2 = CallbackTracker()
cb3 = CallbackTracker()
d.addObserver("/message/body", cb1.call)
d.addObserver("/message", cb1.call)
d.addObserver("/presence", cb2.call)
d.addObserver("//event/testevent", cb3.call)
msg = Element(("ns", "message"))
msg.addElement("body")
pres = Element(("ns", "presence"))
pres.addElement("presence")
d.dispatch(msg)
self.assertEqual(cb1.called, 2)
self.assertEqual(cb1.obj, msg)
self.assertEqual(cb2.called, 0)
d.dispatch(pres)
self.assertEqual(cb1.called, 2)
self.assertEqual(cb2.called, 1)
self.assertEqual(cb2.obj, pres)
self.assertEqual(cb3.called, 0)
d.dispatch(d, "//event/testevent")
self.assertEqual(cb3.called, 1)
self.assertEqual(cb3.obj, d)
d.removeObserver("/presence", cb2.call)
d.dispatch(pres)
self.assertEqual(cb2.called, 1)
def test_addObserverTwice(self):
"""
Test adding two observers for the same query.
When the event is dispatched both of the observers need to be called.
"""
d = EventDispatcher()
cb1 = CallbackTracker()
cb2 = CallbackTracker()
d.addObserver("//event/testevent", cb1.call)
d.addObserver("//event/testevent", cb2.call)
d.dispatch(d, "//event/testevent")
self.assertEqual(cb1.called, 1)
self.assertEqual(cb1.obj, d)
self.assertEqual(cb2.called, 1)
self.assertEqual(cb2.obj, d)
def test_addObserverInDispatch(self):
"""
Test for registration of an observer during dispatch.
"""
d = EventDispatcher()
msg = Element(("ns", "message"))
cb = CallbackTracker()
def onMessage(_):
d.addObserver("/message", cb.call)
d.addOnetimeObserver("/message", onMessage)
d.dispatch(msg)
self.assertEqual(cb.called, 0)
d.dispatch(msg)
self.assertEqual(cb.called, 1)
d.dispatch(msg)
self.assertEqual(cb.called, 2)
def test_addOnetimeObserverInDispatch(self):
"""
Test for registration of a onetime observer during dispatch.
"""
d = EventDispatcher()
msg = Element(("ns", "message"))
cb = CallbackTracker()
def onMessage(msg):
d.addOnetimeObserver("/message", cb.call)
d.addOnetimeObserver("/message", onMessage)
d.dispatch(msg)
self.assertEqual(cb.called, 0)
d.dispatch(msg)
self.assertEqual(cb.called, 1)
d.dispatch(msg)
self.assertEqual(cb.called, 1)
def testOnetimeDispatch(self):
d = EventDispatcher()
msg = Element(("ns", "message"))
cb = CallbackTracker()
d.addOnetimeObserver("/message", cb.call)
d.dispatch(msg)
self.assertEqual(cb.called, 1)
d.dispatch(msg)
self.assertEqual(cb.called, 1)
def testDispatcherResult(self):
d = EventDispatcher()
msg = Element(("ns", "message"))
pres = Element(("ns", "presence"))
cb = CallbackTracker()
d.addObserver("/presence", cb.call)
result = d.dispatch(msg)
self.assertEqual(False, result)
result = d.dispatch(pres)
self.assertEqual(True, result)
def testOrderedXPathDispatch(self):
d = EventDispatcher()
cb = OrderedCallbackTracker()
d.addObserver("/message/body", cb.call2)
d.addObserver("/message", cb.call3, -1)
d.addObserver("/message/body", cb.call1, 1)
msg = Element(("ns", "message"))
msg.addElement("body")
d.dispatch(msg)
self.assertEqual(cb.callList, [cb.call1, cb.call2, cb.call3],
"Calls out of order: %s" %
repr([c.__name__ for c in cb.callList]))
# Observers are put into CallbackLists that are then put into dictionaries
# keyed by the event trigger. Upon removal of the last observer for a
# particular event trigger, the (now empty) CallbackList and corresponding
# event trigger should be removed from those dictionaries to prevent
# slowdown and memory leakage.
def test_cleanUpRemoveEventObserver(self):
"""
Test observer clean-up after removeObserver for named events.
"""
d = EventDispatcher()
cb = CallbackTracker()
d.addObserver('//event/test', cb.call)
d.dispatch(None, '//event/test')
self.assertEqual(1, cb.called)
d.removeObserver('//event/test', cb.call)
self.assertEqual(0, len(d._eventObservers.pop(0)))
def test_cleanUpRemoveXPathObserver(self):
"""
Test observer clean-up after removeObserver for XPath events.
"""
d = EventDispatcher()
cb = CallbackTracker()
msg = Element((None, "message"))
d.addObserver('/message', cb.call)
d.dispatch(msg)
self.assertEqual(1, cb.called)
d.removeObserver('/message', cb.call)
self.assertEqual(0, len(d._xpathObservers.pop(0)))
def test_cleanUpOnetimeEventObserver(self):
"""
Test observer clean-up after onetime named events.
"""
d = EventDispatcher()
cb = CallbackTracker()
d.addOnetimeObserver('//event/test', cb.call)
d.dispatch(None, '//event/test')
self.assertEqual(1, cb.called)
self.assertEqual(0, len(d._eventObservers.pop(0)))
def test_cleanUpOnetimeXPathObserver(self):
"""
Test observer clean-up after onetime XPath events.
"""
d = EventDispatcher()
cb = CallbackTracker()
msg = Element((None, "message"))
d.addOnetimeObserver('/message', cb.call)
d.dispatch(msg)
self.assertEqual(1, cb.called)
self.assertEqual(0, len(d._xpathObservers.pop(0)))
def test_observerRaisingException(self):
"""
Test that exceptions in observers do not bubble up to dispatch.
The exceptions raised in observers should be logged and other
observers should be called as if nothing happened.
"""
class OrderedCallbackList(utility.CallbackList):
def __init__(self):
self.callbacks = OrderedDict()
class TestError(Exception):
pass
def raiseError(_):
raise TestError()
d = EventDispatcher()
cb = CallbackTracker()
originalCallbackList = utility.CallbackList
try:
utility.CallbackList = OrderedCallbackList
d.addObserver('//event/test', raiseError)
d.addObserver('//event/test', cb.call)
try:
d.dispatch(None, '//event/test')
except TestError:
self.fail("TestError raised. Should have been logged instead.")
self.assertEqual(1, len(self.flushLoggedErrors(TestError)))
self.assertEqual(1, cb.called)
finally:
utility.CallbackList = originalCallbackList
class XmlPipeTests(unittest.TestCase):
"""
Tests for L{twisted.words.xish.utility.XmlPipe}.
"""
def setUp(self):
self.pipe = utility.XmlPipe()
def test_sendFromSource(self):
"""
Send an element from the source and observe it from the sink.
"""
def cb(obj):
called.append(obj)
called = []
self.pipe.sink.addObserver('/test[@xmlns="testns"]', cb)
element = Element(('testns', 'test'))
self.pipe.source.send(element)
self.assertEqual([element], called)
def test_sendFromSink(self):
"""
Send an element from the sink and observe it from the source.
"""
def cb(obj):
called.append(obj)
called = []
self.pipe.source.addObserver('/test[@xmlns="testns"]', cb)
element = Element(('testns', 'test'))
self.pipe.sink.send(element)
self.assertEqual([element], called)

View file

@ -0,0 +1,226 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.words.xish.xmlstream}.
"""
from __future__ import absolute_import, division
from twisted.internet import protocol
from twisted.python import failure
from twisted.trial import unittest
from twisted.words.xish import domish, utility, xmlstream
class XmlStreamTests(unittest.TestCase):
def setUp(self):
self.connectionLostMsg = "no reason"
self.outlist = []
self.xmlstream = xmlstream.XmlStream()
self.xmlstream.transport = self
self.xmlstream.transport.write = self.outlist.append
def loseConnection(self):
"""
Stub loseConnection because we are a transport.
"""
self.xmlstream.connectionLost(failure.Failure(
Exception(self.connectionLostMsg)))
def test_send(self):
"""
Calling L{xmlstream.XmlStream.send} results in the data being written
to the transport.
"""
self.xmlstream.connectionMade()
self.xmlstream.send(b"<root>")
self.assertEqual(self.outlist[0], b"<root>")
def test_receiveRoot(self):
"""
Receiving the starttag of the root element results in stream start.
"""
streamStarted = []
def streamStartEvent(rootelem):
streamStarted.append(None)
self.xmlstream.addObserver(xmlstream.STREAM_START_EVENT,
streamStartEvent)
self.xmlstream.connectionMade()
self.xmlstream.dataReceived("<root>")
self.assertEqual(1, len(streamStarted))
def test_receiveBadXML(self):
"""
Receiving malformed XML results in an L{STREAM_ERROR_EVENT}.
"""
streamError = []
streamEnd = []
def streamErrorEvent(reason):
streamError.append(reason)
def streamEndEvent(_):
streamEnd.append(None)
self.xmlstream.addObserver(xmlstream.STREAM_ERROR_EVENT,
streamErrorEvent)
self.xmlstream.addObserver(xmlstream.STREAM_END_EVENT,
streamEndEvent)
self.xmlstream.connectionMade()
self.xmlstream.dataReceived("<root>")
self.assertEqual(0, len(streamError))
self.assertEqual(0, len(streamEnd))
self.xmlstream.dataReceived("<child><unclosed></child>")
self.assertEqual(1, len(streamError))
self.assertTrue(streamError[0].check(domish.ParserError))
self.assertEqual(1, len(streamEnd))
def test_streamEnd(self):
"""
Ending the stream fires a L{STREAM_END_EVENT}.
"""
streamEnd = []
def streamEndEvent(reason):
streamEnd.append(reason)
self.xmlstream.addObserver(xmlstream.STREAM_END_EVENT,
streamEndEvent)
self.xmlstream.connectionMade()
self.loseConnection()
self.assertEqual(1, len(streamEnd))
self.assertIsInstance(streamEnd[0], failure.Failure)
self.assertEqual(streamEnd[0].getErrorMessage(),
self.connectionLostMsg)
class DummyProtocol(protocol.Protocol, utility.EventDispatcher):
"""
I am a protocol with an event dispatcher without further processing.
This protocol is only used for testing XmlStreamFactoryMixin to make
sure the bootstrap observers are added to the protocol instance.
"""
def __init__(self, *args, **kwargs):
self.args = args
self.kwargs = kwargs
self.observers = []
utility.EventDispatcher.__init__(self)
class BootstrapMixinTests(unittest.TestCase):
"""
Tests for L{xmlstream.BootstrapMixin}.
@ivar factory: Instance of the factory or mixin under test.
"""
def setUp(self):
self.factory = xmlstream.BootstrapMixin()
def test_installBootstraps(self):
"""
Dispatching an event fires registered bootstrap observers.
"""
called = []
def cb(data):
called.append(data)
dispatcher = DummyProtocol()
self.factory.addBootstrap('//event/myevent', cb)
self.factory.installBootstraps(dispatcher)
dispatcher.dispatch(None, '//event/myevent')
self.assertEqual(1, len(called))
def test_addAndRemoveBootstrap(self):
"""
Test addition and removal of a bootstrap event handler.
"""
called = []
def cb(data):
called.append(data)
self.factory.addBootstrap('//event/myevent', cb)
self.factory.removeBootstrap('//event/myevent', cb)
dispatcher = DummyProtocol()
self.factory.installBootstraps(dispatcher)
dispatcher.dispatch(None, '//event/myevent')
self.assertFalse(called)
class GenericXmlStreamFactoryTestsMixin(BootstrapMixinTests):
"""
Generic tests for L{XmlStream} factories.
"""
def setUp(self):
self.factory = xmlstream.XmlStreamFactory()
def test_buildProtocolInstallsBootstraps(self):
"""
The protocol factory installs bootstrap event handlers on the protocol.
"""
called = []
def cb(data):
called.append(data)
self.factory.addBootstrap('//event/myevent', cb)
xs = self.factory.buildProtocol(None)
xs.dispatch(None, '//event/myevent')
self.assertEqual(1, len(called))
def test_buildProtocolStoresFactory(self):
"""
The protocol factory is saved in the protocol.
"""
xs = self.factory.buildProtocol(None)
self.assertIdentical(self.factory, xs.factory)
class XmlStreamFactoryMixinTests(GenericXmlStreamFactoryTestsMixin):
"""
Tests for L{xmlstream.XmlStreamFactoryMixin}.
"""
def setUp(self):
self.factory = xmlstream.XmlStreamFactoryMixin(None, test=None)
self.factory.protocol = DummyProtocol
def test_buildProtocolFactoryArguments(self):
"""
Arguments passed to the factory are passed to protocol on
instantiation.
"""
xs = self.factory.buildProtocol(None)
self.assertEqual((None,), xs.args)
self.assertEqual({'test': None}, xs.kwargs)

View file

@ -0,0 +1,84 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.words.xmpproutertap}.
"""
from twisted.application import internet
from twisted.trial import unittest
from twisted.words import xmpproutertap as tap
from twisted.words.protocols.jabber import component
class XMPPRouterTapTests(unittest.TestCase):
def test_port(self):
"""
The port option is recognised as a parameter.
"""
opt = tap.Options()
opt.parseOptions(['--port', '7001'])
self.assertEqual(opt['port'], '7001')
def test_portDefault(self):
"""
The port option has '5347' as default value
"""
opt = tap.Options()
opt.parseOptions([])
self.assertEqual(opt['port'], 'tcp:5347:interface=127.0.0.1')
def test_secret(self):
"""
The secret option is recognised as a parameter.
"""
opt = tap.Options()
opt.parseOptions(['--secret', 'hushhush'])
self.assertEqual(opt['secret'], 'hushhush')
def test_secretDefault(self):
"""
The secret option has 'secret' as default value
"""
opt = tap.Options()
opt.parseOptions([])
self.assertEqual(opt['secret'], 'secret')
def test_verbose(self):
"""
The verbose option is recognised as a flag.
"""
opt = tap.Options()
opt.parseOptions(['--verbose'])
self.assertTrue(opt['verbose'])
def test_makeService(self):
"""
The service gets set up with a router and factory.
"""
opt = tap.Options()
opt.parseOptions([])
s = tap.makeService(opt)
self.assertIsInstance(s, internet.StreamServerEndpointService)
self.assertEqual('127.0.0.1', s.endpoint._interface)
self.assertEqual(5347, s.endpoint._port)
factory = s.factory
self.assertIsInstance(factory, component.XMPPComponentServerFactory)
self.assertIsInstance(factory.router, component.Router)
self.assertEqual('secret', factory.secret)
self.assertFalse(factory.logTraffic)
def test_makeServiceVerbose(self):
"""
The verbose flag enables traffic logging.
"""
opt = tap.Options()
opt.parseOptions(['--verbose'])
s = tap.makeService(opt)
self.assertTrue(s.factory.logTraffic)

View file

@ -0,0 +1,298 @@
# -*- coding: utf-8 -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
from __future__ import absolute_import, division
from twisted.trial import unittest
from twisted.words.xish import xpath
from twisted.words.xish.domish import Element
from twisted.words.xish.xpath import XPathQuery
from twisted.words.xish.xpathparser import SyntaxError
class XPathTests(unittest.TestCase):
def setUp(self):
# Build element:
# <foo xmlns='testns' attrib1='value1' attrib3="user@host/resource">
# somecontent
# <bar>
# <foo>
# <gar>DEF</gar>
# </foo>
# </bar>
# somemorecontent
# <bar attrib2="value2">
# <bar>
# <foo/>
# <gar>ABC</gar>
# </bar>
# <bar/>
# <bar attrib4='value4' attrib5='value5'>
# <foo/>
# <gar>JKL</gar>
# </bar>
# <bar attrib4='value4' attrib5='value4'>
# <foo/>
# <gar>MNO</gar>
# </bar>
# <bar attrib4='value4' attrib5='value6' attrib6='á'>
# <quux>☃</quux>
# </bar>
# </foo>
self.e = Element(("testns", "foo"))
self.e["attrib1"] = "value1"
self.e["attrib3"] = "user@host/resource"
self.e.addContent(u"somecontent")
self.bar1 = self.e.addElement("bar")
self.subfoo = self.bar1.addElement("foo")
self.gar1 = self.subfoo.addElement("gar")
self.gar1.addContent(u"DEF")
self.e.addContent(u"somemorecontent")
self.bar2 = self.e.addElement("bar")
self.bar2["attrib2"] = "value2"
self.bar3 = self.bar2.addElement("bar")
self.subfoo2 = self.bar3.addElement("foo")
self.gar2 = self.bar3.addElement("gar")
self.gar2.addContent(u"ABC")
self.bar4 = self.e.addElement("bar")
self.bar5 = self.e.addElement("bar")
self.bar5["attrib4"] = "value4"
self.bar5["attrib5"] = "value5"
self.subfoo3 = self.bar5.addElement("foo")
self.gar3 = self.bar5.addElement("gar")
self.gar3.addContent(u"JKL")
self.bar6 = self.e.addElement("bar")
self.bar6["attrib4"] = "value4"
self.bar6["attrib5"] = "value4"
self.subfoo4 = self.bar6.addElement("foo")
self.gar4 = self.bar6.addElement("gar")
self.gar4.addContent(u"MNO")
self.bar7 = self.e.addElement("bar")
self.bar7["attrib4"] = "value4"
self.bar7["attrib5"] = "value6"
self.bar7["attrib6"] = u"á"
self.quux = self.bar7.addElement("quux")
self.quux.addContent(u"\N{SNOWMAN}")
def test_staticMethods(self):
"""
Test basic operation of the static methods.
"""
self.assertEqual(xpath.matches("/foo/bar", self.e),
True)
self.assertEqual(xpath.queryForNodes("/foo/bar", self.e),
[self.bar1, self.bar2, self.bar4,
self.bar5, self.bar6, self.bar7])
self.assertEqual(xpath.queryForString("/foo", self.e),
"somecontent")
self.assertEqual(xpath.queryForStringList("/foo", self.e),
["somecontent", "somemorecontent"])
def test_locationFooBar(self):
"""
Test matching foo with child bar.
"""
xp = XPathQuery("/foo/bar")
self.assertEqual(xp.matches(self.e), 1)
def test_locationFooBarFoo(self):
"""
Test finding foos at the second level.
"""
xp = XPathQuery("/foo/bar/foo")
self.assertEqual(xp.matches(self.e), 1)
self.assertEqual(xp.queryForNodes(self.e), [self.subfoo,
self.subfoo3,
self.subfoo4])
def test_locationNoBar3(self):
"""
Test not finding bar3.
"""
xp = XPathQuery("/foo/bar3")
self.assertEqual(xp.matches(self.e), 0)
def test_locationAllChilds(self):
"""
Test finding childs of foo.
"""
xp = XPathQuery("/foo/*")
self.assertEqual(xp.matches(self.e), True)
self.assertEqual(xp.queryForNodes(self.e), [self.bar1, self.bar2,
self.bar4, self.bar5,
self.bar6, self.bar7])
def test_attribute(self):
"""
Test matching foo with attribute.
"""
xp = XPathQuery("/foo[@attrib1]")
self.assertEqual(xp.matches(self.e), True)
def test_attributeWithValueAny(self):
"""
Test find nodes with attribute having value.
"""
xp = XPathQuery("/foo/*[@attrib2='value2']")
self.assertEqual(xp.matches(self.e), True)
self.assertEqual(xp.queryForNodes(self.e), [self.bar2])
def test_locationWithValueUnicode(self):
"""
Nodes' attributes can be matched with non-ASCII values.
"""
xp = XPathQuery(u"/foo/*[@attrib6='á']")
self.assertEqual(xp.matches(self.e), True)
self.assertEqual(xp.queryForNodes(self.e), [self.bar7])
def test_namespaceFound(self):
"""
Test matching node with namespace.
"""
xp = XPathQuery("/foo[@xmlns='testns']/bar")
self.assertEqual(xp.matches(self.e), 1)
def test_namespaceNotFound(self):
"""
Test not matching node with wrong namespace.
"""
xp = XPathQuery("/foo[@xmlns='badns']/bar2")
self.assertEqual(xp.matches(self.e), 0)
def test_attributeWithValue(self):
"""
Test matching node with attribute having value.
"""
xp = XPathQuery("/foo[@attrib1='value1']")
self.assertEqual(xp.matches(self.e), 1)
def test_queryForString(self):
"""
queryforString on absolute paths returns their first CDATA.
"""
xp = XPathQuery("/foo")
self.assertEqual(xp.queryForString(self.e), "somecontent")
def test_queryForStringList(self):
"""
queryforStringList on absolute paths returns all their CDATA.
"""
xp = XPathQuery("/foo")
self.assertEqual(xp.queryForStringList(self.e),
["somecontent", "somemorecontent"])
def test_queryForStringListAnyLocation(self):
"""
queryforStringList on relative paths returns all their CDATA.
"""
xp = XPathQuery("//foo")
self.assertEqual(xp.queryForStringList(self.e),
["somecontent", "somemorecontent"])
def test_queryForNodes(self):
"""
Test finding nodes.
"""
xp = XPathQuery("/foo/bar")
self.assertEqual(xp.queryForNodes(self.e), [self.bar1, self.bar2,
self.bar4, self.bar5,
self.bar6, self.bar7])
def test_textCondition(self):
"""
Test matching a node with given text.
"""
xp = XPathQuery("/foo[text() = 'somecontent']")
self.assertEqual(xp.matches(self.e), True)
def test_textConditionUnicode(self):
"""
A node can be matched by text with non-ascii code points.
"""
xp = XPathQuery(u"//*[text()='\N{SNOWMAN}']")
self.assertEqual(xp.matches(self.e), True)
self.assertEqual(xp.queryForNodes(self.e), [self.quux])
def test_textNotOperator(self):
"""
Test for not operator.
"""
xp = XPathQuery("/foo[not(@nosuchattrib)]")
self.assertEqual(xp.matches(self.e), True)
def test_anyLocationAndText(self):
"""
Test finding any nodes named gar and getting their text contents.
"""
xp = XPathQuery("//gar")
self.assertEqual(xp.matches(self.e), True)
self.assertEqual(xp.queryForNodes(self.e), [self.gar1, self.gar2,
self.gar3, self.gar4])
self.assertEqual(xp.queryForStringList(self.e), ["DEF", "ABC",
"JKL", "MNO"])
def test_anyLocation(self):
"""
Test finding any nodes named bar.
"""
xp = XPathQuery("//bar")
self.assertEqual(xp.matches(self.e), True)
self.assertEqual(xp.queryForNodes(self.e), [self.bar1, self.bar2,
self.bar3, self.bar4,
self.bar5, self.bar6,
self.bar7])
def test_anyLocationQueryForString(self):
"""
L{XPathQuery.queryForString} should raise a L{NotImplementedError}
for any location.
"""
xp = XPathQuery("//bar")
self.assertRaises(NotImplementedError, xp.queryForString, None)
def test_andOperator(self):
"""
Test boolean and operator in condition.
"""
xp = XPathQuery("//bar[@attrib4='value4' and @attrib5='value5']")
self.assertEqual(xp.matches(self.e), True)
self.assertEqual(xp.queryForNodes(self.e), [self.bar5])
def test_orOperator(self):
"""
Test boolean or operator in condition.
"""
xp = XPathQuery("//bar[@attrib5='value4' or @attrib5='value5']")
self.assertEqual(xp.matches(self.e), True)
self.assertEqual(xp.queryForNodes(self.e), [self.bar5, self.bar6])
def test_booleanOperatorsParens(self):
"""
Test multiple boolean operators in condition with parens.
"""
xp = XPathQuery("""//bar[@attrib4='value4' and
(@attrib5='value4' or @attrib5='value6')]""")
self.assertEqual(xp.matches(self.e), True)
self.assertEqual(xp.queryForNodes(self.e), [self.bar6, self.bar7])
def test_booleanOperatorsNoParens(self):
"""
Test multiple boolean operators in condition without parens.
"""
xp = XPathQuery("""//bar[@attrib5='value4' or
@attrib5='value5' or
@attrib5='value6']""")
self.assertEqual(xp.matches(self.e), True)
self.assertEqual(xp.queryForNodes(self.e), [self.bar5, self.bar6, self.bar7])
def test_badXPathNoClosingBracket(self):
"""
A missing closing bracket raises a SyntaxError.
This test excercises the most common failure mode.
"""
exc = self.assertRaises(SyntaxError, XPathQuery, """//bar[@attrib1""")
self.assertTrue(exc.msg.startswith("Trying to find one of"),
("SyntaxError message '%s' doesn't start with "
"'Trying to find one of'") % exc.msg)

View file

@ -0,0 +1,10 @@
# -*- test-case-name: twisted.words.test -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Twisted X-ish: XML-ish DOM and XPath-ish engine
"""

View file

@ -0,0 +1,899 @@
# -*- test-case-name: twisted.words.test.test_domish -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
DOM-like XML processing support.
This module provides support for parsing XML into DOM-like object structures
and serializing such structures to an XML string representation, optimized
for use in streaming XML applications.
"""
from __future__ import absolute_import, division
from zope.interface import implementer, Interface, Attribute
from twisted.python.compat import (_PY3, StringType, _coercedUnicode,
iteritems, itervalues, unicode)
def _splitPrefix(name):
""" Internal method for splitting a prefixed Element name into its
respective parts """
ntok = name.split(":", 1)
if len(ntok) == 2:
return ntok
else:
return (None, ntok[0])
# Global map of prefixes that always get injected
# into the serializers prefix map (note, that doesn't
# mean they're always _USED_)
G_PREFIXES = { "http://www.w3.org/XML/1998/namespace":"xml" }
class _ListSerializer:
""" Internal class which serializes an Element tree into a buffer """
def __init__(self, prefixes=None, prefixesInScope=None):
self.writelist = []
self.prefixes = {}
if prefixes:
self.prefixes.update(prefixes)
self.prefixes.update(G_PREFIXES)
self.prefixStack = [G_PREFIXES.values()] + (prefixesInScope or [])
self.prefixCounter = 0
def getValue(self):
return u"".join(self.writelist)
def getPrefix(self, uri):
if uri not in self.prefixes:
self.prefixes[uri] = "xn%d" % (self.prefixCounter)
self.prefixCounter = self.prefixCounter + 1
return self.prefixes[uri]
def prefixInScope(self, prefix):
stack = self.prefixStack
for i in range(-1, (len(self.prefixStack)+1) * -1, -1):
if prefix in stack[i]:
return True
return False
def serialize(self, elem, closeElement=1, defaultUri=''):
# Optimization shortcuts
write = self.writelist.append
# Shortcut, check to see if elem is actually a chunk o' serialized XML
if isinstance(elem, SerializedXML):
write(elem)
return
# Shortcut, check to see if elem is actually a string (aka Cdata)
if isinstance(elem, StringType):
write(escapeToXml(elem))
return
# Further optimizations
name = elem.name
uri = elem.uri
defaultUri, currentDefaultUri = elem.defaultUri, defaultUri
for p, u in iteritems(elem.localPrefixes):
self.prefixes[u] = p
self.prefixStack.append(list(elem.localPrefixes.keys()))
# Inherit the default namespace
if defaultUri is None:
defaultUri = currentDefaultUri
if uri is None:
uri = defaultUri
prefix = None
if uri != defaultUri or uri in self.prefixes:
prefix = self.getPrefix(uri)
inScope = self.prefixInScope(prefix)
# Create the starttag
if not prefix:
write("<%s" % (name))
else:
write("<%s:%s" % (prefix, name))
if not inScope:
write(" xmlns:%s='%s'" % (prefix, uri))
self.prefixStack[-1].append(prefix)
inScope = True
if defaultUri != currentDefaultUri and \
(uri != defaultUri or not prefix or not inScope):
write(" xmlns='%s'" % (defaultUri))
for p, u in iteritems(elem.localPrefixes):
write(" xmlns:%s='%s'" % (p, u))
# Serialize attributes
for k,v in elem.attributes.items():
# If the attribute name is a tuple, it's a qualified attribute
if isinstance(k, tuple):
attr_uri, attr_name = k
attr_prefix = self.getPrefix(attr_uri)
if not self.prefixInScope(attr_prefix):
write(" xmlns:%s='%s'" % (attr_prefix, attr_uri))
self.prefixStack[-1].append(attr_prefix)
write(" %s:%s='%s'" % (attr_prefix, attr_name,
escapeToXml(v, 1)))
else:
write((" %s='%s'" % ( k, escapeToXml(v, 1))))
# Shortcut out if this is only going to return
# the element (i.e. no children)
if closeElement == 0:
write(">")
return
# Serialize children
if len(elem.children) > 0:
write(">")
for c in elem.children:
self.serialize(c, defaultUri=defaultUri)
# Add closing tag
if not prefix:
write("</%s>" % (name))
else:
write("</%s:%s>" % (prefix, name))
else:
write("/>")
self.prefixStack.pop()
SerializerClass = _ListSerializer
def escapeToXml(text, isattrib = 0):
""" Escape text to proper XML form, per section 2.3 in the XML specification.
@type text: C{str}
@param text: Text to escape
@type isattrib: C{bool}
@param isattrib: Triggers escaping of characters necessary for use as
attribute values
"""
text = text.replace("&", "&amp;")
text = text.replace("<", "&lt;")
text = text.replace(">", "&gt;")
if isattrib == 1:
text = text.replace("'", "&apos;")
text = text.replace("\"", "&quot;")
return text
def unescapeFromXml(text):
text = text.replace("&lt;", "<")
text = text.replace("&gt;", ">")
text = text.replace("&apos;", "'")
text = text.replace("&quot;", "\"")
text = text.replace("&amp;", "&")
return text
def generateOnlyInterface(list, int):
""" Filters items in a list by class
"""
for n in list:
if int.providedBy(n):
yield n
def generateElementsQNamed(list, name, uri):
""" Filters Element items in a list with matching name and URI. """
for n in list:
if IElement.providedBy(n) and n.name == name and n.uri == uri:
yield n
def generateElementsNamed(list, name):
""" Filters Element items in a list with matching name, regardless of URI.
"""
for n in list:
if IElement.providedBy(n) and n.name == name:
yield n
class SerializedXML(unicode):
""" Marker class for pre-serialized XML in the DOM. """
pass
class Namespace:
""" Convenience object for tracking namespace declarations. """
def __init__(self, uri):
self._uri = uri
def __getattr__(self, n):
return (self._uri, n)
def __getitem__(self, n):
return (self._uri, n)
class IElement(Interface):
"""
Interface to XML element nodes.
See L{Element} for a detailed example of its general use.
Warning: this Interface is not yet complete!
"""
uri = Attribute(""" Element's namespace URI """)
name = Attribute(""" Element's local name """)
defaultUri = Attribute(""" Default namespace URI of child elements """)
attributes = Attribute(""" Dictionary of element attributes """)
children = Attribute(""" List of child nodes """)
parent = Attribute(""" Reference to element's parent element """)
localPrefixes = Attribute(""" Dictionary of local prefixes """)
def toXml(prefixes=None, closeElement=1, defaultUri='',
prefixesInScope=None):
""" Serializes object to a (partial) XML document
@param prefixes: dictionary that maps namespace URIs to suggested
prefix names.
@type prefixes: L{dict}
@param closeElement: flag that determines whether to include the
closing tag of the element in the serialized string. A value of
C{0} only generates the element's start tag. A value of C{1} yields
a complete serialization.
@type closeElement: L{int}
@param defaultUri: Initial default namespace URI. This is most useful
for partial rendering, where the logical parent element (of which
the starttag was already serialized) declares a default namespace
that should be inherited.
@type defaultUri: L{unicode}
@param prefixesInScope: list of prefixes that are assumed to be
declared by ancestors.
@type prefixesInScope: C{list}
@return: (partial) serialized XML
@rtype: C{unicode}
"""
def addElement(name, defaultUri=None, content=None):
"""
Create an element and add as child.
The new element is added to this element as a child, and will have
this element as its parent.
@param name: element name. This can be either a L{unicode} object that
contains the local name, or a tuple of (uri, local_name) for a
fully qualified name. In the former case, the namespace URI is
inherited from this element.
@type name: L{unicode} or L{tuple} of (L{unicode}, L{unicode})
@param defaultUri: default namespace URI for child elements. If
L{None}, this is inherited from this element.
@type defaultUri: L{unicode}
@param content: text contained by the new element.
@type content: L{unicode}
@return: the created element
@rtype: object providing L{IElement}
"""
def addChild(node):
"""
Adds a node as child of this element.
The C{node} will be added to the list of childs of this element, and
will have this element set as its parent when C{node} provides
L{IElement}. If C{node} is a L{unicode} and the current last child is
character data (L{unicode}), the text from C{node} is appended to the
existing last child.
@param node: the child node.
@type node: L{unicode} or object implementing L{IElement}
"""
def addContent(text):
"""
Adds character data to this element.
If the current last child of this element is a string, the text will
be appended to that string. Otherwise, the text will be added as a new
child.
@param text: The character data to be added to this element.
@type text: L{unicode}
"""
@implementer(IElement)
class Element(object):
""" Represents an XML element node.
An Element contains a series of attributes (name/value pairs), content
(character data), and other child Element objects. When building a document
with markup (such as HTML or XML), use this object as the starting point.
Element objects fully support XML Namespaces. The fully qualified name of
the XML Element it represents is stored in the C{uri} and C{name}
attributes, where C{uri} holds the namespace URI. There is also a default
namespace, for child elements. This is stored in the C{defaultUri}
attribute. Note that C{''} means the empty namespace.
Serialization of Elements through C{toXml()} will use these attributes
for generating proper serialized XML. When both C{uri} and C{defaultUri}
are not None in the Element and all of its descendents, serialization
proceeds as expected:
>>> from twisted.words.xish import domish
>>> root = domish.Element(('myns', 'root'))
>>> root.addElement('child', content='test')
<twisted.words.xish.domish.Element object at 0x83002ac>
>>> root.toXml()
u"<root xmlns='myns'><child>test</child></root>"
For partial serialization, needed for streaming XML, a special value for
namespace URIs can be used: L{None}.
Using L{None} as the value for C{uri} means: this element is in whatever
namespace inherited by the closest logical ancestor when the complete XML
document has been serialized. The serialized start tag will have a
non-prefixed name, and no xmlns declaration will be generated.
Similarly, L{None} for C{defaultUri} means: the default namespace for my
child elements is inherited from the logical ancestors of this element,
when the complete XML document has been serialized.
To illustrate, an example from a Jabber stream. Assume the start tag of the
root element of the stream has already been serialized, along with several
complete child elements, and sent off, looking like this::
<stream:stream xmlns:stream='http://etherx.jabber.org/streams'
xmlns='jabber:client' to='example.com'>
...
Now suppose we want to send a complete element represented by an
object C{message} created like:
>>> message = domish.Element((None, 'message'))
>>> message['to'] = 'user@example.com'
>>> message.addElement('body', content='Hi!')
<twisted.words.xish.domish.Element object at 0x8276e8c>
>>> message.toXml()
u"<message to='user@example.com'><body>Hi!</body></message>"
As, you can see, this XML snippet has no xmlns declaration. When sent
off, it inherits the C{jabber:client} namespace from the root element.
Note that this renders the same as using C{''} instead of L{None}:
>>> presence = domish.Element(('', 'presence'))
>>> presence.toXml()
u"<presence/>"
However, if this object has a parent defined, the difference becomes
clear:
>>> child = message.addElement(('http://example.com/', 'envelope'))
>>> child.addChild(presence)
<twisted.words.xish.domish.Element object at 0x8276fac>
>>> message.toXml()
u"<message to='user@example.com'><body>Hi!</body><envelope xmlns='http://example.com/'><presence xmlns=''/></envelope></message>"
As, you can see, the <presence/> element is now in the empty namespace, not
in the default namespace of the parent or the streams'.
@type uri: C{unicode} or None
@ivar uri: URI of this Element's name
@type name: C{unicode}
@ivar name: Name of this Element
@type defaultUri: C{unicode} or None
@ivar defaultUri: URI this Element exists within
@type children: C{list}
@ivar children: List of child Elements and content
@type parent: L{Element}
@ivar parent: Reference to the parent Element, if any.
@type attributes: L{dict}
@ivar attributes: Dictionary of attributes associated with this Element.
@type localPrefixes: L{dict}
@ivar localPrefixes: Dictionary of namespace declarations on this
element. The key is the prefix to bind the
namespace uri to.
"""
_idCounter = 0
def __init__(self, qname, defaultUri=None, attribs=None,
localPrefixes=None):
"""
@param qname: Tuple of (uri, name)
@param defaultUri: The default URI of the element; defaults to the URI
specified in C{qname}
@param attribs: Dictionary of attributes
@param localPrefixes: Dictionary of namespace declarations on this
element. The key is the prefix to bind the
namespace uri to.
"""
self.localPrefixes = localPrefixes or {}
self.uri, self.name = qname
if defaultUri is None and \
self.uri not in itervalues(self.localPrefixes):
self.defaultUri = self.uri
else:
self.defaultUri = defaultUri
self.attributes = attribs or {}
self.children = []
self.parent = None
def __getattr__(self, key):
# Check child list for first Element with a name matching the key
for n in self.children:
if IElement.providedBy(n) and n.name == key:
return n
# Tweak the behaviour so that it's more friendly about not
# finding elements -- we need to document this somewhere :)
if key.startswith('_'):
raise AttributeError(key)
else:
return None
def __getitem__(self, key):
return self.attributes[self._dqa(key)]
def __delitem__(self, key):
del self.attributes[self._dqa(key)];
def __setitem__(self, key, value):
self.attributes[self._dqa(key)] = value
def __unicode__(self):
"""
Retrieve the first CData (content) node
"""
for n in self.children:
if isinstance(n, StringType):
return n
return u""
def __bytes__(self):
"""
Retrieve the first character data node as UTF-8 bytes.
"""
return unicode(self).encode('utf-8')
if _PY3:
__str__ = __unicode__
else:
__str__ = __bytes__
def _dqa(self, attr):
""" Dequalify an attribute key as needed """
if isinstance(attr, tuple) and not attr[0]:
return attr[1]
else:
return attr
def getAttribute(self, attribname, default = None):
""" Retrieve the value of attribname, if it exists """
return self.attributes.get(attribname, default)
def hasAttribute(self, attrib):
""" Determine if the specified attribute exists """
return self._dqa(attrib) in self.attributes
def compareAttribute(self, attrib, value):
""" Safely compare the value of an attribute against a provided value.
L{None}-safe.
"""
return self.attributes.get(self._dqa(attrib), None) == value
def swapAttributeValues(self, left, right):
""" Swap the values of two attribute. """
d = self.attributes
l = d[left]
d[left] = d[right]
d[right] = l
def addChild(self, node):
""" Add a child to this Element. """
if IElement.providedBy(node):
node.parent = self
self.children.append(node)
return node
def addContent(self, text):
""" Add some text data to this Element. """
text = _coercedUnicode(text)
c = self.children
if len(c) > 0 and isinstance(c[-1], unicode):
c[-1] = c[-1] + text
else:
c.append(text)
return c[-1]
def addElement(self, name, defaultUri = None, content = None):
if isinstance(name, tuple):
if defaultUri is None:
defaultUri = name[0]
child = Element(name, defaultUri)
else:
if defaultUri is None:
defaultUri = self.defaultUri
child = Element((defaultUri, name), defaultUri)
self.addChild(child)
if content:
child.addContent(content)
return child
def addRawXml(self, rawxmlstring):
""" Add a pre-serialized chunk o' XML as a child of this Element. """
self.children.append(SerializedXML(rawxmlstring))
def addUniqueId(self):
""" Add a unique (across a given Python session) id attribute to this
Element.
"""
self.attributes["id"] = "H_%d" % Element._idCounter
Element._idCounter = Element._idCounter + 1
def elements(self, uri=None, name=None):
"""
Iterate across all children of this Element that are Elements.
Returns a generator over the child elements. If both the C{uri} and
C{name} parameters are set, the returned generator will only yield
on elements matching the qualified name.
@param uri: Optional element URI.
@type uri: C{unicode}
@param name: Optional element name.
@type name: C{unicode}
@return: Iterator that yields objects implementing L{IElement}.
"""
if name is None:
return generateOnlyInterface(self.children, IElement)
else:
return generateElementsQNamed(self.children, name, uri)
def toXml(self, prefixes=None, closeElement=1, defaultUri='',
prefixesInScope=None):
""" Serialize this Element and all children to a string. """
s = SerializerClass(prefixes=prefixes, prefixesInScope=prefixesInScope)
s.serialize(self, closeElement=closeElement, defaultUri=defaultUri)
return s.getValue()
def firstChildElement(self):
for c in self.children:
if IElement.providedBy(c):
return c
return None
class ParserError(Exception):
""" Exception thrown when a parsing error occurs """
pass
def elementStream():
""" Preferred method to construct an ElementStream
Uses Expat-based stream if available, and falls back to Sux if necessary.
"""
try:
es = ExpatElementStream()
return es
except ImportError:
if SuxElementStream is None:
raise Exception("No parsers available :(")
es = SuxElementStream()
return es
try:
from twisted.web import sux
except:
SuxElementStream = None
else:
class SuxElementStream(sux.XMLParser):
def __init__(self):
self.connectionMade()
self.DocumentStartEvent = None
self.ElementEvent = None
self.DocumentEndEvent = None
self.currElem = None
self.rootElem = None
self.documentStarted = False
self.defaultNsStack = []
self.prefixStack = []
def parse(self, buffer):
try:
self.dataReceived(buffer)
except sux.ParseError as e:
raise ParserError(str(e))
def findUri(self, prefix):
# Walk prefix stack backwards, looking for the uri
# matching the specified prefix
stack = self.prefixStack
for i in range(-1, (len(self.prefixStack)+1) * -1, -1):
if prefix in stack[i]:
return stack[i][prefix]
return None
def gotTagStart(self, name, attributes):
defaultUri = None
localPrefixes = {}
attribs = {}
uri = None
# Pass 1 - Identify namespace decls
for k, v in list(attributes.items()):
if k.startswith("xmlns"):
x, p = _splitPrefix(k)
if (x is None): # I.e. default declaration
defaultUri = v
else:
localPrefixes[p] = v
del attributes[k]
# Push namespace decls onto prefix stack
self.prefixStack.append(localPrefixes)
# Determine default namespace for this element; if there
# is one
if defaultUri is None:
if len(self.defaultNsStack) > 0:
defaultUri = self.defaultNsStack[-1]
else:
defaultUri = ''
# Fix up name
prefix, name = _splitPrefix(name)
if prefix is None: # This element is in the default namespace
uri = defaultUri
else:
# Find the URI for the prefix
uri = self.findUri(prefix)
# Pass 2 - Fix up and escape attributes
for k, v in attributes.items():
p, n = _splitPrefix(k)
if p is None:
attribs[n] = v
else:
attribs[(self.findUri(p)), n] = unescapeFromXml(v)
# Construct the actual Element object
e = Element((uri, name), defaultUri, attribs, localPrefixes)
# Save current default namespace
self.defaultNsStack.append(defaultUri)
# Document already started
if self.documentStarted:
# Starting a new packet
if self.currElem is None:
self.currElem = e
# Adding to existing element
else:
self.currElem = self.currElem.addChild(e)
# New document
else:
self.rootElem = e
self.documentStarted = True
self.DocumentStartEvent(e)
def gotText(self, data):
if self.currElem != None:
if isinstance(data, bytes):
data = data.decode('ascii')
self.currElem.addContent(data)
def gotCData(self, data):
if self.currElem != None:
if isinstance(data, bytes):
data = data.decode('ascii')
self.currElem.addContent(data)
def gotComment(self, data):
# Ignore comments for the moment
pass
entities = { "amp" : "&",
"lt" : "<",
"gt" : ">",
"apos": "'",
"quot": "\"" }
def gotEntityReference(self, entityRef):
# If this is an entity we know about, add it as content
# to the current element
if entityRef in SuxElementStream.entities:
data = SuxElementStream.entities[entityRef]
if isinstance(data, bytes):
data = data.decode('ascii')
self.currElem.addContent(data)
def gotTagEnd(self, name):
# Ensure the document hasn't already ended
if self.rootElem is None:
# XXX: Write more legible explanation
raise ParserError("Element closed after end of document.")
# Fix up name
prefix, name = _splitPrefix(name)
if prefix is None:
uri = self.defaultNsStack[-1]
else:
uri = self.findUri(prefix)
# End of document
if self.currElem is None:
# Ensure element name and uri matches
if self.rootElem.name != name or self.rootElem.uri != uri:
raise ParserError("Mismatched root elements")
self.DocumentEndEvent()
self.rootElem = None
# Other elements
else:
# Ensure the tag being closed matches the name of the current
# element
if self.currElem.name != name or self.currElem.uri != uri:
# XXX: Write more legible explanation
raise ParserError("Malformed element close")
# Pop prefix and default NS stack
self.prefixStack.pop()
self.defaultNsStack.pop()
# Check for parent null parent of current elem;
# that's the top of the stack
if self.currElem.parent is None:
self.currElem.parent = self.rootElem
self.ElementEvent(self.currElem)
self.currElem = None
# Anything else is just some element wrapping up
else:
self.currElem = self.currElem.parent
class ExpatElementStream:
def __init__(self):
import pyexpat
self.DocumentStartEvent = None
self.ElementEvent = None
self.DocumentEndEvent = None
self.error = pyexpat.error
self.parser = pyexpat.ParserCreate("UTF-8", " ")
self.parser.StartElementHandler = self._onStartElement
self.parser.EndElementHandler = self._onEndElement
self.parser.CharacterDataHandler = self._onCdata
self.parser.StartNamespaceDeclHandler = self._onStartNamespace
self.parser.EndNamespaceDeclHandler = self._onEndNamespace
self.currElem = None
self.defaultNsStack = ['']
self.documentStarted = 0
self.localPrefixes = {}
def parse(self, buffer):
try:
self.parser.Parse(buffer)
except self.error as e:
raise ParserError(str(e))
def _onStartElement(self, name, attrs):
# Generate a qname tuple from the provided name. See
# http://docs.python.org/library/pyexpat.html#xml.parsers.expat.ParserCreate
# for an explanation of the formatting of name.
qname = name.rsplit(" ", 1)
if len(qname) == 1:
qname = ('', name)
# Process attributes
newAttrs = {}
toDelete = []
for k, v in attrs.items():
if " " in k:
aqname = k.rsplit(" ", 1)
newAttrs[(aqname[0], aqname[1])] = v
toDelete.append(k)
attrs.update(newAttrs)
for k in toDelete:
del attrs[k]
# Construct the new element
e = Element(qname, self.defaultNsStack[-1], attrs, self.localPrefixes)
self.localPrefixes = {}
# Document already started
if self.documentStarted == 1:
if self.currElem != None:
self.currElem.children.append(e)
e.parent = self.currElem
self.currElem = e
# New document
else:
self.documentStarted = 1
self.DocumentStartEvent(e)
def _onEndElement(self, _):
# Check for null current elem; end of doc
if self.currElem is None:
self.DocumentEndEvent()
# Check for parent that is None; that's
# the top of the stack
elif self.currElem.parent is None:
self.ElementEvent(self.currElem)
self.currElem = None
# Anything else is just some element in the current
# packet wrapping up
else:
self.currElem = self.currElem.parent
def _onCdata(self, data):
if self.currElem != None:
self.currElem.addContent(data)
def _onStartNamespace(self, prefix, uri):
# If this is the default namespace, put
# it on the stack
if prefix is None:
self.defaultNsStack.append(uri)
else:
self.localPrefixes[prefix] = uri
def _onEndNamespace(self, prefix):
# Remove last element on the stack
if prefix is None:
self.defaultNsStack.pop()
## class FileParser(ElementStream):
## def __init__(self):
## ElementStream.__init__(self)
## self.DocumentStartEvent = self.docStart
## self.ElementEvent = self.elem
## self.DocumentEndEvent = self.docEnd
## self.done = 0
## def docStart(self, elem):
## self.document = elem
## def elem(self, elem):
## self.document.addChild(elem)
## def docEnd(self):
## self.done = 1
## def parse(self, filename):
## with open(filename) as f:
## for l in f.readlines():
## self.parser.Parse(l)
## assert self.done == 1
## return self.document
## def parseFile(filename):
## return FileParser().parse(filename)

View file

@ -0,0 +1,375 @@
# -*- test-case-name: twisted.words.test.test_xishutil -*-
#
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Event Dispatching and Callback utilities.
"""
from __future__ import absolute_import, division
from twisted.python import log
from twisted.python.compat import iteritems
from twisted.words.xish import xpath
class _MethodWrapper(object):
"""
Internal class for tracking method calls.
"""
def __init__(self, method, *args, **kwargs):
self.method = method
self.args = args
self.kwargs = kwargs
def __call__(self, *args, **kwargs):
nargs = self.args + args
nkwargs = self.kwargs.copy()
nkwargs.update(kwargs)
self.method(*nargs, **nkwargs)
class CallbackList:
"""
Container for callbacks.
Event queries are linked to lists of callables. When a matching event
occurs, these callables are called in sequence. One-time callbacks
are removed from the list after the first time the event was triggered.
Arguments to callbacks are split spread across two sets. The first set,
callback specific, is passed to C{addCallback} and is used for all
subsequent event triggers. The second set is passed to C{callback} and is
event specific. Positional arguments in the second set come after the
positional arguments of the first set. Keyword arguments in the second set
override those in the first set.
@ivar callbacks: The registered callbacks as mapping from the callable to a
tuple of a wrapper for that callable that keeps the
callback specific arguments and a boolean that signifies
if it is to be called only once.
@type callbacks: C{dict}
"""
def __init__(self):
self.callbacks = {}
def addCallback(self, onetime, method, *args, **kwargs):
"""
Add callback.
The arguments passed are used as callback specific arguments.
@param onetime: If C{True}, this callback is called at most once.
@type onetime: C{bool}
@param method: The callback callable to be added.
@param args: Positional arguments to the callable.
@type args: C{list}
@param kwargs: Keyword arguments to the callable.
@type kwargs: C{dict}
"""
if not method in self.callbacks:
self.callbacks[method] = (_MethodWrapper(method, *args, **kwargs),
onetime)
def removeCallback(self, method):
"""
Remove callback.
@param method: The callable to be removed.
"""
if method in self.callbacks:
del self.callbacks[method]
def callback(self, *args, **kwargs):
"""
Call all registered callbacks.
The passed arguments are event specific and augment and override
the callback specific arguments as described above.
@note: Exceptions raised by callbacks are trapped and logged. They will
not propagate up to make sure other callbacks will still be
called, and the event dispatching always succeeds.
@param args: Positional arguments to the callable.
@type args: C{list}
@param kwargs: Keyword arguments to the callable.
@type kwargs: C{dict}
"""
for key, (methodwrapper, onetime) in list(self.callbacks.items()):
try:
methodwrapper(*args, **kwargs)
except:
log.err()
if onetime:
del self.callbacks[key]
def isEmpty(self):
"""
Return if list of registered callbacks is empty.
@rtype: C{bool}
"""
return len(self.callbacks) == 0
class EventDispatcher:
"""
Event dispatching service.
The C{EventDispatcher} allows observers to be registered for certain events
that are dispatched. There are two types of events: XPath events and Named
events.
Every dispatch is triggered by calling L{dispatch} with a data object and,
for named events, the name of the event.
When an XPath type event is dispatched, the associated object is assumed to
be an L{Element<twisted.words.xish.domish.Element>} instance, which is
matched against all registered XPath queries. For every match, the
respective observer will be called with the data object.
A named event will simply call each registered observer for that particular
event name, with the data object. Unlike XPath type events, the data object
is not restricted to L{Element<twisted.words.xish.domish.Element>}, but can
be anything.
When registering observers, the event that is to be observed is specified
using an L{xpath.XPathQuery} instance or a string. In the latter case, the
string can also contain the string representation of an XPath expression.
To distinguish these from named events, each named event should start with
a special prefix that is stored in C{self.prefix}. It defaults to
C{//event/}.
Observers registered using L{addObserver} are persistent: after the
observer has been triggered by a dispatch, it remains registered for a
possible next dispatch. If instead L{addOnetimeObserver} was used to
observe an event, the observer is removed from the list of observers after
the first observed event.
Observers can also be prioritized, by providing an optional C{priority}
parameter to the L{addObserver} and L{addOnetimeObserver} methods. Higher
priority observers are then called before lower priority observers.
Finally, observers can be unregistered by using L{removeObserver}.
"""
def __init__(self, eventprefix="//event/"):
self.prefix = eventprefix
self._eventObservers = {}
self._xpathObservers = {}
self._dispatchDepth = 0 # Flag indicating levels of dispatching
# in progress
self._updateQueue = [] # Queued updates for observer ops
def _getEventAndObservers(self, event):
if isinstance(event, xpath.XPathQuery):
# Treat as xpath
observers = self._xpathObservers
else:
if self.prefix == event[:len(self.prefix)]:
# Treat as event
observers = self._eventObservers
else:
# Treat as xpath
event = xpath.internQuery(event)
observers = self._xpathObservers
return event, observers
def addOnetimeObserver(self, event, observerfn, priority=0, *args, **kwargs):
"""
Register a one-time observer for an event.
Like L{addObserver}, but is only triggered at most once. See there
for a description of the parameters.
"""
self._addObserver(True, event, observerfn, priority, *args, **kwargs)
def addObserver(self, event, observerfn, priority=0, *args, **kwargs):
"""
Register an observer for an event.
Each observer will be registered with a certain priority. Higher
priority observers get called before lower priority observers.
@param event: Name or XPath query for the event to be monitored.
@type event: C{str} or L{xpath.XPathQuery}.
@param observerfn: Function to be called when the specified event
has been triggered. This callable takes
one parameter: the data object that triggered
the event. When specified, the C{*args} and
C{**kwargs} parameters to addObserver are being used
as additional parameters to the registered observer
callable.
@param priority: (Optional) priority of this observer in relation to
other observer that match the same event. Defaults to
C{0}.
@type priority: C{int}
"""
self._addObserver(False, event, observerfn, priority, *args, **kwargs)
def _addObserver(self, onetime, event, observerfn, priority, *args, **kwargs):
# If this is happening in the middle of the dispatch, queue
# it up for processing after the dispatch completes
if self._dispatchDepth > 0:
self._updateQueue.append(lambda:self._addObserver(onetime, event, observerfn, priority, *args, **kwargs))
return
event, observers = self._getEventAndObservers(event)
if priority not in observers:
cbl = CallbackList()
observers[priority] = {event: cbl}
else:
priorityObservers = observers[priority]
if event not in priorityObservers:
cbl = CallbackList()
observers[priority][event] = cbl
else:
cbl = priorityObservers[event]
cbl.addCallback(onetime, observerfn, *args, **kwargs)
def removeObserver(self, event, observerfn):
"""
Remove callable as observer for an event.
The observer callable is removed for all priority levels for the
specified event.
@param event: Event for which the observer callable was registered.
@type event: C{str} or L{xpath.XPathQuery}
@param observerfn: Observer callable to be unregistered.
"""
# If this is happening in the middle of the dispatch, queue
# it up for processing after the dispatch completes
if self._dispatchDepth > 0:
self._updateQueue.append(lambda:self.removeObserver(event, observerfn))
return
event, observers = self._getEventAndObservers(event)
emptyLists = []
for priority, priorityObservers in iteritems(observers):
for query, callbacklist in iteritems(priorityObservers):
if event == query:
callbacklist.removeCallback(observerfn)
if callbacklist.isEmpty():
emptyLists.append((priority, query))
for priority, query in emptyLists:
del observers[priority][query]
def dispatch(self, obj, event=None):
"""
Dispatch an event.
When C{event} is L{None}, an XPath type event is triggered, and
C{obj} is assumed to be an instance of
L{Element<twisted.words.xish.domish.Element>}. Otherwise, C{event}
holds the name of the named event being triggered. In the latter case,
C{obj} can be anything.
@param obj: The object to be dispatched.
@param event: Optional event name.
@type event: C{str}
"""
foundTarget = False
self._dispatchDepth += 1
if event != None:
# Named event
observers = self._eventObservers
match = lambda query, obj: query == event
else:
# XPath event
observers = self._xpathObservers
match = lambda query, obj: query.matches(obj)
priorities = list(observers.keys())
priorities.sort()
priorities.reverse()
emptyLists = []
for priority in priorities:
for query, callbacklist in iteritems(observers[priority]):
if match(query, obj):
callbacklist.callback(obj)
foundTarget = True
if callbacklist.isEmpty():
emptyLists.append((priority, query))
for priority, query in emptyLists:
del observers[priority][query]
self._dispatchDepth -= 1
# If this is a dispatch within a dispatch, don't
# do anything with the updateQueue -- it needs to
# wait until we've back all the way out of the stack
if self._dispatchDepth == 0:
# Deal with pending update operations
for f in self._updateQueue:
f()
self._updateQueue = []
return foundTarget
class XmlPipe(object):
"""
XML stream pipe.
Connects two objects that communicate stanzas through an XML stream like
interface. Each of the ends of the pipe (sink and source) can be used to
send XML stanzas to the other side, or add observers to process XML stanzas
that were sent from the other side.
XML pipes are usually used in place of regular XML streams that are
transported over TCP. This is the reason for the use of the names source
and sink for both ends of the pipe. The source side corresponds with the
entity that initiated the TCP connection, whereas the sink corresponds with
the entity that accepts that connection. In this object, though, the source
and sink are treated equally.
Unlike Jabber
L{XmlStream<twisted.words.protocols.jabber.xmlstream.XmlStream>}s, the sink
and source objects are assumed to represent an eternal connected and
initialized XML stream. As such, events corresponding to connection,
disconnection, initialization and stream errors are not dispatched or
processed.
@since: 8.2
@ivar source: Source XML stream.
@ivar sink: Sink XML stream.
"""
def __init__(self):
self.source = EventDispatcher()
self.sink = EventDispatcher()
self.source.send = lambda obj: self.sink.dispatch(obj)
self.sink.send = lambda obj: self.source.dispatch(obj)

View file

@ -0,0 +1,279 @@
# -*- test-case-name: twisted.words.test.test_xmlstream -*-
#
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
XML Stream processing.
An XML Stream is defined as a connection over which two XML documents are
exchanged during the lifetime of the connection, one for each direction. The
unit of interaction is a direct child element of the root element (stanza).
The most prominent use of XML Streams is Jabber, but this module is generically
usable. See Twisted Words for Jabber specific protocol support.
Maintainer: Ralph Meijer
@var STREAM_CONNECTED_EVENT: This event signals that the connection has been
established.
@type STREAM_CONNECTED_EVENT: L{str}.
@var STREAM_END_EVENT: This event signals that the connection has been closed.
@type STREAM_END_EVENT: L{str}.
@var STREAM_ERROR_EVENT: This event signals that a parse error occurred.
@type STREAM_ERROR_EVENT: L{str}.
@var STREAM_START_EVENT: This event signals that the root element of the XML
Stream has been received.
For XMPP, this would be the C{<stream:stream ...>} opening tag.
@type STREAM_START_EVENT: L{str}.
"""
from __future__ import absolute_import, division
from twisted.python import failure
from twisted.python.compat import intern, unicode
from twisted.internet import protocol
from twisted.words.xish import domish, utility
STREAM_CONNECTED_EVENT = intern("//event/stream/connected")
STREAM_START_EVENT = intern("//event/stream/start")
STREAM_END_EVENT = intern("//event/stream/end")
STREAM_ERROR_EVENT = intern("//event/stream/error")
class XmlStream(protocol.Protocol, utility.EventDispatcher):
""" Generic Streaming XML protocol handler.
This protocol handler will parse incoming data as XML and dispatch events
accordingly. Incoming stanzas can be handled by registering observers using
XPath-like expressions that are matched against each stanza. See
L{utility.EventDispatcher} for details.
"""
def __init__(self):
utility.EventDispatcher.__init__(self)
self.stream = None
self.rawDataOutFn = None
self.rawDataInFn = None
def _initializeStream(self):
""" Sets up XML Parser. """
self.stream = domish.elementStream()
self.stream.DocumentStartEvent = self.onDocumentStart
self.stream.ElementEvent = self.onElement
self.stream.DocumentEndEvent = self.onDocumentEnd
### --------------------------------------------------------------
###
### Protocol events
###
### --------------------------------------------------------------
def connectionMade(self):
""" Called when a connection is made.
Sets up the XML parser and dispatches the L{STREAM_CONNECTED_EVENT}
event indicating the connection has been established.
"""
self._initializeStream()
self.dispatch(self, STREAM_CONNECTED_EVENT)
def dataReceived(self, data):
""" Called whenever data is received.
Passes the data to the XML parser. This can result in calls to the
DOM handlers. If a parse error occurs, the L{STREAM_ERROR_EVENT} event
is called to allow for cleanup actions, followed by dropping the
connection.
"""
try:
if self.rawDataInFn:
self.rawDataInFn(data)
self.stream.parse(data)
except domish.ParserError:
self.dispatch(failure.Failure(), STREAM_ERROR_EVENT)
self.transport.loseConnection()
def connectionLost(self, reason):
""" Called when the connection is shut down.
Dispatches the L{STREAM_END_EVENT}.
"""
self.dispatch(reason, STREAM_END_EVENT)
self.stream = None
### --------------------------------------------------------------
###
### DOM events
###
### --------------------------------------------------------------
def onDocumentStart(self, rootElement):
""" Called whenever the start tag of a root element has been received.
Dispatches the L{STREAM_START_EVENT}.
"""
self.dispatch(self, STREAM_START_EVENT)
def onElement(self, element):
""" Called whenever a direct child element of the root element has
been received.
Dispatches the received element.
"""
self.dispatch(element)
def onDocumentEnd(self):
""" Called whenever the end tag of the root element has been received.
Closes the connection. This causes C{connectionLost} being called.
"""
self.transport.loseConnection()
def setDispatchFn(self, fn):
""" Set another function to handle elements. """
self.stream.ElementEvent = fn
def resetDispatchFn(self):
""" Set the default function (C{onElement}) to handle elements. """
self.stream.ElementEvent = self.onElement
def send(self, obj):
""" Send data over the stream.
Sends the given C{obj} over the connection. C{obj} may be instances of
L{domish.Element}, C{unicode} and C{str}. The first two will be
properly serialized and/or encoded. C{str} objects must be in UTF-8
encoding.
Note: because it is easy to make mistakes in maintaining a properly
encoded C{str} object, it is advised to use C{unicode} objects
everywhere when dealing with XML Streams.
@param obj: Object to be sent over the stream.
@type obj: L{domish.Element}, L{domish} or C{str}
"""
if domish.IElement.providedBy(obj):
obj = obj.toXml()
if isinstance(obj, unicode):
obj = obj.encode('utf-8')
if self.rawDataOutFn:
self.rawDataOutFn(obj)
self.transport.write(obj)
class BootstrapMixin(object):
"""
XmlStream factory mixin to install bootstrap event observers.
This mixin is for factories providing
L{IProtocolFactory<twisted.internet.interfaces.IProtocolFactory>} to make
sure bootstrap event observers are set up on protocols, before incoming
data is processed. Such protocols typically derive from
L{utility.EventDispatcher}, like L{XmlStream}.
You can set up bootstrap event observers using C{addBootstrap}. The
C{event} and C{fn} parameters correspond with the C{event} and
C{observerfn} arguments to L{utility.EventDispatcher.addObserver}.
@since: 8.2.
@ivar bootstraps: The list of registered bootstrap event observers.
@type bootstrap: C{list}
"""
def __init__(self):
self.bootstraps = []
def installBootstraps(self, dispatcher):
"""
Install registered bootstrap observers.
@param dispatcher: Event dispatcher to add the observers to.
@type dispatcher: L{utility.EventDispatcher}
"""
for event, fn in self.bootstraps:
dispatcher.addObserver(event, fn)
def addBootstrap(self, event, fn):
"""
Add a bootstrap event handler.
@param event: The event to register an observer for.
@type event: C{str} or L{xpath.XPathQuery}
@param fn: The observer callable to be registered.
"""
self.bootstraps.append((event, fn))
def removeBootstrap(self, event, fn):
"""
Remove a bootstrap event handler.
@param event: The event the observer is registered for.
@type event: C{str} or L{xpath.XPathQuery}
@param fn: The registered observer callable.
"""
self.bootstraps.remove((event, fn))
class XmlStreamFactoryMixin(BootstrapMixin):
"""
XmlStream factory mixin that takes care of event handlers.
All positional and keyword arguments passed to create this factory are
passed on as-is to the protocol.
@ivar args: Positional arguments passed to the protocol upon instantiation.
@type args: C{tuple}.
@ivar kwargs: Keyword arguments passed to the protocol upon instantiation.
@type kwargs: C{dict}.
"""
def __init__(self, *args, **kwargs):
BootstrapMixin.__init__(self)
self.args = args
self.kwargs = kwargs
def buildProtocol(self, addr):
"""
Create an instance of XmlStream.
The returned instance will have bootstrap event observers registered
and will proceed to handle input on an incoming connection.
"""
xs = self.protocol(*self.args, **self.kwargs)
xs.factory = self
self.installBootstraps(xs)
return xs
class XmlStreamFactory(XmlStreamFactoryMixin,
protocol.ReconnectingClientFactory):
"""
Factory for XmlStream protocol objects as a reconnection client.
"""
protocol = XmlStream
def buildProtocol(self, addr):
"""
Create a protocol instance.
Overrides L{XmlStreamFactoryMixin.buildProtocol} to work with
a L{ReconnectingClientFactory}. As this is called upon having an
connection established, we are resetting the delay for reconnection
attempts when the connection is lost again.
"""
self.resetDelay()
return XmlStreamFactoryMixin.buildProtocol(self, addr)

View file

@ -0,0 +1,337 @@
# -*- test-case-name: twisted.words.test.test_xpath -*-
#
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
XPath query support.
This module provides L{XPathQuery} to match
L{domish.Element<twisted.words.xish.domish.Element>} instances against
XPath-like expressions.
"""
from __future__ import absolute_import, division
from io import StringIO
from twisted.python.compat import StringType, unicode
class LiteralValue(unicode):
def value(self, elem):
return self
class IndexValue:
def __init__(self, index):
self.index = int(index) - 1
def value(self, elem):
return elem.children[self.index]
class AttribValue:
def __init__(self, attribname):
self.attribname = attribname
if self.attribname == "xmlns":
self.value = self.value_ns
def value_ns(self, elem):
return elem.uri
def value(self, elem):
if self.attribname in elem.attributes:
return elem.attributes[self.attribname]
else:
return None
class CompareValue:
def __init__(self, lhs, op, rhs):
self.lhs = lhs
self.rhs = rhs
if op == "=":
self.value = self._compareEqual
else:
self.value = self._compareNotEqual
def _compareEqual(self, elem):
return self.lhs.value(elem) == self.rhs.value(elem)
def _compareNotEqual(self, elem):
return self.lhs.value(elem) != self.rhs.value(elem)
class BooleanValue:
"""
Provide boolean XPath expression operators.
@ivar lhs: Left hand side expression of the operator.
@ivar op: The operator. One of C{'and'}, C{'or'}.
@ivar rhs: Right hand side expression of the operator.
@ivar value: Reference to the method that will calculate the value of
this expression given an element.
"""
def __init__(self, lhs, op, rhs):
self.lhs = lhs
self.rhs = rhs
if op == "and":
self.value = self._booleanAnd
else:
self.value = self._booleanOr
def _booleanAnd(self, elem):
"""
Calculate boolean and of the given expressions given an element.
@param elem: The element to calculate the value of the expression from.
"""
return self.lhs.value(elem) and self.rhs.value(elem)
def _booleanOr(self, elem):
"""
Calculate boolean or of the given expressions given an element.
@param elem: The element to calculate the value of the expression from.
"""
return self.lhs.value(elem) or self.rhs.value(elem)
def Function(fname):
"""
Internal method which selects the function object
"""
klassname = "_%s_Function" % fname
c = globals()[klassname]()
return c
class _not_Function:
def __init__(self):
self.baseValue = None
def setParams(self, baseValue):
self.baseValue = baseValue
def value(self, elem):
return not self.baseValue.value(elem)
class _text_Function:
def setParams(self):
pass
def value(self, elem):
return unicode(elem)
class _Location:
def __init__(self):
self.predicates = []
self.elementName = None
self.childLocation = None
def matchesPredicates(self, elem):
if self.elementName != None and self.elementName != elem.name:
return 0
for p in self.predicates:
if not p.value(elem):
return 0
return 1
def matches(self, elem):
if not self.matchesPredicates(elem):
return 0
if self.childLocation != None:
for c in elem.elements():
if self.childLocation.matches(c):
return 1
else:
return 1
return 0
def queryForString(self, elem, resultbuf):
if not self.matchesPredicates(elem):
return
if self.childLocation != None:
for c in elem.elements():
self.childLocation.queryForString(c, resultbuf)
else:
resultbuf.write(unicode(elem))
def queryForNodes(self, elem, resultlist):
if not self.matchesPredicates(elem):
return
if self.childLocation != None:
for c in elem.elements():
self.childLocation.queryForNodes(c, resultlist)
else:
resultlist.append(elem)
def queryForStringList(self, elem, resultlist):
if not self.matchesPredicates(elem):
return
if self.childLocation != None:
for c in elem.elements():
self.childLocation.queryForStringList(c, resultlist)
else:
for c in elem.children:
if isinstance(c, StringType):
resultlist.append(c)
class _AnyLocation:
def __init__(self):
self.predicates = []
self.elementName = None
self.childLocation = None
def matchesPredicates(self, elem):
for p in self.predicates:
if not p.value(elem):
return 0
return 1
def listParents(self, elem, parentlist):
if elem.parent != None:
self.listParents(elem.parent, parentlist)
parentlist.append(elem.name)
def isRootMatch(self, elem):
if (self.elementName == None or self.elementName == elem.name) and \
self.matchesPredicates(elem):
if self.childLocation != None:
for c in elem.elements():
if self.childLocation.matches(c):
return True
else:
return True
return False
def findFirstRootMatch(self, elem):
if (self.elementName == None or self.elementName == elem.name) and \
self.matchesPredicates(elem):
# Thus far, the name matches and the predicates match,
# now check into the children and find the first one
# that matches the rest of the structure
# the rest of the structure
if self.childLocation != None:
for c in elem.elements():
if self.childLocation.matches(c):
return c
return None
else:
# No children locations; this is a match!
return elem
else:
# Ok, predicates or name didn't match, so we need to start
# down each child and treat it as the root and try
# again
for c in elem.elements():
if self.matches(c):
return c
# No children matched...
return None
def matches(self, elem):
if self.isRootMatch(elem):
return True
else:
# Ok, initial element isn't an exact match, walk
# down each child and treat it as the root and try
# again
for c in elem.elements():
if self.matches(c):
return True
# No children matched...
return False
def queryForString(self, elem, resultbuf):
raise NotImplementedError(
"queryForString is not implemented for any location")
def queryForNodes(self, elem, resultlist):
# First check to see if _this_ element is a root
if self.isRootMatch(elem):
resultlist.append(elem)
# Now check each child
for c in elem.elements():
self.queryForNodes(c, resultlist)
def queryForStringList(self, elem, resultlist):
if self.isRootMatch(elem):
for c in elem.children:
if isinstance(c, StringType):
resultlist.append(c)
for c in elem.elements():
self.queryForStringList(c, resultlist)
class XPathQuery:
def __init__(self, queryStr):
self.queryStr = queryStr
# Prevent a circular import issue, as xpathparser imports this module.
from twisted.words.xish.xpathparser import (XPathParser,
XPathParserScanner)
parser = XPathParser(XPathParserScanner(queryStr))
self.baseLocation = getattr(parser, 'XPATH')()
def __hash__(self):
return self.queryStr.__hash__()
def matches(self, elem):
return self.baseLocation.matches(elem)
def queryForString(self, elem):
result = StringIO()
self.baseLocation.queryForString(elem, result)
return result.getvalue()
def queryForNodes(self, elem):
result = []
self.baseLocation.queryForNodes(elem, result)
if len(result) == 0:
return None
else:
return result
def queryForStringList(self, elem):
result = []
self.baseLocation.queryForStringList(elem, result)
if len(result) == 0:
return None
else:
return result
__internedQueries = {}
def internQuery(queryString):
if queryString not in __internedQueries:
__internedQueries[queryString] = XPathQuery(queryString)
return __internedQueries[queryString]
def matches(xpathstr, elem):
return internQuery(xpathstr).matches(elem)
def queryForStringList(xpathstr, elem):
return internQuery(xpathstr).queryForStringList(elem)
def queryForString(xpathstr, elem):
return internQuery(xpathstr).queryForString(elem)
def queryForNodes(xpathstr, elem):
return internQuery(xpathstr).queryForNodes(elem)

View file

@ -0,0 +1,524 @@
# -*- test-case-name: twisted.words.test.test_xpath -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
# pylint: disable=W9401,W9402
# DO NOT EDIT xpathparser.py!
#
# It is generated from xpathparser.g using Yapps. Make needed changes there.
# This also means that the generated Python may not conform to Twisted's coding
# standards, so it is wrapped in exec to prevent automated checkers from
# complaining.
# HOWTO Generate me:
#
# 1.) Grab a copy of yapps2:
# https://github.com/smurfix/yapps
#
# Note: Do NOT use the package in debian/ubuntu as it has incompatible
# modifications. The original at http://theory.stanford.edu/~amitp/yapps/
# hasn't been touched since 2003 and has not been updated to work with
# Python 3.
#
# 2.) Generate the grammar:
#
# yapps2 xpathparser.g xpathparser.py.proto
#
# 3.) Edit the output to depend on the embedded runtime, and remove extraneous
# imports:
#
# sed -e '/^# Begin/,${/^[^ ].*mport/d}' -e 's/runtime\.//g' \
# -e "s/^\(from __future\)/exec(r'''\n\1/" -e"\$a''')"
# xpathparser.py.proto > xpathparser.py
"""
XPath Parser.
Besides the parser code produced by Yapps, this module also defines the
parse-time exception classes, a scanner class, a base class for parsers
produced by Yapps, and a context class that keeps track of the parse stack.
These have been copied from the Yapps runtime module.
"""
from __future__ import print_function
import sys, re
MIN_WINDOW=4096
# File lookup window
class SyntaxError(Exception):
"""When we run into an unexpected token, this is the exception to use"""
def __init__(self, pos=None, msg="Bad Token", context=None):
Exception.__init__(self)
self.pos = pos
self.msg = msg
self.context = context
def __str__(self):
if not self.pos: return 'SyntaxError'
else: return 'SyntaxError@%s(%s)' % (repr(self.pos), self.msg)
class NoMoreTokens(Exception):
"""Another exception object, for when we run out of tokens"""
pass
class Token(object):
"""Yapps token.
This is a container for a scanned token.
"""
def __init__(self, type,value, pos=None):
"""Initialize a token."""
self.type = type
self.value = value
self.pos = pos
def __repr__(self):
output = '<%s: %s' % (self.type, repr(self.value))
if self.pos:
output += " @ "
if self.pos[0]:
output += "%s:" % self.pos[0]
if self.pos[1]:
output += "%d" % self.pos[1]
if self.pos[2] is not None:
output += ".%d" % self.pos[2]
output += ">"
return output
in_name=0
class Scanner(object):
"""Yapps scanner.
The Yapps scanner can work in context sensitive or context
insensitive modes. The token(i) method is used to retrieve the
i-th token. It takes a restrict set that limits the set of tokens
it is allowed to return. In context sensitive mode, this restrict
set guides the scanner. In context insensitive mode, there is no
restriction (the set is always the full set of tokens).
"""
def __init__(self, patterns, ignore, input="",
file=None,filename=None,stacked=False):
"""Initialize the scanner.
Parameters:
patterns : [(terminal, uncompiled regex), ...] or None
ignore : {terminal:None, ...}
input : string
If patterns is None, we assume that the subclass has
defined self.patterns : [(terminal, compiled regex), ...].
Note that the patterns parameter expects uncompiled regexes,
whereas the self.patterns field expects compiled regexes.
The 'ignore' value is either None or a callable, which is called
with the scanner and the to-be-ignored match object; this can
be used for include file or comment handling.
"""
if not filename:
global in_name
filename="<f.%d>" % in_name
in_name += 1
self.input = input
self.ignore = ignore
self.file = file
self.filename = filename
self.pos = 0
self.del_pos = 0 # skipped
self.line = 1
self.del_line = 0 # skipped
self.col = 0
self.tokens = []
self.stack = None
self.stacked = stacked
self.last_read_token = None
self.last_token = None
self.last_types = None
if patterns is not None:
# Compile the regex strings into regex objects
self.patterns = []
for terminal, regex in patterns:
self.patterns.append( (terminal, re.compile(regex)) )
def stack_input(self, input="", file=None, filename=None):
"""Temporarily parse from a second file."""
# Already reading from somewhere else: Go on top of that, please.
if self.stack:
# autogenerate a recursion-level-identifying filename
if not filename:
filename = 1
else:
try:
filename += 1
except TypeError:
pass
# now pass off to the include file
self.stack.stack_input(input,file,filename)
else:
try:
filename += 0
except TypeError:
pass
else:
filename = "<str_%d>" % filename
# self.stack = object.__new__(self.__class__)
# Scanner.__init__(self.stack,self.patterns,self.ignore,input,file,filename, stacked=True)
# Note that the pattern+ignore are added by the generated
# scanner code
self.stack = self.__class__(input,file,filename, stacked=True)
def get_pos(self):
"""Return a file/line/char tuple."""
if self.stack: return self.stack.get_pos()
return (self.filename, self.line+self.del_line, self.col)
# def __repr__(self):
# """Print the last few tokens that have been scanned in"""
# output = ''
# for t in self.tokens:
# output += '%s\n' % (repr(t),)
# return output
def print_line_with_pointer(self, pos, length=0, out=sys.stderr):
"""Print the line of 'text' that includes position 'p',
along with a second line with a single caret (^) at position p"""
file,line,p = pos
if file != self.filename:
if self.stack: return self.stack.print_line_with_pointer(pos,length=length,out=out)
print >>out, "(%s: not in input buffer)" % file
return
text = self.input
p += length-1 # starts at pos 1
origline=line
line -= self.del_line
spos=0
if line > 0:
while 1:
line = line - 1
try:
cr = text.index("\n",spos)
except ValueError:
if line:
text = ""
break
if line == 0:
text = text[spos:cr]
break
spos = cr+1
else:
print >>out, "(%s:%d not in input buffer)" % (file,origline)
return
# Now try printing part of the line
text = text[max(p-80, 0):p+80]
p = p - max(p-80, 0)
# Strip to the left
i = text[:p].rfind('\n')
j = text[:p].rfind('\r')
if i < 0 or (0 <= j < i): i = j
if 0 <= i < p:
p = p - i - 1
text = text[i+1:]
# Strip to the right
i = text.find('\n', p)
j = text.find('\r', p)
if i < 0 or (0 <= j < i): i = j
if i >= 0:
text = text[:i]
# Now shorten the text
while len(text) > 70 and p > 60:
# Cut off 10 chars
text = "..." + text[10:]
p = p - 7
# Now print the string, along with an indicator
print >>out, '> ',text
print >>out, '> ',' '*p + '^'
def grab_input(self):
"""Get more input if possible."""
if not self.file: return
if len(self.input) - self.pos >= MIN_WINDOW: return
data = self.file.read(MIN_WINDOW)
if data is None or data == "":
self.file = None
# Drop bytes from the start, if necessary.
if self.pos > 2*MIN_WINDOW:
self.del_pos += MIN_WINDOW
self.del_line += self.input[:MIN_WINDOW].count("\n")
self.pos -= MIN_WINDOW
self.input = self.input[MIN_WINDOW:] + data
else:
self.input = self.input + data
def getchar(self):
"""Return the next character."""
self.grab_input()
c = self.input[self.pos]
self.pos += 1
return c
def token(self, restrict, context=None):
"""Scan for another token."""
while 1:
if self.stack:
try:
return self.stack.token(restrict, context)
except StopIteration:
self.stack = None
# Keep looking for a token, ignoring any in self.ignore
self.grab_input()
# special handling for end-of-file
if self.stacked and self.pos==len(self.input):
raise StopIteration
# Search the patterns for the longest match, with earlier
# tokens in the list having preference
best_match = -1
best_pat = '(error)'
best_m = None
for p, regexp in self.patterns:
# First check to see if we're ignoring this token
if restrict and p not in restrict and p not in self.ignore:
continue
m = regexp.match(self.input, self.pos)
if m and m.end()-m.start() > best_match:
# We got a match that's better than the previous one
best_pat = p
best_match = m.end()-m.start()
best_m = m
# If we didn't find anything, raise an error
if best_pat == '(error)' and best_match < 0:
msg = 'Bad Token'
if restrict:
msg = 'Trying to find one of '+', '.join(restrict)
raise SyntaxError(self.get_pos(), msg, context=context)
ignore = best_pat in self.ignore
value = self.input[self.pos:self.pos+best_match]
if not ignore:
tok=Token(type=best_pat, value=value, pos=self.get_pos())
self.pos += best_match
npos = value.rfind("\n")
if npos > -1:
self.col = best_match-npos
self.line += value.count("\n")
else:
self.col += best_match
# If we found something that isn't to be ignored, return it
if not ignore:
if len(self.tokens) >= 10:
del self.tokens[0]
self.tokens.append(tok)
self.last_read_token = tok
# print repr(tok)
return tok
else:
ignore = self.ignore[best_pat]
if ignore:
ignore(self, best_m)
def peek(self, *types, **kw):
"""Returns the token type for lookahead; if there are any args
then the list of args is the set of token types to allow"""
context = kw.get("context",None)
if self.last_token is None:
self.last_types = types
self.last_token = self.token(types,context)
elif self.last_types:
for t in types:
if t not in self.last_types:
raise NotImplementedError("Unimplemented: restriction set changed")
return self.last_token.type
def scan(self, type, **kw):
"""Returns the matched text, and moves to the next token"""
context = kw.get("context",None)
if self.last_token is None:
tok = self.token([type],context)
else:
if self.last_types and type not in self.last_types:
raise NotImplementedError("Unimplemented: restriction set changed")
tok = self.last_token
self.last_token = None
if tok.type != type:
if not self.last_types: self.last_types=[]
raise SyntaxError(tok.pos, 'Trying to find '+type+': '+ ', '.join(self.last_types)+", got "+tok.type, context=context)
return tok.value
class Parser(object):
"""Base class for Yapps-generated parsers.
"""
def __init__(self, scanner):
self._scanner = scanner
def _stack(self, input="",file=None,filename=None):
"""Temporarily read from someplace else"""
self._scanner.stack_input(input,file,filename)
self._tok = None
def _peek(self, *types, **kw):
"""Returns the token type for lookahead; if there are any args
then the list of args is the set of token types to allow"""
return self._scanner.peek(*types, **kw)
def _scan(self, type, **kw):
"""Returns the matched text, and moves to the next token"""
return self._scanner.scan(type, **kw)
class Context(object):
"""Class to represent the parser's call stack.
Every rule creates a Context that links to its parent rule. The
contexts can be used for debugging.
"""
def __init__(self, parent, scanner, rule, args=()):
"""Create a new context.
Args:
parent: Context object or None
scanner: Scanner object
rule: string (name of the rule)
args: tuple listing parameters to the rule
"""
self.parent = parent
self.scanner = scanner
self.rule = rule
self.args = args
while scanner.stack: scanner = scanner.stack
self.token = scanner.last_read_token
def __str__(self):
output = ''
if self.parent: output = str(self.parent) + ' > '
output += self.rule
return output
def print_error(err, scanner, max_ctx=None):
"""Print error messages, the parser stack, and the input text -- for human-readable error messages."""
# NOTE: this function assumes 80 columns :-(
# Figure out the line number
pos = err.pos
if not pos:
pos = scanner.get_pos()
file_name, line_number, column_number = pos
print('%s:%d:%d: %s' % (file_name, line_number, column_number, err.msg), file=sys.stderr)
scanner.print_line_with_pointer(pos)
context = err.context
token = None
while context:
print('while parsing %s%s:' % (context.rule, tuple(context.args)), file=sys.stderr)
if context.token:
token = context.token
if token:
scanner.print_line_with_pointer(token.pos, length=len(token.value))
context = context.parent
if max_ctx:
max_ctx = max_ctx-1
if not max_ctx:
break
def wrap_error_reporter(parser, rule, *args,**kw):
try:
return getattr(parser, rule)(*args,**kw)
except SyntaxError as e:
print_error(e, parser._scanner)
except NoMoreTokens:
print('Could not complete parsing; stopped around here:', file=sys.stderr)
print(parser._scanner, file=sys.stderr)
from twisted.words.xish.xpath import AttribValue, BooleanValue, CompareValue
from twisted.words.xish.xpath import Function, IndexValue, LiteralValue
from twisted.words.xish.xpath import _AnyLocation, _Location
%%
parser XPathParser:
ignore: "\\s+"
token INDEX: "[0-9]+"
token WILDCARD: "\*"
token IDENTIFIER: "[a-zA-Z][a-zA-Z0-9_\-]*"
token ATTRIBUTE: "\@[a-zA-Z][a-zA-Z0-9_\-]*"
token FUNCNAME: "[a-zA-Z][a-zA-Z0-9_]*"
token CMP_EQ: "\="
token CMP_NE: "\!\="
token STR_DQ: '"([^"]|(\\"))*?"'
token STR_SQ: "'([^']|(\\'))*?'"
token OP_AND: "and"
token OP_OR: "or"
token END: "$"
rule XPATH: PATH {{ result = PATH; current = result }}
( PATH {{ current.childLocation = PATH; current = current.childLocation }} ) * END
{{ return result }}
rule PATH: ("/" {{ result = _Location() }} | "//" {{ result = _AnyLocation() }} )
( IDENTIFIER {{ result.elementName = IDENTIFIER }} | WILDCARD {{ result.elementName = None }} )
( "\[" PREDICATE {{ result.predicates.append(PREDICATE) }} "\]")*
{{ return result }}
rule PREDICATE: EXPR {{ return EXPR }} |
INDEX {{ return IndexValue(INDEX) }}
rule EXPR: FACTOR {{ e = FACTOR }}
( BOOLOP FACTOR {{ e = BooleanValue(e, BOOLOP, FACTOR) }} )*
{{ return e }}
rule BOOLOP: ( OP_AND {{ return OP_AND }} | OP_OR {{ return OP_OR }} )
rule FACTOR: TERM {{ return TERM }}
| "\(" EXPR "\)" {{ return EXPR }}
rule TERM: VALUE {{ t = VALUE }}
[ CMP VALUE {{ t = CompareValue(t, CMP, VALUE) }} ]
{{ return t }}
rule VALUE: "@" IDENTIFIER {{ return AttribValue(IDENTIFIER) }} |
FUNCNAME {{ f = Function(FUNCNAME); args = [] }}
"\(" [ VALUE {{ args.append(VALUE) }}
(
"," VALUE {{ args.append(VALUE) }}
)*
] "\)" {{ f.setParams(*args); return f }} |
STR {{ return LiteralValue(STR[1:len(STR)-1]) }}
rule CMP: (CMP_EQ {{ return CMP_EQ }} | CMP_NE {{ return CMP_NE }})
rule STR: (STR_DQ {{ return STR_DQ }} | STR_SQ {{ return STR_SQ }})

View file

@ -0,0 +1,650 @@
# -*- test-case-name: twisted.words.test.test_xpath -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
# pylint: disable=W9401,W9402
# DO NOT EDIT xpathparser.py!
#
# It is generated from xpathparser.g using Yapps. Make needed changes there.
# This also means that the generated Python may not conform to Twisted's coding
# standards, so it is wrapped in exec to prevent automated checkers from
# complaining.
# HOWTO Generate me:
#
# 1.) Grab a copy of yapps2:
# https://github.com/smurfix/yapps
#
# Note: Do NOT use the package in debian/ubuntu as it has incompatible
# modifications. The original at http://theory.stanford.edu/~amitp/yapps/
# hasn't been touched since 2003 and has not been updated to work with
# Python 3.
#
# 2.) Generate the grammar:
#
# yapps2 xpathparser.g xpathparser.py.proto
#
# 3.) Edit the output to depend on the embedded runtime, and remove extraneous
# imports:
#
# sed -e '/^# Begin/,${/^[^ ].*mport/d}' -e '/^[^#]/s/runtime\.//g' \
# -e "s/^\(from __future\)/exec(r'''\n\1/" -e"\$a''')"
# xpathparser.py.proto > xpathparser.py
"""
XPath Parser.
Besides the parser code produced by Yapps, this module also defines the
parse-time exception classes, a scanner class, a base class for parsers
produced by Yapps, and a context class that keeps track of the parse stack.
These have been copied from the Yapps runtime module.
"""
exec(r'''
from __future__ import print_function
import sys, re
MIN_WINDOW=4096
# File lookup window
class SyntaxError(Exception):
"""When we run into an unexpected token, this is the exception to use"""
def __init__(self, pos=None, msg="Bad Token", context=None):
Exception.__init__(self)
self.pos = pos
self.msg = msg
self.context = context
def __str__(self):
if not self.pos: return 'SyntaxError'
else: return 'SyntaxError@%s(%s)' % (repr(self.pos), self.msg)
class NoMoreTokens(Exception):
"""Another exception object, for when we run out of tokens"""
pass
class Token(object):
"""Yapps token.
This is a container for a scanned token.
"""
def __init__(self, type,value, pos=None):
"""Initialize a token."""
self.type = type
self.value = value
self.pos = pos
def __repr__(self):
output = '<%s: %s' % (self.type, repr(self.value))
if self.pos:
output += " @ "
if self.pos[0]:
output += "%s:" % self.pos[0]
if self.pos[1]:
output += "%d" % self.pos[1]
if self.pos[2] is not None:
output += ".%d" % self.pos[2]
output += ">"
return output
in_name=0
class Scanner(object):
"""Yapps scanner.
The Yapps scanner can work in context sensitive or context
insensitive modes. The token(i) method is used to retrieve the
i-th token. It takes a restrict set that limits the set of tokens
it is allowed to return. In context sensitive mode, this restrict
set guides the scanner. In context insensitive mode, there is no
restriction (the set is always the full set of tokens).
"""
def __init__(self, patterns, ignore, input="",
file=None,filename=None,stacked=False):
"""Initialize the scanner.
Parameters:
patterns : [(terminal, uncompiled regex), ...] or None
ignore : {terminal:None, ...}
input : string
If patterns is None, we assume that the subclass has
defined self.patterns : [(terminal, compiled regex), ...].
Note that the patterns parameter expects uncompiled regexes,
whereas the self.patterns field expects compiled regexes.
The 'ignore' value is either None or a callable, which is called
with the scanner and the to-be-ignored match object; this can
be used for include file or comment handling.
"""
if not filename:
global in_name
filename="<f.%d>" % in_name
in_name += 1
self.input = input
self.ignore = ignore
self.file = file
self.filename = filename
self.pos = 0
self.del_pos = 0 # skipped
self.line = 1
self.del_line = 0 # skipped
self.col = 0
self.tokens = []
self.stack = None
self.stacked = stacked
self.last_read_token = None
self.last_token = None
self.last_types = None
if patterns is not None:
# Compile the regex strings into regex objects
self.patterns = []
for terminal, regex in patterns:
self.patterns.append( (terminal, re.compile(regex)) )
def stack_input(self, input="", file=None, filename=None):
"""Temporarily parse from a second file."""
# Already reading from somewhere else: Go on top of that, please.
if self.stack:
# autogenerate a recursion-level-identifying filename
if not filename:
filename = 1
else:
try:
filename += 1
except TypeError:
pass
# now pass off to the include file
self.stack.stack_input(input,file,filename)
else:
try:
filename += 0
except TypeError:
pass
else:
filename = "<str_%d>" % filename
# self.stack = object.__new__(self.__class__)
# Scanner.__init__(self.stack,self.patterns,self.ignore,input,file,filename, stacked=True)
# Note that the pattern+ignore are added by the generated
# scanner code
self.stack = self.__class__(input,file,filename, stacked=True)
def get_pos(self):
"""Return a file/line/char tuple."""
if self.stack: return self.stack.get_pos()
return (self.filename, self.line+self.del_line, self.col)
# def __repr__(self):
# """Print the last few tokens that have been scanned in"""
# output = ''
# for t in self.tokens:
# output += '%s\n' % (repr(t),)
# return output
def print_line_with_pointer(self, pos, length=0, out=sys.stderr):
"""Print the line of 'text' that includes position 'p',
along with a second line with a single caret (^) at position p"""
file,line,p = pos
if file != self.filename:
if self.stack: return self.stack.print_line_with_pointer(pos,length=length,out=out)
print >>out, "(%s: not in input buffer)" % file
return
text = self.input
p += length-1 # starts at pos 1
origline=line
line -= self.del_line
spos=0
if line > 0:
while 1:
line = line - 1
try:
cr = text.index("\n",spos)
except ValueError:
if line:
text = ""
break
if line == 0:
text = text[spos:cr]
break
spos = cr+1
else:
print >>out, "(%s:%d not in input buffer)" % (file,origline)
return
# Now try printing part of the line
text = text[max(p-80, 0):p+80]
p = p - max(p-80, 0)
# Strip to the left
i = text[:p].rfind('\n')
j = text[:p].rfind('\r')
if i < 0 or (0 <= j < i): i = j
if 0 <= i < p:
p = p - i - 1
text = text[i+1:]
# Strip to the right
i = text.find('\n', p)
j = text.find('\r', p)
if i < 0 or (0 <= j < i): i = j
if i >= 0:
text = text[:i]
# Now shorten the text
while len(text) > 70 and p > 60:
# Cut off 10 chars
text = "..." + text[10:]
p = p - 7
# Now print the string, along with an indicator
print >>out, '> ',text
print >>out, '> ',' '*p + '^'
def grab_input(self):
"""Get more input if possible."""
if not self.file: return
if len(self.input) - self.pos >= MIN_WINDOW: return
data = self.file.read(MIN_WINDOW)
if data is None or data == "":
self.file = None
# Drop bytes from the start, if necessary.
if self.pos > 2*MIN_WINDOW:
self.del_pos += MIN_WINDOW
self.del_line += self.input[:MIN_WINDOW].count("\n")
self.pos -= MIN_WINDOW
self.input = self.input[MIN_WINDOW:] + data
else:
self.input = self.input + data
def getchar(self):
"""Return the next character."""
self.grab_input()
c = self.input[self.pos]
self.pos += 1
return c
def token(self, restrict, context=None):
"""Scan for another token."""
while 1:
if self.stack:
try:
return self.stack.token(restrict, context)
except StopIteration:
self.stack = None
# Keep looking for a token, ignoring any in self.ignore
self.grab_input()
# special handling for end-of-file
if self.stacked and self.pos==len(self.input):
raise StopIteration
# Search the patterns for the longest match, with earlier
# tokens in the list having preference
best_match = -1
best_pat = '(error)'
best_m = None
for p, regexp in self.patterns:
# First check to see if we're ignoring this token
if restrict and p not in restrict and p not in self.ignore:
continue
m = regexp.match(self.input, self.pos)
if m and m.end()-m.start() > best_match:
# We got a match that's better than the previous one
best_pat = p
best_match = m.end()-m.start()
best_m = m
# If we didn't find anything, raise an error
if best_pat == '(error)' and best_match < 0:
msg = 'Bad Token'
if restrict:
msg = 'Trying to find one of '+', '.join(restrict)
raise SyntaxError(self.get_pos(), msg, context=context)
ignore = best_pat in self.ignore
value = self.input[self.pos:self.pos+best_match]
if not ignore:
tok=Token(type=best_pat, value=value, pos=self.get_pos())
self.pos += best_match
npos = value.rfind("\n")
if npos > -1:
self.col = best_match-npos
self.line += value.count("\n")
else:
self.col += best_match
# If we found something that isn't to be ignored, return it
if not ignore:
if len(self.tokens) >= 10:
del self.tokens[0]
self.tokens.append(tok)
self.last_read_token = tok
# print repr(tok)
return tok
else:
ignore = self.ignore[best_pat]
if ignore:
ignore(self, best_m)
def peek(self, *types, **kw):
"""Returns the token type for lookahead; if there are any args
then the list of args is the set of token types to allow"""
context = kw.get("context",None)
if self.last_token is None:
self.last_types = types
self.last_token = self.token(types,context)
elif self.last_types:
for t in types:
if t not in self.last_types:
raise NotImplementedError("Unimplemented: restriction set changed")
return self.last_token.type
def scan(self, type, **kw):
"""Returns the matched text, and moves to the next token"""
context = kw.get("context",None)
if self.last_token is None:
tok = self.token([type],context)
else:
if self.last_types and type not in self.last_types:
raise NotImplementedError("Unimplemented: restriction set changed")
tok = self.last_token
self.last_token = None
if tok.type != type:
if not self.last_types: self.last_types=[]
raise SyntaxError(tok.pos, 'Trying to find '+type+': '+ ', '.join(self.last_types)+", got "+tok.type, context=context)
return tok.value
class Parser(object):
"""Base class for Yapps-generated parsers.
"""
def __init__(self, scanner):
self._scanner = scanner
def _stack(self, input="",file=None,filename=None):
"""Temporarily read from someplace else"""
self._scanner.stack_input(input,file,filename)
self._tok = None
def _peek(self, *types, **kw):
"""Returns the token type for lookahead; if there are any args
then the list of args is the set of token types to allow"""
return self._scanner.peek(*types, **kw)
def _scan(self, type, **kw):
"""Returns the matched text, and moves to the next token"""
return self._scanner.scan(type, **kw)
class Context(object):
"""Class to represent the parser's call stack.
Every rule creates a Context that links to its parent rule. The
contexts can be used for debugging.
"""
def __init__(self, parent, scanner, rule, args=()):
"""Create a new context.
Args:
parent: Context object or None
scanner: Scanner object
rule: string (name of the rule)
args: tuple listing parameters to the rule
"""
self.parent = parent
self.scanner = scanner
self.rule = rule
self.args = args
while scanner.stack: scanner = scanner.stack
self.token = scanner.last_read_token
def __str__(self):
output = ''
if self.parent: output = str(self.parent) + ' > '
output += self.rule
return output
def print_error(err, scanner, max_ctx=None):
"""Print error messages, the parser stack, and the input text -- for human-readable error messages."""
# NOTE: this function assumes 80 columns :-(
# Figure out the line number
pos = err.pos
if not pos:
pos = scanner.get_pos()
file_name, line_number, column_number = pos
print('%s:%d:%d: %s' % (file_name, line_number, column_number, err.msg), file=sys.stderr)
scanner.print_line_with_pointer(pos)
context = err.context
token = None
while context:
print('while parsing %s%s:' % (context.rule, tuple(context.args)), file=sys.stderr)
if context.token:
token = context.token
if token:
scanner.print_line_with_pointer(token.pos, length=len(token.value))
context = context.parent
if max_ctx:
max_ctx = max_ctx-1
if not max_ctx:
break
def wrap_error_reporter(parser, rule, *args,**kw):
try:
return getattr(parser, rule)(*args,**kw)
except SyntaxError as e:
print_error(e, parser._scanner)
except NoMoreTokens:
print('Could not complete parsing; stopped around here:', file=sys.stderr)
print(parser._scanner, file=sys.stderr)
from twisted.words.xish.xpath import AttribValue, BooleanValue, CompareValue
from twisted.words.xish.xpath import Function, IndexValue, LiteralValue
from twisted.words.xish.xpath import _AnyLocation, _Location
# Begin -- grammar generated by Yapps
class XPathParserScanner(Scanner):
patterns = [
('","', re.compile(',')),
('"@"', re.compile('@')),
('"\\)"', re.compile('\\)')),
('"\\("', re.compile('\\(')),
('"\\]"', re.compile('\\]')),
('"\\["', re.compile('\\[')),
('"//"', re.compile('//')),
('"/"', re.compile('/')),
('\\s+', re.compile('\\s+')),
('INDEX', re.compile('[0-9]+')),
('WILDCARD', re.compile('\\*')),
('IDENTIFIER', re.compile('[a-zA-Z][a-zA-Z0-9_\\-]*')),
('ATTRIBUTE', re.compile('\\@[a-zA-Z][a-zA-Z0-9_\\-]*')),
('FUNCNAME', re.compile('[a-zA-Z][a-zA-Z0-9_]*')),
('CMP_EQ', re.compile('\\=')),
('CMP_NE', re.compile('\\!\\=')),
('STR_DQ', re.compile('"([^"]|(\\"))*?"')),
('STR_SQ', re.compile("'([^']|(\\'))*?'")),
('OP_AND', re.compile('and')),
('OP_OR', re.compile('or')),
('END', re.compile('$')),
]
def __init__(self, str,*args,**kw):
Scanner.__init__(self,None,{'\\s+':None,},str,*args,**kw)
class XPathParser(Parser):
Context = Context
def XPATH(self, _parent=None):
_context = self.Context(_parent, self._scanner, 'XPATH', [])
PATH = self.PATH(_context)
result = PATH; current = result
while self._peek('END', '"/"', '"//"', context=_context) != 'END':
PATH = self.PATH(_context)
current.childLocation = PATH; current = current.childLocation
END = self._scan('END', context=_context)
return result
def PATH(self, _parent=None):
_context = self.Context(_parent, self._scanner, 'PATH', [])
_token = self._peek('"/"', '"//"', context=_context)
if _token == '"/"':
self._scan('"/"', context=_context)
result = _Location()
else: # == '"//"'
self._scan('"//"', context=_context)
result = _AnyLocation()
_token = self._peek('IDENTIFIER', 'WILDCARD', context=_context)
if _token == 'IDENTIFIER':
IDENTIFIER = self._scan('IDENTIFIER', context=_context)
result.elementName = IDENTIFIER
else: # == 'WILDCARD'
WILDCARD = self._scan('WILDCARD', context=_context)
result.elementName = None
while self._peek('"\\["', 'END', '"/"', '"//"', context=_context) == '"\\["':
self._scan('"\\["', context=_context)
PREDICATE = self.PREDICATE(_context)
result.predicates.append(PREDICATE)
self._scan('"\\]"', context=_context)
return result
def PREDICATE(self, _parent=None):
_context = self.Context(_parent, self._scanner, 'PREDICATE', [])
_token = self._peek('INDEX', '"\\("', '"@"', 'FUNCNAME', 'STR_DQ', 'STR_SQ', context=_context)
if _token != 'INDEX':
EXPR = self.EXPR(_context)
return EXPR
else: # == 'INDEX'
INDEX = self._scan('INDEX', context=_context)
return IndexValue(INDEX)
def EXPR(self, _parent=None):
_context = self.Context(_parent, self._scanner, 'EXPR', [])
FACTOR = self.FACTOR(_context)
e = FACTOR
while self._peek('OP_AND', 'OP_OR', '"\\)"', '"\\]"', context=_context) in ['OP_AND', 'OP_OR']:
BOOLOP = self.BOOLOP(_context)
FACTOR = self.FACTOR(_context)
e = BooleanValue(e, BOOLOP, FACTOR)
return e
def BOOLOP(self, _parent=None):
_context = self.Context(_parent, self._scanner, 'BOOLOP', [])
_token = self._peek('OP_AND', 'OP_OR', context=_context)
if _token == 'OP_AND':
OP_AND = self._scan('OP_AND', context=_context)
return OP_AND
else: # == 'OP_OR'
OP_OR = self._scan('OP_OR', context=_context)
return OP_OR
def FACTOR(self, _parent=None):
_context = self.Context(_parent, self._scanner, 'FACTOR', [])
_token = self._peek('"\\("', '"@"', 'FUNCNAME', 'STR_DQ', 'STR_SQ', context=_context)
if _token != '"\\("':
TERM = self.TERM(_context)
return TERM
else: # == '"\\("'
self._scan('"\\("', context=_context)
EXPR = self.EXPR(_context)
self._scan('"\\)"', context=_context)
return EXPR
def TERM(self, _parent=None):
_context = self.Context(_parent, self._scanner, 'TERM', [])
VALUE = self.VALUE(_context)
t = VALUE
if self._peek('CMP_EQ', 'CMP_NE', 'OP_AND', 'OP_OR', '"\\)"', '"\\]"', context=_context) in ['CMP_EQ', 'CMP_NE']:
CMP = self.CMP(_context)
VALUE = self.VALUE(_context)
t = CompareValue(t, CMP, VALUE)
return t
def VALUE(self, _parent=None):
_context = self.Context(_parent, self._scanner, 'VALUE', [])
_token = self._peek('"@"', 'FUNCNAME', 'STR_DQ', 'STR_SQ', context=_context)
if _token == '"@"':
self._scan('"@"', context=_context)
IDENTIFIER = self._scan('IDENTIFIER', context=_context)
return AttribValue(IDENTIFIER)
elif _token == 'FUNCNAME':
FUNCNAME = self._scan('FUNCNAME', context=_context)
f = Function(FUNCNAME); args = []
self._scan('"\\("', context=_context)
if self._peek('"\\)"', '"@"', 'FUNCNAME', '","', 'STR_DQ', 'STR_SQ', context=_context) not in ['"\\)"', '","']:
VALUE = self.VALUE(_context)
args.append(VALUE)
while self._peek('","', '"\\)"', context=_context) == '","':
self._scan('","', context=_context)
VALUE = self.VALUE(_context)
args.append(VALUE)
self._scan('"\\)"', context=_context)
f.setParams(*args); return f
else: # in ['STR_DQ', 'STR_SQ']
STR = self.STR(_context)
return LiteralValue(STR[1:len(STR)-1])
def CMP(self, _parent=None):
_context = self.Context(_parent, self._scanner, 'CMP', [])
_token = self._peek('CMP_EQ', 'CMP_NE', context=_context)
if _token == 'CMP_EQ':
CMP_EQ = self._scan('CMP_EQ', context=_context)
return CMP_EQ
else: # == 'CMP_NE'
CMP_NE = self._scan('CMP_NE', context=_context)
return CMP_NE
def STR(self, _parent=None):
_context = self.Context(_parent, self._scanner, 'STR', [])
_token = self._peek('STR_DQ', 'STR_SQ', context=_context)
if _token == 'STR_DQ':
STR_DQ = self._scan('STR_DQ', context=_context)
return STR_DQ
else: # == 'STR_SQ'
STR_SQ = self._scan('STR_SQ', context=_context)
return STR_SQ
def parse(rule, text):
P = XPathParser(XPathParserScanner(text))
return wrap_error_reporter(P, rule)
if __name__ == '__main__':
from sys import argv, stdin
if len(argv) >= 2:
if len(argv) >= 3:
f = open(argv[2],'r')
else:
f = stdin
print(parse(argv[1], f.read()))
else: print ('Args: <rule> [<filename>]', file=sys.stderr)
# End -- grammar generated by Yapps
''')

View file

@ -0,0 +1,30 @@
# -*- test-case-name: twisted.words.test.test_xmpproutertap -*-
#
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
from twisted.application import strports
from twisted.python import usage
from twisted.words.protocols.jabber import component
class Options(usage.Options):
optParameters = [
('port', None, 'tcp:5347:interface=127.0.0.1',
'Port components connect to'),
('secret', None, 'secret', 'Router secret'),
]
optFlags = [
('verbose', 'v', 'Log traffic'),
]
def makeService(config):
router = component.Router()
factory = component.XMPPComponentServerFactory(router, config['secret'])
if config['verbose']:
factory.logTraffic = True
return strports.service(config['port'], factory)