Ausgabe der neuen DB Einträge

This commit is contained in:
hubobel 2022-01-02 21:50:48 +01:00
parent bad48e1627
commit cfbbb9ee3d
2399 changed files with 843193 additions and 43 deletions

View file

@ -0,0 +1,23 @@
# -*- test-case-name: twisted -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Twisted: The Framework Of Your Internet.
"""
# setup version
from twisted._version import __version__ as version
__version__ = version.short()
from incremental import Version
from twisted.python.deprecate import deprecatedModuleAttribute
deprecatedModuleAttribute(
Version('Twisted', 20, 3, 0),
"morituri nolumus mori",
"twisted",
"news"
)

View file

@ -0,0 +1,16 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
# Make the twisted module executable with the default behaviour of
# running twist.
# This is not a docstring to avoid changing the string output of twist.
from __future__ import division, absolute_import
import sys
from pkg_resources import load_entry_point
if __name__ == '__main__':
sys.exit(
load_entry_point('Twisted', 'console_scripts', 'twist')()
)

View file

@ -0,0 +1,25 @@
# -*- test-case-name: twisted.test.test_paths -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Twisted integration with operating system threads.
"""
from __future__ import absolute_import, division, print_function
from ._threadworker import ThreadWorker, LockWorker
from ._ithreads import IWorker, AlreadyQuit
from ._team import Team
from ._memory import createMemoryWorker
from ._pool import pool
__all__ = [
"ThreadWorker",
"LockWorker",
"IWorker",
"AlreadyQuit",
"Team",
"createMemoryWorker",
"pool",
]

View file

@ -0,0 +1,46 @@
# -*- test-case-name: twisted._threads.test.test_convenience -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Common functionality used within the implementation of various workers.
"""
from __future__ import absolute_import, division, print_function
from ._ithreads import AlreadyQuit
class Quit(object):
"""
A flag representing whether a worker has been quit.
@ivar isSet: Whether this flag is set.
@type isSet: L{bool}
"""
def __init__(self):
"""
Create a L{Quit} un-set.
"""
self.isSet = False
def set(self):
"""
Set the flag if it has not been set.
@raise AlreadyQuit: If it has been set.
"""
self.check()
self.isSet = True
def check(self):
"""
Check if the flag has been set.
@raise AlreadyQuit: If it has been set.
"""
if self.isSet:
raise AlreadyQuit()

View file

@ -0,0 +1,61 @@
# -*- test-case-name: twisted._threads.test -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Interfaces related to threads.
"""
from __future__ import absolute_import, division, print_function
from zope.interface import Interface
class AlreadyQuit(Exception):
"""
This worker worker is dead and cannot execute more instructions.
"""
class IWorker(Interface):
"""
A worker that can perform some work concurrently.
All methods on this interface must be thread-safe.
"""
def do(task):
"""
Perform the given task.
As an interface, this method makes no specific claims about concurrent
execution. An L{IWorker}'s C{do} implementation may defer execution
for later on the same thread, immediately on a different thread, or
some combination of the two. It is valid for a C{do} method to
schedule C{task} in such a way that it may never be executed.
It is important for some implementations to provide specific properties
with respect to where C{task} is executed, of course, and client code
may rely on a more specific implementation of C{do} than L{IWorker}.
@param task: a task to call in a thread or other concurrent context.
@type task: 0-argument callable
@raise AlreadyQuit: if C{quit} has been called.
"""
def quit():
"""
Free any resources associated with this L{IWorker} and cause it to
reject all future work.
@raise: L{AlreadyQuit} if this method has already been called.
"""
class IExclusiveWorker(IWorker):
"""
Like L{IWorker}, but with the additional guarantee that the callables
passed to C{do} will not be called exclusively with each other.
"""

View file

@ -0,0 +1,71 @@
# -*- test-case-name: twisted._threads.test.test_memory -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Implementation of an in-memory worker that defers execution.
"""
from __future__ import absolute_import, division, print_function
from zope.interface import implementer
from . import IWorker
from ._convenience import Quit
NoMoreWork = object()
@implementer(IWorker)
class MemoryWorker(object):
"""
An L{IWorker} that queues work for later performance.
@ivar _quit: a flag indicating
@type _quit: L{Quit}
"""
def __init__(self, pending=list):
"""
Create a L{MemoryWorker}.
"""
self._quit = Quit()
self._pending = pending()
def do(self, work):
"""
Queue some work for to perform later; see L{createMemoryWorker}.
@param work: The work to perform.
"""
self._quit.check()
self._pending.append(work)
def quit(self):
"""
Quit this worker.
"""
self._quit.set()
self._pending.append(NoMoreWork)
def createMemoryWorker():
"""
Create an L{IWorker} that does nothing but defer work, to be performed
later.
@return: a worker that will enqueue work to perform later, and a callable
that will perform one element of that work.
@rtype: 2-L{tuple} of (L{IWorker}, L{callable})
"""
def perform():
if not worker._pending:
return False
if worker._pending[0] is NoMoreWork:
return False
worker._pending.pop(0)()
return True
worker = MemoryWorker()
return (worker, perform)

View file

@ -0,0 +1,69 @@
# -*- test-case-name: twisted._threads.test -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Top level thread pool interface, used to implement
L{twisted.python.threadpool}.
"""
from __future__ import absolute_import, division, print_function
from threading import Thread, Lock, local as LocalStorage
try:
from Queue import Queue
except ImportError:
from queue import Queue
from twisted.python.log import err
from ._threadworker import LockWorker
from ._team import Team
from ._threadworker import ThreadWorker
def pool(currentLimit, threadFactory=Thread):
"""
Construct a L{Team} that spawns threads as a thread pool, with the given
limiting function.
@note: Future maintainers: while the public API for the eventual move to
twisted.threads should look I{something} like this, and while this
function is necessary to implement the API described by
L{twisted.python.threadpool}, I am starting to think the idea of a hard
upper limit on threadpool size is just bad (turning memory performance
issues into correctness issues well before we run into memory
pressure), and instead we should build something with reactor
integration for slowly releasing idle threads when they're not needed
and I{rate} limiting the creation of new threads rather than just
hard-capping it.
@param currentLimit: a callable that returns the current limit on the
number of workers that the returned L{Team} should create; if it
already has more workers than that value, no new workers will be
created.
@type currentLimit: 0-argument callable returning L{int}
@param reactor: If passed, the L{IReactorFromThreads} / L{IReactorCore} to
be used to coordinate actions on the L{Team} itself. Otherwise, a
L{LockWorker} will be used.
@return: a new L{Team}.
"""
def startThread(target):
return threadFactory(target=target).start()
def limitedWorkerCreator():
stats = team.statistics()
if stats.busyWorkerCount + stats.idleWorkerCount >= currentLimit():
return None
return ThreadWorker(startThread, Queue())
team = Team(coordinator=LockWorker(Lock(), LocalStorage()),
createWorker=limitedWorkerCreator,
logException=err)
return team

View file

@ -0,0 +1,231 @@
# -*- test-case-name: twisted._threads.test.test_team -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Implementation of a L{Team} of workers; a thread-pool that can allocate work to
workers.
"""
from __future__ import absolute_import, division, print_function
from collections import deque
from zope.interface import implementer
from . import IWorker
from ._convenience import Quit
class Statistics(object):
"""
Statistics about a L{Team}'s current activity.
@ivar idleWorkerCount: The number of idle workers.
@type idleWorkerCount: L{int}
@ivar busyWorkerCount: The number of busy workers.
@type busyWorkerCount: L{int}
@ivar backloggedWorkCount: The number of work items passed to L{Team.do}
which have not yet been sent to a worker to be performed because not
enough workers are available.
@type backloggedWorkCount: L{int}
"""
def __init__(self, idleWorkerCount, busyWorkerCount,
backloggedWorkCount):
self.idleWorkerCount = idleWorkerCount
self.busyWorkerCount = busyWorkerCount
self.backloggedWorkCount = backloggedWorkCount
@implementer(IWorker)
class Team(object):
"""
A composite L{IWorker} implementation.
@ivar _quit: A L{Quit} flag indicating whether this L{Team} has been quit
yet. This may be set by an arbitrary thread since L{Team.quit} may be
called from anywhere.
@ivar _coordinator: the L{IExclusiveWorker} coordinating access to this
L{Team}'s internal resources.
@ivar _createWorker: a callable that will create new workers.
@ivar _logException: a 0-argument callable called in an exception context
when there is an unhandled error from a task passed to L{Team.do}
@ivar _idle: a L{set} of idle workers.
@ivar _busyCount: the number of workers currently busy.
@ivar _pending: a C{deque} of tasks - that is, 0-argument callables passed
to L{Team.do} - that are outstanding.
@ivar _shouldQuitCoordinator: A flag indicating that the coordinator should
be quit at the next available opportunity. Unlike L{Team._quit}, this
flag is only set by the coordinator.
@ivar _toShrink: the number of workers to shrink this L{Team} by at the
next available opportunity; set in the coordinator.
"""
def __init__(self, coordinator, createWorker, logException):
"""
@param coordinator: an L{IExclusiveWorker} which will coordinate access
to resources on this L{Team}; that is to say, an
L{IExclusiveWorker} whose C{do} method ensures that its given work
will be executed in a mutually exclusive context, not in parallel
with other work enqueued by C{do} (although possibly in parallel
with the caller).
@param createWorker: A 0-argument callable that will create an
L{IWorker} to perform work.
@param logException: A 0-argument callable called in an exception
context when the work passed to C{do} raises an exception.
"""
self._quit = Quit()
self._coordinator = coordinator
self._createWorker = createWorker
self._logException = logException
# Don't touch these except from the coordinator.
self._idle = set()
self._busyCount = 0
self._pending = deque()
self._shouldQuitCoordinator = False
self._toShrink = 0
def statistics(self):
"""
Gather information on the current status of this L{Team}.
@return: a L{Statistics} describing the current state of this L{Team}.
"""
return Statistics(len(self._idle), self._busyCount, len(self._pending))
def grow(self, n):
"""
Increase the the number of idle workers by C{n}.
@param n: The number of new idle workers to create.
@type n: L{int}
"""
self._quit.check()
@self._coordinator.do
def createOneWorker():
for x in range(n):
worker = self._createWorker()
if worker is None:
return
self._recycleWorker(worker)
def shrink(self, n=None):
"""
Decrease the number of idle workers by C{n}.
@param n: The number of idle workers to shut down, or L{None} (or
unspecified) to shut down all workers.
@type n: L{int} or L{None}
"""
self._quit.check()
self._coordinator.do(lambda: self._quitIdlers(n))
def _quitIdlers(self, n=None):
"""
The implmentation of C{shrink}, performed by the coordinator worker.
@param n: see L{Team.shrink}
"""
if n is None:
n = len(self._idle) + self._busyCount
for x in range(n):
if self._idle:
self._idle.pop().quit()
else:
self._toShrink += 1
if self._shouldQuitCoordinator and self._busyCount == 0:
self._coordinator.quit()
def do(self, task):
"""
Perform some work in a worker created by C{createWorker}.
@param task: the callable to run
"""
self._quit.check()
self._coordinator.do(lambda: self._coordinateThisTask(task))
def _coordinateThisTask(self, task):
"""
Select a worker to dispatch to, either an idle one or a new one, and
perform it.
This method should run on the coordinator worker.
@param task: the task to dispatch
@type task: 0-argument callable
"""
worker = (self._idle.pop() if self._idle
else self._createWorker())
if worker is None:
# The createWorker method may return None if we're out of resources
# to create workers.
self._pending.append(task)
return
self._busyCount += 1
@worker.do
def doWork():
try:
task()
except:
self._logException()
@self._coordinator.do
def idleAndPending():
self._busyCount -= 1
self._recycleWorker(worker)
def _recycleWorker(self, worker):
"""
Called only from coordinator.
Recycle the given worker into the idle pool.
@param worker: a worker created by C{createWorker} and now idle.
@type worker: L{IWorker}
"""
self._idle.add(worker)
if self._pending:
# Re-try the first enqueued thing.
# (Explicitly do _not_ honor _quit.)
self._coordinateThisTask(self._pending.popleft())
elif self._shouldQuitCoordinator:
self._quitIdlers()
elif self._toShrink > 0:
self._toShrink -= 1
self._idle.remove(worker)
worker.quit()
def quit(self):
"""
Stop doing work and shut down all idle workers.
"""
self._quit.set()
# In case all the workers are idle when we do this.
@self._coordinator.do
def startFinishing():
self._shouldQuitCoordinator = True
self._quitIdlers()

View file

@ -0,0 +1,123 @@
# -*- test-case-name: twisted._threads.test.test_threadworker -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Implementation of an L{IWorker} based on native threads and queues.
"""
from __future__ import absolute_import, division, print_function
from zope.interface import implementer
from ._ithreads import IExclusiveWorker
from ._convenience import Quit
_stop = object()
@implementer(IExclusiveWorker)
class ThreadWorker(object):
"""
An L{IExclusiveWorker} implemented based on a single thread and a queue.
This worker ensures exclusivity (i.e. it is an L{IExclusiveWorker} and not
an L{IWorker}) by performing all of the work passed to C{do} on the I{same}
thread.
"""
def __init__(self, startThread, queue):
"""
Create a L{ThreadWorker} with a function to start a thread and a queue
to use to communicate with that thread.
@param startThread: a callable that takes a callable to run in another
thread.
@type startThread: callable taking a 0-argument callable and returning
nothing.
@param queue: A L{Queue} to use to give tasks to the thread created by
C{startThread}.
@param queue: L{Queue}
"""
self._q = queue
self._hasQuit = Quit()
def work():
for task in iter(queue.get, _stop):
task()
startThread(work)
def do(self, task):
"""
Perform the given task on the thread owned by this L{ThreadWorker}.
@param task: the function to call on a thread.
"""
self._hasQuit.check()
self._q.put(task)
def quit(self):
"""
Reject all future work and stop the thread started by C{__init__}.
"""
# Reject all future work. Set this _before_ enqueueing _stop, so
# that no work is ever enqueued _after_ _stop.
self._hasQuit.set()
self._q.put(_stop)
@implementer(IExclusiveWorker)
class LockWorker(object):
"""
An L{IWorker} implemented based on a mutual-exclusion lock.
"""
def __init__(self, lock, local):
"""
@param lock: A mutual-exclusion lock, with C{acquire} and C{release}
methods.
@type lock: L{threading.Lock}
@param local: Local storage.
@type local: L{threading.local}
"""
self._quit = Quit()
self._lock = lock
self._local = local
def do(self, work):
"""
Do the given work on this thread, with the mutex acquired. If this is
called re-entrantly, return and wait for the outer invocation to do the
work.
@param work: the work to do with the lock held.
"""
lock = self._lock
local = self._local
self._quit.check()
working = getattr(local, "working", None)
if working is None:
working = local.working = []
working.append(work)
lock.acquire()
try:
while working:
working.pop(0)()
finally:
lock.release()
local.working = None
else:
working.append(work)
def quit(self):
"""
Quit this L{LockWorker}.
"""
self._quit.set()
self._lock = None

View file

@ -0,0 +1,9 @@
# -*- test-case-name: twisted._threads.test -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted._threads}.
"""
from __future__ import absolute_import, division, print_function

View file

@ -0,0 +1,61 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Test cases for convenience functionality in L{twisted._threads._convenience}.
"""
from __future__ import absolute_import, division, print_function
from twisted.trial.unittest import SynchronousTestCase
from .._convenience import Quit
from .._ithreads import AlreadyQuit
class QuitTests(SynchronousTestCase):
"""
Tests for L{Quit}
"""
def test_isInitiallySet(self):
"""
L{Quit.isSet} starts as L{False}.
"""
quit = Quit()
self.assertEqual(quit.isSet, False)
def test_setSetsSet(self):
"""
L{Quit.set} sets L{Quit.isSet} to L{True}.
"""
quit = Quit()
quit.set()
self.assertEqual(quit.isSet, True)
def test_checkDoesNothing(self):
"""
L{Quit.check} initially does nothing and returns L{None}.
"""
quit = Quit()
self.assertIs(quit.check(), None)
def test_checkAfterSetRaises(self):
"""
L{Quit.check} raises L{AlreadyQuit} if L{Quit.set} has been called.
"""
quit = Quit()
quit.set()
self.assertRaises(AlreadyQuit, quit.check)
def test_setTwiceRaises(self):
"""
L{Quit.set} raises L{AlreadyQuit} if it has been called previously.
"""
quit = Quit()
quit.set()
self.assertRaises(AlreadyQuit, quit.set)

View file

@ -0,0 +1,65 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted._threads._memory}.
"""
from __future__ import absolute_import, division, print_function
from zope.interface.verify import verifyObject
from twisted.trial.unittest import SynchronousTestCase
from .. import AlreadyQuit, IWorker, createMemoryWorker
class MemoryWorkerTests(SynchronousTestCase):
"""
Tests for L{MemoryWorker}.
"""
def test_createWorkerAndPerform(self):
"""
L{createMemoryWorker} creates an L{IWorker} and a callable that can
perform work on it. The performer returns C{True} if it accomplished
useful work.
"""
worker, performer = createMemoryWorker()
verifyObject(IWorker, worker)
done = []
worker.do(lambda: done.append(3))
worker.do(lambda: done.append(4))
self.assertEqual(done, [])
self.assertEqual(performer(), True)
self.assertEqual(done, [3])
self.assertEqual(performer(), True)
self.assertEqual(done, [3, 4])
def test_quitQuits(self):
"""
Calling C{quit} on the worker returned by L{createMemoryWorker} causes
its C{do} and C{quit} methods to raise L{AlreadyQuit}; its C{perform}
callable will start raising L{AlreadyQuit} when the work already
provided to C{do} has been exhausted.
"""
worker, performer = createMemoryWorker()
done = []
def moreWork():
done.append(7)
worker.do(moreWork)
worker.quit()
self.assertRaises(AlreadyQuit, worker.do, moreWork)
self.assertRaises(AlreadyQuit, worker.quit)
performer()
self.assertEqual(done, [7])
self.assertEqual(performer(), False)
def test_performWhenNothingToDoYet(self):
"""
The C{perform} callable returned by L{createMemoryWorker} will return
no result when there's no work to do yet. Since there is no work to
do, the performer returns C{False}.
"""
worker, performer = createMemoryWorker()
self.assertEqual(performer(), False)

View file

@ -0,0 +1,290 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted._threads._team}.
"""
from __future__ import absolute_import, division, print_function
from twisted.trial.unittest import SynchronousTestCase
from twisted.python.context import call, get
from twisted.python.components import proxyForInterface
from twisted.python.failure import Failure
from .. import IWorker, Team, createMemoryWorker, AlreadyQuit
class ContextualWorker(proxyForInterface(IWorker, "_realWorker")):
"""
A worker implementation that supplies a context.
"""
def __init__(self, realWorker, **ctx):
"""
Create with a real worker and a context.
"""
self._realWorker = realWorker
self._context = ctx
def do(self, work):
"""
Perform the given work with the context given to __init__.
@param work: the work to pass on to the real worker.
"""
super(ContextualWorker, self).do(lambda: call(self._context, work))
class TeamTests(SynchronousTestCase):
"""
Tests for L{Team}
"""
def setUp(self):
"""
Set up a L{Team} with inspectable, synchronous workers that can be
single-stepped.
"""
coordinator, self.coordinateOnce = createMemoryWorker()
self.coordinator = ContextualWorker(coordinator, worker="coordinator")
self.workerPerformers = []
self.allWorkersEver = []
self.allUnquitWorkers = []
self.activePerformers = []
self.noMoreWorkers = lambda: False
def createWorker():
if self.noMoreWorkers():
return None
worker, performer = createMemoryWorker()
self.workerPerformers.append(performer)
self.activePerformers.append(performer)
cw = ContextualWorker(worker, worker=len(self.workerPerformers))
self.allWorkersEver.append(cw)
self.allUnquitWorkers.append(cw)
realQuit = cw.quit
def quitAndRemove():
realQuit()
self.allUnquitWorkers.remove(cw)
self.activePerformers.remove(performer)
cw.quit = quitAndRemove
return cw
self.failures = []
def logException():
self.failures.append(Failure())
self.team = Team(coordinator, createWorker, logException)
def coordinate(self):
"""
Perform all work currently scheduled in the coordinator.
@return: whether any coordination work was performed; if the
coordinator was idle when this was called, return L{False}
(otherwise L{True}).
@rtype: L{bool}
"""
did = False
while self.coordinateOnce():
did = True
return did
def performAllOutstandingWork(self):
"""
Perform all work on the coordinator and worker performers that needs to
be done.
"""
continuing = True
while continuing:
continuing = self.coordinate()
for performer in self.workerPerformers:
if performer in self.activePerformers:
performer()
continuing = continuing or self.coordinate()
def test_doDoesWorkInWorker(self):
"""
L{Team.do} does the work in a worker created by the createWorker
callable.
"""
def something():
something.who = get("worker")
self.team.do(something)
self.coordinate()
self.assertEqual(self.team.statistics().busyWorkerCount, 1)
self.performAllOutstandingWork()
self.assertEqual(something.who, 1)
self.assertEqual(self.team.statistics().busyWorkerCount, 0)
def test_initialStatistics(self):
"""
L{Team.statistics} returns an object with idleWorkerCount,
busyWorkerCount, and backloggedWorkCount integer attributes.
"""
stats = self.team.statistics()
self.assertEqual(stats.idleWorkerCount, 0)
self.assertEqual(stats.busyWorkerCount, 0)
self.assertEqual(stats.backloggedWorkCount, 0)
def test_growCreatesIdleWorkers(self):
"""
L{Team.grow} increases the number of available idle workers.
"""
self.team.grow(5)
self.performAllOutstandingWork()
self.assertEqual(len(self.workerPerformers), 5)
def test_growCreateLimit(self):
"""
L{Team.grow} increases the number of available idle workers until the
C{createWorker} callable starts returning None.
"""
self.noMoreWorkers = lambda: len(self.allWorkersEver) >= 3
self.team.grow(5)
self.performAllOutstandingWork()
self.assertEqual(len(self.allWorkersEver), 3)
self.assertEqual(self.team.statistics().idleWorkerCount, 3)
def test_shrinkQuitsWorkers(self):
"""
L{Team.shrink} will quit the given number of workers.
"""
self.team.grow(5)
self.performAllOutstandingWork()
self.team.shrink(3)
self.performAllOutstandingWork()
self.assertEqual(len(self.allUnquitWorkers), 2)
def test_shrinkToZero(self):
"""
L{Team.shrink} with no arguments will stop all outstanding workers.
"""
self.team.grow(10)
self.performAllOutstandingWork()
self.assertEqual(len(self.allUnquitWorkers), 10)
self.team.shrink()
self.assertEqual(len(self.allUnquitWorkers), 10)
self.performAllOutstandingWork()
self.assertEqual(len(self.allUnquitWorkers), 0)
def test_moreWorkWhenNoWorkersAvailable(self):
"""
When no additional workers are available, the given work is backlogged,
and then performed later when the work was.
"""
self.team.grow(3)
self.coordinate()
def something():
something.times += 1
something.times = 0
self.assertEqual(self.team.statistics().idleWorkerCount, 3)
for i in range(3):
self.team.do(something)
# Make progress on the coordinator but do _not_ actually complete the
# work, yet.
self.coordinate()
self.assertEqual(self.team.statistics().idleWorkerCount, 0)
self.noMoreWorkers = lambda: True
self.team.do(something)
self.coordinate()
self.assertEqual(self.team.statistics().idleWorkerCount, 0)
self.assertEqual(self.team.statistics().backloggedWorkCount, 1)
self.performAllOutstandingWork()
self.assertEqual(self.team.statistics().backloggedWorkCount, 0)
self.assertEqual(something.times, 4)
def test_exceptionInTask(self):
"""
When an exception is raised in a task passed to L{Team.do}, the
C{logException} given to the L{Team} at construction is invoked in the
exception context.
"""
self.team.do(lambda: 1/0)
self.performAllOutstandingWork()
self.assertEqual(len(self.failures), 1)
self.assertEqual(self.failures[0].type, ZeroDivisionError)
def test_quit(self):
"""
L{Team.quit} causes future invocations of L{Team.do} and L{Team.quit}
to raise L{AlreadyQuit}.
"""
self.team.quit()
self.assertRaises(AlreadyQuit, self.team.quit)
self.assertRaises(AlreadyQuit, self.team.do, list)
def test_quitQuits(self):
"""
L{Team.quit} causes all idle workers, as well as the coordinator
worker, to quit.
"""
for x in range(10):
self.team.do(list)
self.performAllOutstandingWork()
self.team.quit()
self.performAllOutstandingWork()
self.assertEqual(len(self.allUnquitWorkers), 0)
self.assertRaises(AlreadyQuit, self.coordinator.quit)
def test_quitQuitsLaterWhenBusy(self):
"""
L{Team.quit} causes all busy workers to be quit once they've finished
the work they've been given.
"""
self.team.grow(10)
for x in range(5):
self.team.do(list)
self.coordinate()
self.team.quit()
self.coordinate()
self.assertEqual(len(self.allUnquitWorkers), 5)
self.performAllOutstandingWork()
self.assertEqual(len(self.allUnquitWorkers), 0)
self.assertRaises(AlreadyQuit, self.coordinator.quit)
def test_quitConcurrentWithWorkHappening(self):
"""
If work happens after L{Team.quit} sets its C{Quit} flag, but before
any other work takes place, the L{Team} should still exit gracefully.
"""
self.team.do(list)
originalSet = self.team._quit.set
def performWorkConcurrently():
originalSet()
self.performAllOutstandingWork()
self.team._quit.set = performWorkConcurrently
self.team.quit()
self.assertRaises(AlreadyQuit, self.team.quit)
self.assertRaises(AlreadyQuit, self.team.do, list)
def test_shrinkWhenBusy(self):
"""
L{Team.shrink} will wait for busy workers to finish being busy and then
quit them.
"""
for x in range(10):
self.team.do(list)
self.coordinate()
self.assertEqual(len(self.allUnquitWorkers), 10)
# There should be 10 busy workers at this point.
self.team.shrink(7)
self.performAllOutstandingWork()
self.assertEqual(len(self.allUnquitWorkers), 3)

View file

@ -0,0 +1,308 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted._threads._threadworker}.
"""
from __future__ import absolute_import, division, print_function
import gc
import weakref
from twisted.trial.unittest import SynchronousTestCase
from threading import ThreadError, local
from .. import ThreadWorker, LockWorker, AlreadyQuit
class FakeQueueEmpty(Exception):
"""
L{FakeQueue}'s C{get} has exhausted the queue.
"""
class WouldDeadlock(Exception):
"""
If this were a real lock, you'd be deadlocked because the lock would be
double-acquired.
"""
class FakeThread(object):
"""
A fake L{threading.Thread}.
@ivar target: A target function to run.
@type target: L{callable}
@ivar started: Has this thread been started?
@type started: L{bool}
"""
def __init__(self, target):
"""
Create a L{FakeThread} with a target.
"""
self.target = target
self.started = False
def start(self):
"""
Set the "started" flag.
"""
self.started = True
class FakeQueue(object):
"""
A fake L{Queue} implementing C{put} and C{get}.
@ivar items: A lit of items placed by C{put} but not yet retrieved by
C{get}.
@type items: L{list}
"""
def __init__(self):
"""
Create a L{FakeQueue}.
"""
self.items = []
def put(self, item):
"""
Put an item into the queue for later retrieval by L{FakeQueue.get}.
@param item: any object
"""
self.items.append(item)
def get(self):
"""
Get an item.
@return: an item previously put by C{put}.
"""
if not self.items:
raise FakeQueueEmpty()
return self.items.pop(0)
class FakeLock(object):
"""
A stand-in for L{threading.Lock}.
@ivar acquired: Whether this lock is presently acquired.
"""
def __init__(self):
"""
Create a lock in the un-acquired state.
"""
self.acquired = False
def acquire(self):
"""
Acquire the lock. Raise an exception if the lock is already acquired.
"""
if self.acquired:
raise WouldDeadlock()
self.acquired = True
def release(self):
"""
Release the lock. Raise an exception if the lock is not presently
acquired.
"""
if not self.acquired:
raise ThreadError()
self.acquired = False
class ThreadWorkerTests(SynchronousTestCase):
"""
Tests for L{ThreadWorker}.
"""
def setUp(self):
"""
Create a worker with fake threads.
"""
self.fakeThreads = []
self.fakeQueue = FakeQueue()
def startThread(target):
newThread = FakeThread(target=target)
newThread.start()
self.fakeThreads.append(newThread)
return newThread
self.worker = ThreadWorker(startThread, self.fakeQueue)
def test_startsThreadAndPerformsWork(self):
"""
L{ThreadWorker} calls its C{createThread} callable to create a thread,
its C{createQueue} callable to create a queue, and then the thread's
target pulls work from that queue.
"""
self.assertEqual(len(self.fakeThreads), 1)
self.assertEqual(self.fakeThreads[0].started, True)
def doIt():
doIt.done = True
doIt.done = False
self.worker.do(doIt)
self.assertEqual(doIt.done, False)
self.assertRaises(FakeQueueEmpty, self.fakeThreads[0].target)
self.assertEqual(doIt.done, True)
def test_quitPreventsFutureCalls(self):
"""
L{ThreadWorker.quit} causes future calls to L{ThreadWorker.do} and
L{ThreadWorker.quit} to raise L{AlreadyQuit}.
"""
self.worker.quit()
self.assertRaises(AlreadyQuit, self.worker.quit)
self.assertRaises(AlreadyQuit, self.worker.do, list)
class LockWorkerTests(SynchronousTestCase):
"""
Tests for L{LockWorker}.
"""
def test_fakeDeadlock(self):
"""
The L{FakeLock} test fixture will alert us if there's a potential
deadlock.
"""
lock = FakeLock()
lock.acquire()
self.assertRaises(WouldDeadlock, lock.acquire)
def test_fakeDoubleRelease(self):
"""
The L{FakeLock} test fixture will alert us if there's a potential
double-release.
"""
lock = FakeLock()
self.assertRaises(ThreadError, lock.release)
lock.acquire()
self.assertEqual(None, lock.release())
self.assertRaises(ThreadError, lock.release)
def test_doExecutesImmediatelyWithLock(self):
"""
L{LockWorker.do} immediately performs the work it's given, while the
lock is acquired.
"""
storage = local()
lock = FakeLock()
worker = LockWorker(lock, storage)
def work():
work.done = True
work.acquired = lock.acquired
work.done = False
worker.do(work)
self.assertEqual(work.done, True)
self.assertEqual(work.acquired, True)
self.assertEqual(lock.acquired, False)
def test_doUnwindsReentrancy(self):
"""
If L{LockWorker.do} is called recursively, it postpones the inner call
until the outer one is complete.
"""
lock = FakeLock()
worker = LockWorker(lock, local())
levels = []
acquired = []
def work():
work.level += 1
levels.append(work.level)
acquired.append(lock.acquired)
if len(levels) < 2:
worker.do(work)
work.level -= 1
work.level = 0
worker.do(work)
self.assertEqual(levels, [1, 1])
self.assertEqual(acquired, [True, True])
def test_quit(self):
"""
L{LockWorker.quit} frees the resources associated with its lock and
causes further calls to C{do} and C{quit} to fail.
"""
lock = FakeLock()
ref = weakref.ref(lock)
worker = LockWorker(lock, local())
lock = None
self.assertIsNot(ref(), None)
worker.quit()
gc.collect()
self.assertIs(ref(), None)
self.assertRaises(AlreadyQuit, worker.quit)
self.assertRaises(AlreadyQuit, worker.do, list)
def test_quitWhileWorking(self):
"""
If L{LockWorker.quit} is invoked during a call to L{LockWorker.do}, all
recursive work scheduled with L{LockWorker.do} will be completed and
the lock will be released.
"""
lock = FakeLock()
ref = weakref.ref(lock)
worker = LockWorker(lock, local())
def phase1():
worker.do(phase2)
worker.quit()
self.assertRaises(AlreadyQuit, worker.do, list)
phase1.complete = True
phase1.complete = False
def phase2():
phase2.complete = True
phase2.acquired = lock.acquired
phase2.complete = False
worker.do(phase1)
self.assertEqual(phase1.complete, True)
self.assertEqual(phase2.complete, True)
self.assertEqual(lock.acquired, False)
lock = None
gc.collect()
self.assertIs(ref(), None)
def test_quitWhileGettingLock(self):
"""
If L{LockWorker.do} is called concurrently with L{LockWorker.quit}, and
C{quit} wins the race before C{do} gets the lock attribute, then
L{AlreadyQuit} will be raised.
"""
class RacyLockWorker(LockWorker):
def _lock_get(self):
self.quit()
return self.__dict__['_lock']
def _lock_set(self, value):
self.__dict__['_lock'] = value
_lock = property(_lock_get, _lock_set)
worker = RacyLockWorker(FakeLock(), local())
self.assertRaises(AlreadyQuit, worker.do, list)

View file

@ -0,0 +1,11 @@
"""
Provides Twisted version information.
"""
# This file is auto-generated! Do not edit!
# Use `python -m incremental.update Twisted` to change this file.
from incremental import Version
__version__ = Version('Twisted', 20, 3, 0)
__all__ = ["__version__"]

View file

@ -0,0 +1,6 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Configuration objects for Twisted Applications.
"""

View file

@ -0,0 +1,708 @@
# -*- test-case-name: twisted.test.test_application,twisted.test.test_twistd -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
from __future__ import absolute_import, division, print_function
import sys
import os
import pdb
import getpass
import traceback
import signal
import warnings
from operator import attrgetter
from twisted import copyright, plugin, logger
from twisted.application import service, reactors
from twisted.internet import defer
from twisted.persisted import sob
from twisted.python import runtime, log, usage, failure, util, logfile
from twisted.python._oldstyle import _oldStyle
from twisted.python.reflect import (qual, namedAny, namedModule)
from twisted.internet.interfaces import _ISupportsExitSignalCapturing
# Expose the new implementation of installReactor at the old location.
from twisted.application.reactors import installReactor
from twisted.application.reactors import NoSuchReactor
class _BasicProfiler(object):
"""
@ivar saveStats: if C{True}, save the stats information instead of the
human readable format
@type saveStats: C{bool}
@ivar profileOutput: the name of the file use to print profile data.
@type profileOutput: C{str}
"""
def __init__(self, profileOutput, saveStats):
self.profileOutput = profileOutput
self.saveStats = saveStats
def _reportImportError(self, module, e):
"""
Helper method to report an import error with a profile module. This
has to be explicit because some of these modules are removed by
distributions due to them being non-free.
"""
s = "Failed to import module %s: %s" % (module, e)
s += """
This is most likely caused by your operating system not including
the module due to it being non-free. Either do not use the option
--profile, or install the module; your operating system vendor
may provide it in a separate package.
"""
raise SystemExit(s)
class ProfileRunner(_BasicProfiler):
"""
Runner for the standard profile module.
"""
def run(self, reactor):
"""
Run reactor under the standard profiler.
"""
try:
import profile
except ImportError as e:
self._reportImportError("profile", e)
p = profile.Profile()
p.runcall(reactor.run)
if self.saveStats:
p.dump_stats(self.profileOutput)
else:
tmp, sys.stdout = sys.stdout, open(self.profileOutput, 'a')
try:
p.print_stats()
finally:
sys.stdout, tmp = tmp, sys.stdout
tmp.close()
class CProfileRunner(_BasicProfiler):
"""
Runner for the cProfile module.
"""
def run(self, reactor):
"""
Run reactor under the cProfile profiler.
"""
try:
import cProfile
import pstats
except ImportError as e:
self._reportImportError("cProfile", e)
p = cProfile.Profile()
p.runcall(reactor.run)
if self.saveStats:
p.dump_stats(self.profileOutput)
else:
with open(self.profileOutput, 'w') as stream:
s = pstats.Stats(p, stream=stream)
s.strip_dirs()
s.sort_stats(-1)
s.print_stats()
class AppProfiler(object):
"""
Class which selects a specific profile runner based on configuration
options.
@ivar profiler: the name of the selected profiler.
@type profiler: C{str}
"""
profilers = {"profile": ProfileRunner, "cprofile": CProfileRunner}
def __init__(self, options):
saveStats = options.get("savestats", False)
profileOutput = options.get("profile", None)
self.profiler = options.get("profiler", "cprofile").lower()
if self.profiler in self.profilers:
profiler = self.profilers[self.profiler](profileOutput, saveStats)
self.run = profiler.run
else:
raise SystemExit("Unsupported profiler name: %s" %
(self.profiler,))
class AppLogger(object):
"""
An L{AppLogger} attaches the configured log observer specified on the
commandline to a L{ServerOptions} object, a custom L{logger.ILogObserver},
or a legacy custom {log.ILogObserver}.
@ivar _logfilename: The name of the file to which to log, if other than the
default.
@type _logfilename: C{str}
@ivar _observerFactory: Callable object that will create a log observer, or
None.
@ivar _observer: log observer added at C{start} and removed at C{stop}.
@type _observer: a callable that implements L{logger.ILogObserver} or
L{log.ILogObserver}.
"""
_observer = None
def __init__(self, options):
"""
Initialize an L{AppLogger} with a L{ServerOptions}.
"""
self._logfilename = options.get("logfile", "")
self._observerFactory = options.get("logger") or None
def start(self, application):
"""
Initialize the global logging system for the given application.
If a custom logger was specified on the command line it will be used.
If not, and an L{logger.ILogObserver} or legacy L{log.ILogObserver}
component has been set on C{application}, then it will be used as the
log observer. Otherwise a log observer will be created based on the
command line options for built-in loggers (e.g. C{--logfile}).
@param application: The application on which to check for an
L{logger.ILogObserver} or legacy L{log.ILogObserver}.
@type application: L{twisted.python.components.Componentized}
"""
if self._observerFactory is not None:
observer = self._observerFactory()
else:
observer = application.getComponent(logger.ILogObserver, None)
if observer is None:
# If there's no new ILogObserver, try the legacy one
observer = application.getComponent(log.ILogObserver, None)
if observer is None:
observer = self._getLogObserver()
self._observer = observer
if logger.ILogObserver.providedBy(self._observer):
observers = [self._observer]
elif log.ILogObserver.providedBy(self._observer):
observers = [logger.LegacyLogObserverWrapper(self._observer)]
else:
warnings.warn(
("Passing a logger factory which makes log observers which do "
"not implement twisted.logger.ILogObserver or "
"twisted.python.log.ILogObserver to "
"twisted.application.app.AppLogger was deprecated in "
"Twisted 16.2. Please use a factory that produces "
"twisted.logger.ILogObserver (or the legacy "
"twisted.python.log.ILogObserver) implementing objects "
"instead."),
DeprecationWarning,
stacklevel=2)
observers = [logger.LegacyLogObserverWrapper(self._observer)]
logger.globalLogBeginner.beginLoggingTo(observers)
self._initialLog()
def _initialLog(self):
"""
Print twistd start log message.
"""
from twisted.internet import reactor
logger._loggerFor(self).info(
"twistd {version} ({exe} {pyVersion}) starting up.",
version=copyright.version, exe=sys.executable,
pyVersion=runtime.shortPythonVersion())
logger._loggerFor(self).info('reactor class: {reactor}.',
reactor=qual(reactor.__class__))
def _getLogObserver(self):
"""
Create a log observer to be added to the logging system before running
this application.
"""
if self._logfilename == '-' or not self._logfilename:
logFile = sys.stdout
else:
logFile = logfile.LogFile.fromFullPath(self._logfilename)
return logger.textFileLogObserver(logFile)
def stop(self):
"""
Remove all log observers previously set up by L{AppLogger.start}.
"""
logger._loggerFor(self).info("Server Shut Down.")
if self._observer is not None:
logger.globalLogPublisher.removeObserver(self._observer)
self._observer = None
def fixPdb():
def do_stop(self, arg):
self.clear_all_breaks()
self.set_continue()
from twisted.internet import reactor
reactor.callLater(0, reactor.stop)
return 1
def help_stop(self):
print("stop - Continue execution, then cleanly shutdown the twisted "
"reactor.")
def set_quit(self):
os._exit(0)
pdb.Pdb.set_quit = set_quit
pdb.Pdb.do_stop = do_stop
pdb.Pdb.help_stop = help_stop
def runReactorWithLogging(config, oldstdout, oldstderr, profiler=None,
reactor=None):
"""
Start the reactor, using profiling if specified by the configuration, and
log any error happening in the process.
@param config: configuration of the twistd application.
@type config: L{ServerOptions}
@param oldstdout: initial value of C{sys.stdout}.
@type oldstdout: C{file}
@param oldstderr: initial value of C{sys.stderr}.
@type oldstderr: C{file}
@param profiler: object used to run the reactor with profiling.
@type profiler: L{AppProfiler}
@param reactor: The reactor to use. If L{None}, the global reactor will
be used.
"""
if reactor is None:
from twisted.internet import reactor
try:
if config['profile']:
if profiler is not None:
profiler.run(reactor)
elif config['debug']:
sys.stdout = oldstdout
sys.stderr = oldstderr
if runtime.platformType == 'posix':
signal.signal(signal.SIGUSR2, lambda *args: pdb.set_trace())
signal.signal(signal.SIGINT, lambda *args: pdb.set_trace())
fixPdb()
pdb.runcall(reactor.run)
else:
reactor.run()
except:
close = False
if config['nodaemon']:
file = oldstdout
else:
file = open("TWISTD-CRASH.log", "a")
close = True
try:
traceback.print_exc(file=file)
file.flush()
finally:
if close:
file.close()
def getPassphrase(needed):
if needed:
return getpass.getpass('Passphrase: ')
else:
return None
def getSavePassphrase(needed):
if needed:
return util.getPassword("Encryption passphrase: ")
else:
return None
class ApplicationRunner(object):
"""
An object which helps running an application based on a config object.
Subclass me and implement preApplication and postApplication
methods. postApplication generally will want to run the reactor
after starting the application.
@ivar config: The config object, which provides a dict-like interface.
@ivar application: Available in postApplication, but not
preApplication. This is the application object.
@ivar profilerFactory: Factory for creating a profiler object, able to
profile the application if options are set accordingly.
@ivar profiler: Instance provided by C{profilerFactory}.
@ivar loggerFactory: Factory for creating object responsible for logging.
@ivar logger: Instance provided by C{loggerFactory}.
"""
profilerFactory = AppProfiler
loggerFactory = AppLogger
def __init__(self, config):
self.config = config
self.profiler = self.profilerFactory(config)
self.logger = self.loggerFactory(config)
def run(self):
"""
Run the application.
"""
self.preApplication()
self.application = self.createOrGetApplication()
self.logger.start(self.application)
self.postApplication()
self.logger.stop()
def startReactor(self, reactor, oldstdout, oldstderr):
"""
Run the reactor with the given configuration. Subclasses should
probably call this from C{postApplication}.
@see: L{runReactorWithLogging}
"""
if reactor is None:
from twisted.internet import reactor
runReactorWithLogging(
self.config, oldstdout, oldstderr, self.profiler, reactor)
if _ISupportsExitSignalCapturing.providedBy(reactor):
self._exitSignal = reactor._exitSignal
else:
self._exitSignal = None
def preApplication(self):
"""
Override in subclass.
This should set up any state necessary before loading and
running the Application.
"""
raise NotImplementedError()
def postApplication(self):
"""
Override in subclass.
This will be called after the application has been loaded (so
the C{application} attribute will be set). Generally this
should start the application and run the reactor.
"""
raise NotImplementedError()
def createOrGetApplication(self):
"""
Create or load an Application based on the parameters found in the
given L{ServerOptions} instance.
If a subcommand was used, the L{service.IServiceMaker} that it
represents will be used to construct a service to be added to
a newly-created Application.
Otherwise, an application will be loaded based on parameters in
the config.
"""
if self.config.subCommand:
# If a subcommand was given, it's our responsibility to create
# the application, instead of load it from a file.
# loadedPlugins is set up by the ServerOptions.subCommands
# property, which is iterated somewhere in the bowels of
# usage.Options.
plg = self.config.loadedPlugins[self.config.subCommand]
ser = plg.makeService(self.config.subOptions)
application = service.Application(plg.tapname)
ser.setServiceParent(application)
else:
passphrase = getPassphrase(self.config['encrypted'])
application = getApplication(self.config, passphrase)
return application
def getApplication(config, passphrase):
s = [(config[t], t)
for t in ['python', 'source', 'file'] if config[t]][0]
filename, style = s[0], {'file': 'pickle'}.get(s[1], s[1])
try:
log.msg("Loading %s..." % filename)
application = service.loadApplication(filename, style, passphrase)
log.msg("Loaded.")
except Exception as e:
s = "Failed to load application: %s" % e
if isinstance(e, KeyError) and e.args[0] == "application":
s += """
Could not find 'application' in the file. To use 'twistd -y', your .tac
file must create a suitable object (e.g., by calling service.Application())
and store it in a variable named 'application'. twistd loads your .tac file
and scans the global variables for one of this name.
Please read the 'Using Application' HOWTO for details.
"""
traceback.print_exc(file=log.logfile)
log.msg(s)
log.deferr()
sys.exit('\n' + s + '\n')
return application
def _reactorAction():
return usage.CompleteList([r.shortName for r in
reactors.getReactorTypes()])
@_oldStyle
class ReactorSelectionMixin:
"""
Provides options for selecting a reactor to install.
If a reactor is installed, the short name which was used to locate it is
saved as the value for the C{"reactor"} key.
"""
compData = usage.Completions(
optActions={"reactor": _reactorAction})
messageOutput = sys.stdout
_getReactorTypes = staticmethod(reactors.getReactorTypes)
def opt_help_reactors(self):
"""
Display a list of possibly available reactor names.
"""
rcts = sorted(self._getReactorTypes(), key=attrgetter('shortName'))
notWorkingReactors = ""
for r in rcts:
try:
namedModule(r.moduleName)
self.messageOutput.write(' %-4s\t%s\n' %
(r.shortName, r.description))
except ImportError as e:
notWorkingReactors += (' !%-4s\t%s (%s)\n' %
(r.shortName, r.description, e.args[0]))
if notWorkingReactors:
self.messageOutput.write('\n')
self.messageOutput.write(' reactors not available '
'on this platform:\n\n')
self.messageOutput.write(notWorkingReactors)
raise SystemExit(0)
def opt_reactor(self, shortName):
"""
Which reactor to use (see --help-reactors for a list of possibilities)
"""
# Actually actually actually install the reactor right at this very
# moment, before any other code (for example, a sub-command plugin)
# runs and accidentally imports and installs the default reactor.
#
# This could probably be improved somehow.
try:
installReactor(shortName)
except NoSuchReactor:
msg = ("The specified reactor does not exist: '%s'.\n"
"See the list of available reactors with "
"--help-reactors" % (shortName,))
raise usage.UsageError(msg)
except Exception as e:
msg = ("The specified reactor cannot be used, failed with error: "
"%s.\nSee the list of available reactors with "
"--help-reactors" % (e,))
raise usage.UsageError(msg)
else:
self["reactor"] = shortName
opt_r = opt_reactor
class ServerOptions(usage.Options, ReactorSelectionMixin):
longdesc = ("twistd reads a twisted.application.service.Application out "
"of a file and runs it.")
optFlags = [['savestats', None,
"save the Stats object rather than the text output of "
"the profiler."],
['no_save', 'o', "do not save state on shutdown"],
['encrypted', 'e',
"The specified tap/aos file is encrypted."]]
optParameters = [['logfile', 'l', None,
"log to a specified file, - for stdout"],
['logger', None, None,
"A fully-qualified name to a log observer factory to "
"use for the initial log observer. Takes precedence "
"over --logfile and --syslog (when available)."],
['profile', 'p', None,
"Run in profile mode, dumping results to specified "
"file."],
['profiler', None, "cprofile",
"Name of the profiler to use (%s)." %
", ".join(AppProfiler.profilers)],
['file', 'f', 'twistd.tap',
"read the given .tap file"],
['python', 'y', None,
"read an application from within a Python file "
"(implies -o)"],
['source', 's', None,
"Read an application from a .tas file (AOT format)."],
['rundir', 'd', '.',
'Change to a supplied directory before running']]
compData = usage.Completions(
mutuallyExclusive=[("file", "python", "source")],
optActions={"file": usage.CompleteFiles("*.tap"),
"python": usage.CompleteFiles("*.(tac|py)"),
"source": usage.CompleteFiles("*.tas"),
"rundir": usage.CompleteDirs()}
)
_getPlugins = staticmethod(plugin.getPlugins)
def __init__(self, *a, **kw):
self['debug'] = False
if 'stdout' in kw:
self.stdout = kw['stdout']
else:
self.stdout = sys.stdout
usage.Options.__init__(self)
def opt_debug(self):
"""
Run the application in the Python Debugger (implies nodaemon),
sending SIGUSR2 will drop into debugger
"""
defer.setDebugging(True)
failure.startDebugMode()
self['debug'] = True
opt_b = opt_debug
def opt_spew(self):
"""
Print an insanely verbose log of everything that happens.
Useful when debugging freezes or locks in complex code.
"""
sys.settrace(util.spewer)
try:
import threading
except ImportError:
return
threading.settrace(util.spewer)
def parseOptions(self, options=None):
if options is None:
options = sys.argv[1:] or ["--help"]
usage.Options.parseOptions(self, options)
def postOptions(self):
if self.subCommand or self['python']:
self['no_save'] = True
if self['logger'] is not None:
try:
self['logger'] = namedAny(self['logger'])
except Exception as e:
raise usage.UsageError("Logger '%s' could not be imported: %s"
% (self['logger'], e))
def subCommands(self):
plugins = self._getPlugins(service.IServiceMaker)
self.loadedPlugins = {}
for plug in sorted(plugins, key=attrgetter('tapname')):
self.loadedPlugins[plug.tapname] = plug
yield (plug.tapname,
None,
# Avoid resolving the options attribute right away, in case
# it's a property with a non-trivial getter (eg, one which
# imports modules).
lambda plug=plug: plug.options(),
plug.description)
subCommands = property(subCommands)
def run(runApp, ServerOptions):
config = ServerOptions()
try:
config.parseOptions()
except usage.error as ue:
print(config)
print("%s: %s" % (sys.argv[0], ue))
else:
runApp(config)
def convertStyle(filein, typein, passphrase, fileout, typeout, encrypt):
application = service.loadApplication(filein, typein, passphrase)
sob.IPersistable(application).setStyle(typeout)
passphrase = getSavePassphrase(encrypt)
if passphrase:
fileout = None
sob.IPersistable(application).save(filename=fileout, passphrase=passphrase)
def startApplication(application, save):
from twisted.internet import reactor
service.IService(application).startService()
if save:
p = sob.IPersistable(application)
reactor.addSystemEventTrigger('after', 'shutdown', p.save, 'shutdown')
reactor.addSystemEventTrigger('before', 'shutdown',
service.IService(application).stopService)
def _exitWithSignal(sig):
"""
Force the application to terminate with the specified signal by replacing
the signal handler with the default and sending the signal to ourselves.
@param sig: Signal to use to terminate the process with C{os.kill}.
@type sig: C{int}
"""
signal.signal(sig, signal.SIG_DFL)
os.kill(os.getpid(), sig)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,85 @@
# -*- test-case-name: twisted.test.test_application -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Plugin-based system for enumerating available reactors and installing one of
them.
"""
from __future__ import absolute_import, division
from zope.interface import Interface, Attribute, implementer
from twisted.plugin import IPlugin, getPlugins
from twisted.python.reflect import namedAny
class IReactorInstaller(Interface):
"""
Definition of a reactor which can probably be installed.
"""
shortName = Attribute("""
A brief string giving the user-facing name of this reactor.
""")
description = Attribute("""
A longer string giving a user-facing description of this reactor.
""")
def install():
"""
Install this reactor.
"""
# TODO - A method which provides a best-guess as to whether this reactor
# can actually be used in the execution environment.
class NoSuchReactor(KeyError):
"""
Raised when an attempt is made to install a reactor which cannot be found.
"""
@implementer(IPlugin, IReactorInstaller)
class Reactor(object):
"""
@ivar moduleName: The fully-qualified Python name of the module of which
the install callable is an attribute.
"""
def __init__(self, shortName, moduleName, description):
self.shortName = shortName
self.moduleName = moduleName
self.description = description
def install(self):
namedAny(self.moduleName).install()
def getReactorTypes():
"""
Return an iterator of L{IReactorInstaller} plugins.
"""
return getPlugins(IReactorInstaller)
def installReactor(shortName):
"""
Install the reactor with the given C{shortName} attribute.
@raise NoSuchReactor: If no reactor is found with a matching C{shortName}.
@raise: anything that the specified reactor can raise when installed.
"""
for installer in getReactorTypes():
if installer.shortName == shortName:
installer.install()
from twisted.internet import reactor
return reactor
raise NoSuchReactor(shortName)

View file

@ -0,0 +1,7 @@
# -*- test-case-name: twisted.application.runner.test -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Facilities for running a Twisted application.
"""

View file

@ -0,0 +1,138 @@
# -*- test-case-name: twisted.application.runner.test.test_exit -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
System exit support.
"""
from sys import stdout, stderr, exit as sysexit
from constantly import Values, ValueConstant
def exit(status, message=None):
"""
Exit the python interpreter with the given status and an optional message.
@param status: An exit status.
@type status: L{int} or L{ValueConstant} from L{ExitStatus}.
@param message: An options message to print.
@type status: L{str}
"""
if isinstance(status, ValueConstant):
code = status.value
else:
code = int(status)
if message:
if code == 0:
out = stdout
else:
out = stderr
out.write(message)
out.write("\n")
sysexit(code)
try:
import posix as Status
except ImportError:
class Status(object):
"""
Object to hang C{EX_*} values off of as a substitute for L{posix}.
"""
EX__BASE = 64
EX_OK = 0
EX_USAGE = EX__BASE
EX_DATAERR = EX__BASE + 1
EX_NOINPUT = EX__BASE + 2
EX_NOUSER = EX__BASE + 3
EX_NOHOST = EX__BASE + 4
EX_UNAVAILABLE = EX__BASE + 5
EX_SOFTWARE = EX__BASE + 6
EX_OSERR = EX__BASE + 7
EX_OSFILE = EX__BASE + 8
EX_CANTCREAT = EX__BASE + 9
EX_IOERR = EX__BASE + 10
EX_TEMPFAIL = EX__BASE + 11
EX_PROTOCOL = EX__BASE + 12
EX_NOPERM = EX__BASE + 13
EX_CONFIG = EX__BASE + 14
class ExitStatus(Values):
"""
Standard exit status codes for system programs.
@cvar EX_OK: Successful termination.
@type EX_OK: L{ValueConstant}
@cvar EX_USAGE: Command line usage error.
@type EX_USAGE: L{ValueConstant}
@cvar EX_DATAERR: Data format error.
@type EX_DATAERR: L{ValueConstant}
@cvar EX_NOINPUT: Cannot open input.
@type EX_NOINPUT: L{ValueConstant}
@cvar EX_NOUSER: Addressee unknown.
@type EX_NOUSER: L{ValueConstant}
@cvar EX_NOHOST: Host name unknown.
@type EX_NOHOST: L{ValueConstant}
@cvar EX_UNAVAILABLE: Service unavailable.
@type EX_UNAVAILABLE: L{ValueConstant}
@cvar EX_SOFTWARE: Internal software error.
@type EX_SOFTWARE: L{ValueConstant}
@cvar EX_OSERR: System error (e.g., can't fork).
@type EX_OSERR: L{ValueConstant}
@cvar EX_OSFILE: Critical OS file missing.
@type EX_OSFILE: L{ValueConstant}
@cvar EX_CANTCREAT: Can't create (user) output file.
@type EX_CANTCREAT: L{ValueConstant}
@cvar EX_IOERR: Input/output error.
@type EX_IOERR: L{ValueConstant}
@cvar EX_TEMPFAIL: Temporary failure; the user is invited to retry.
@type EX_TEMPFAIL: L{ValueConstant}
@cvar EX_PROTOCOL: Remote error in protocol.
@type EX_PROTOCOL: L{ValueConstant}
@cvar EX_NOPERM: Permission denied.
@type EX_NOPERM: L{ValueConstant}
@cvar EX_CONFIG: Configuration error.
@type EX_CONFIG: L{ValueConstant}
"""
EX_OK = ValueConstant(Status.EX_OK)
EX_USAGE = ValueConstant(Status.EX_USAGE)
EX_DATAERR = ValueConstant(Status.EX_DATAERR)
EX_NOINPUT = ValueConstant(Status.EX_NOINPUT)
EX_NOUSER = ValueConstant(Status.EX_NOUSER)
EX_NOHOST = ValueConstant(Status.EX_NOHOST)
EX_UNAVAILABLE = ValueConstant(Status.EX_UNAVAILABLE)
EX_SOFTWARE = ValueConstant(Status.EX_SOFTWARE)
EX_OSERR = ValueConstant(Status.EX_OSERR)
EX_OSFILE = ValueConstant(Status.EX_OSFILE)
EX_CANTCREAT = ValueConstant(Status.EX_CANTCREAT)
EX_IOERR = ValueConstant(Status.EX_IOERR)
EX_TEMPFAIL = ValueConstant(Status.EX_TEMPFAIL)
EX_PROTOCOL = ValueConstant(Status.EX_PROTOCOL)
EX_NOPERM = ValueConstant(Status.EX_NOPERM)
EX_CONFIG = ValueConstant(Status.EX_CONFIG)

View file

@ -0,0 +1,303 @@
# -*- test-case-name: twisted.application.runner.test.test_pidfile -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
PID file.
"""
import errno
from os import getpid, kill, name as SYSTEM_NAME
from zope.interface import Interface, implementer
from twisted.logger import Logger
class IPIDFile(Interface):
"""
Manages a file that remembers a process ID.
"""
def read():
"""
Read the process ID stored in this PID file.
@return: The contained process ID.
@rtype: L{int}
@raise NoPIDFound: If this PID file does not exist.
@raise EnvironmentError: If this PID file cannot be read.
@raise ValueError: If this PID file's content is invalid.
"""
def writeRunningPID():
"""
Store the PID of the current process in this PID file.
@raise EnvironmentError: If this PID file cannot be written.
"""
def remove():
"""
Remove this PID file.
@raise EnvironmentError: If this PID file cannot be removed.
"""
def isRunning():
"""
Determine whether there is a running process corresponding to the PID
in this PID file.
@return: True if this PID file contains a PID and a process with that
PID is currently running; false otherwise.
@rtype: L{bool}
@raise EnvironmentError: If this PID file cannot be read.
@raise InvalidPIDFileError: If this PID file's content is invalid.
@raise StalePIDFileError: If this PID file's content refers to a PID
for which there is no corresponding running process.
"""
def __enter__():
"""
Enter a context using this PIDFile.
Writes the PID file with the PID of the running process.
@raise AlreadyRunningError: A process corresponding to the PID in this
PID file is already running.
"""
def __exit__(excType, excValue, traceback):
"""
Exit a context using this PIDFile.
Removes the PID file.
"""
@implementer(IPIDFile)
class PIDFile(object):
"""
Concrete implementation of L{IPIDFile} based on C{IFilePath}.
This implementation is presently not supported on non-POSIX platforms.
Specifically, calling L{PIDFile.isRunning} will raise
L{NotImplementedError}.
"""
_log = Logger()
@staticmethod
def _format(pid):
"""
Format a PID file's content.
@param pid: A process ID.
@type pid: int
@return: Formatted PID file contents.
@rtype: L{bytes}
"""
return u"{}\n".format(int(pid)).encode("utf-8")
def __init__(self, filePath):
"""
@param filePath: The path to the PID file on disk.
@type filePath: L{IFilePath}
"""
self.filePath = filePath
def read(self):
pidString = b""
try:
with self.filePath.open() as fh:
for pidString in fh:
break
except OSError as e:
if e.errno == errno.ENOENT: # No such file
raise NoPIDFound("PID file does not exist")
raise
try:
return int(pidString)
except ValueError:
raise InvalidPIDFileError(
"non-integer PID value in PID file: {!r}".format(pidString)
)
def _write(self, pid):
"""
Store a PID in this PID file.
@param pid: A PID to store.
@type pid: L{int}
@raise EnvironmentError: If this PID file cannot be written.
"""
self.filePath.setContent(self._format(pid=pid))
def writeRunningPID(self):
self._write(getpid())
def remove(self):
self.filePath.remove()
def isRunning(self):
try:
pid = self.read()
except NoPIDFound:
return False
if SYSTEM_NAME == "posix":
return self._pidIsRunningPOSIX(pid)
else:
raise NotImplementedError(
"isRunning is not implemented on {}".format(SYSTEM_NAME)
)
@staticmethod
def _pidIsRunningPOSIX(pid):
"""
POSIX implementation for running process check.
Determine whether there is a running process corresponding to the given
PID.
@return: True if the given PID is currently running; false otherwise.
@rtype: L{bool}
@raise EnvironmentError: If this PID file cannot be read.
@raise InvalidPIDFileError: If this PID file's content is invalid.
@raise StalePIDFileError: If this PID file's content refers to a PID
for which there is no corresponding running process.
"""
try:
kill(pid, 0)
except OSError as e:
if e.errno == errno.ESRCH: # No such process
raise StalePIDFileError(
"PID file refers to non-existing process"
)
elif e.errno == errno.EPERM: # Not permitted to kill
return True
else:
raise
else:
return True
def __enter__(self):
try:
if self.isRunning():
raise AlreadyRunningError()
except StalePIDFileError:
self._log.info("Replacing stale PID file: {log_source}")
self.writeRunningPID()
return self
def __exit__(self, excType, excValue, traceback):
self.remove()
@implementer(IPIDFile)
class NonePIDFile(object):
"""
PID file implementation that does nothing.
This is meant to be used as a "active None" object in place of a PID file
when no PID file is desired.
"""
def __init__(self):
pass
def read(self):
raise NoPIDFound("PID file does not exist")
def _write(self, pid):
"""
Store a PID in this PID file.
@param pid: A PID to store.
@type pid: L{int}
@raise EnvironmentError: If this PID file cannot be written.
@note: This implementation always raises an L{EnvironmentError}.
"""
raise OSError(errno.EPERM, "Operation not permitted")
def writeRunningPID(self):
self._write(0)
def remove(self):
raise OSError(errno.ENOENT, "No such file or directory")
def isRunning(self):
return False
def __enter__(self):
return self
def __exit__(self, excType, excValue, traceback):
pass
nonePIDFile = NonePIDFile()
class AlreadyRunningError(Exception):
"""
Process is already running.
"""
class InvalidPIDFileError(Exception):
"""
PID file contents are invalid.
"""
class StalePIDFileError(Exception):
"""
PID file contents are valid, but there is no process with the referenced
PID.
"""
class NoPIDFound(Exception):
"""
No PID found in PID file.
"""

View file

@ -0,0 +1,185 @@
# -*- test-case-name: twisted.application.runner.test.test_runner -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Twisted application runner.
"""
from sys import stderr
from signal import SIGTERM
from os import kill
from attr import attrib, attrs, Factory
from twisted.logger import (
globalLogBeginner, textFileLogObserver,
FilteringLogObserver, LogLevelFilterPredicate,
LogLevel, Logger,
)
from ._exit import exit, ExitStatus
from ._pidfile import nonePIDFile, AlreadyRunningError, InvalidPIDFileError
@attrs(frozen=True)
class Runner(object):
"""
Twisted application runner.
@cvar _log: The logger attached to this class.
@type _log: L{Logger}
@ivar _reactor: The reactor to start and run the application in.
@type _reactor: L{IReactorCore}
@ivar _pidFile: The file to store the running process ID in.
@type _pidFile: L{IPIDFile}
@ivar _kill: Whether this runner should kill an existing running
instance of the application.
@type _kill: L{bool}
@ivar _defaultLogLevel: The default log level to start the logging
system with.
@type _defaultLogLevel: L{constantly.NamedConstant} from L{LogLevel}
@ivar _logFile: A file stream to write logging output to.
@type _logFile: writable file-like object
@ivar _fileLogObserverFactory: A factory for the file log observer to
use when starting the logging system.
@type _pidFile: callable that takes a single writable file-like object
argument and returns a L{twisted.logger.FileLogObserver}
@ivar _whenRunning: Hook to call after the reactor is running;
this is where the application code that relies on the reactor gets
called.
@type _whenRunning: callable that takes the keyword arguments specified
by C{whenRunningArguments}
@ivar _whenRunningArguments: Keyword arguments to pass to
C{whenRunning} when it is called.
@type _whenRunningArguments: L{dict}
@ivar _reactorExited: Hook to call after the reactor exits.
@type _reactorExited: callable that takes the keyword arguments
specified by C{reactorExitedArguments}
@ivar _reactorExitedArguments: Keyword arguments to pass to
C{reactorExited} when it is called.
@type _reactorExitedArguments: L{dict}
"""
_log = Logger()
_reactor = attrib()
_pidFile = attrib(default=nonePIDFile)
_kill = attrib(default=False)
_defaultLogLevel = attrib(default=LogLevel.info)
_logFile = attrib(default=stderr)
_fileLogObserverFactory = attrib(default=textFileLogObserver)
_whenRunning = attrib(default=lambda **_: None)
_whenRunningArguments = attrib(default=Factory(dict))
_reactorExited = attrib(default=lambda **_: None)
_reactorExitedArguments = attrib(default=Factory(dict))
def run(self):
"""
Run this command.
"""
pidFile = self._pidFile
self.killIfRequested()
try:
with pidFile:
self.startLogging()
self.startReactor()
self.reactorExited()
except AlreadyRunningError:
exit(ExitStatus.EX_CONFIG, "Already running.")
return # When testing, patched exit doesn't exit
def killIfRequested(self):
"""
If C{self._kill} is true, attempt to kill a running instance of the
application.
"""
pidFile = self._pidFile
if self._kill:
if pidFile is nonePIDFile:
exit(ExitStatus.EX_USAGE, "No PID file specified.")
return # When testing, patched exit doesn't exit
try:
pid = pidFile.read()
except EnvironmentError:
exit(ExitStatus.EX_IOERR, "Unable to read PID file.")
return # When testing, patched exit doesn't exit
except InvalidPIDFileError:
exit(ExitStatus.EX_DATAERR, "Invalid PID file.")
return # When testing, patched exit doesn't exit
self.startLogging()
self._log.info("Terminating process: {pid}", pid=pid)
kill(pid, SIGTERM)
exit(ExitStatus.EX_OK)
return # When testing, patched exit doesn't exit
def startLogging(self):
"""
Start the L{twisted.logger} logging system.
"""
logFile = self._logFile
fileLogObserverFactory = self._fileLogObserverFactory
fileLogObserver = fileLogObserverFactory(logFile)
logLevelPredicate = LogLevelFilterPredicate(
defaultLogLevel=self._defaultLogLevel
)
filteringObserver = FilteringLogObserver(
fileLogObserver, [logLevelPredicate]
)
globalLogBeginner.beginLoggingTo([filteringObserver])
def startReactor(self):
"""
Register C{self._whenRunning} with the reactor so that it is called
once the reactor is running, then start the reactor.
"""
self._reactor.callWhenRunning(self.whenRunning)
self._log.info("Starting reactor...")
self._reactor.run()
def whenRunning(self):
"""
Call C{self._whenRunning} with C{self._whenRunningArguments}.
@note: This method is called after the reactor starts running.
"""
self._whenRunning(**self._whenRunningArguments)
def reactorExited(self):
"""
Call C{self._reactorExited} with C{self._reactorExitedArguments}.
@note: This method is called after the reactor exits.
"""
self._reactorExited(**self._reactorExitedArguments)

View file

@ -0,0 +1,7 @@
# -*- test-case-name: twisted.application.runner.test -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.application.runner}.
"""

View file

@ -0,0 +1,104 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.application.runner._exit}.
"""
from twisted.python.compat import NativeStringIO
from ...runner import _exit
from .._exit import exit, ExitStatus
import twisted.trial.unittest
class ExitTests(twisted.trial.unittest.TestCase):
"""
Tests for L{exit}.
"""
def setUp(self):
self.exit = DummyExit()
self.patch(_exit, "sysexit", self.exit)
def test_exitStatusInt(self):
"""
L{exit} given an L{int} status code will pass it to L{sys.exit}.
"""
status = 1234
exit(status)
self.assertEqual(self.exit.arg, status)
def test_exitStatusStringNotInt(self):
"""
L{exit} given a L{str} status code that isn't a string integer raises
L{ValueError}.
"""
self.assertRaises(ValueError, exit, "foo")
def test_exitStatusStringInt(self):
"""
L{exit} given a L{str} status code that is a string integer passes the
corresponding L{int} to L{sys.exit}.
"""
exit("1234")
self.assertEqual(self.exit.arg, 1234)
def test_exitConstant(self):
"""
L{exit} given a L{ValueConstant} status code passes the corresponding
value to L{sys.exit}.
"""
status = ExitStatus.EX_CONFIG
exit(status)
self.assertEqual(self.exit.arg, status.value)
def test_exitMessageZero(self):
"""
L{exit} given a status code of zero (C{0}) writes the given message to
standard output.
"""
out = NativeStringIO()
self.patch(_exit, "stdout", out)
message = "Hello, world."
exit(0, message)
self.assertEqual(out.getvalue(), message + "\n")
def test_exitMessageNonZero(self):
"""
L{exit} given a non-zero status code writes the given message to
standard error.
"""
out = NativeStringIO()
self.patch(_exit, "stderr", out)
message = "Hello, world."
exit(64, message)
self.assertEqual(out.getvalue(), message + "\n")
class DummyExit(object):
"""
Stub for L{sys.exit} that remembers whether it's been called and, if it
has, what argument it was given.
"""
def __init__(self):
self.exited = False
def __call__(self, arg=None):
assert not self.exited
self.arg = arg
self.exited = True

View file

@ -0,0 +1,476 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.application.runner._pidfile}.
"""
from functools import wraps
import errno
from os import getpid, name as SYSTEM_NAME
from io import BytesIO
from zope.interface import implementer
from zope.interface.verify import verifyObject
from twisted.python.filepath import IFilePath
from twisted.python.runtime import platform
from ...runner import _pidfile
from .._pidfile import (
IPIDFile, PIDFile, NonePIDFile,
AlreadyRunningError, InvalidPIDFileError, StalePIDFileError,
NoPIDFound,
)
import twisted.trial.unittest
from twisted.trial.unittest import SkipTest
def ifPlatformSupported(f):
"""
Decorator for tests that are not expected to work on all platforms.
Calling L{PIDFile.isRunning} currently raises L{NotImplementedError} on
non-POSIX platforms.
On an unsupported platform, we expect to see any test that calls
L{PIDFile.isRunning} to raise either L{NotImplementedError}, L{SkipTest},
or C{self.failureException}.
(C{self.failureException} may occur in a test that checks for a specific
exception but it gets NotImplementedError instead.)
@param f: The test method to decorate.
@type f: method
@return: The wrapped callable.
"""
@wraps(f)
def wrapper(self, *args, **kwargs):
supported = platform.getType() == "posix"
if supported:
return f(self, *args, **kwargs)
else:
e = self.assertRaises(
(NotImplementedError, SkipTest, self.failureException),
f, self, *args, **kwargs
)
if isinstance(e, NotImplementedError):
self.assertTrue(
str(e).startswith("isRunning is not implemented on ")
)
return wrapper
class PIDFileTests(twisted.trial.unittest.TestCase):
"""
Tests for L{PIDFile}.
"""
def test_interface(self):
"""
L{PIDFile} conforms to L{IPIDFile}.
"""
pidFile = PIDFile(DummyFilePath())
verifyObject(IPIDFile, pidFile)
def test_formatWithPID(self):
"""
L{PIDFile._format} returns the expected format when given a PID.
"""
self.assertEqual(PIDFile._format(pid=1337), b"1337\n")
def test_readWithPID(self):
"""
L{PIDFile.read} returns the PID from the given file path.
"""
pid = 1337
pidFile = PIDFile(DummyFilePath(PIDFile._format(pid=pid)))
self.assertEqual(pid, pidFile.read())
def test_readEmptyPID(self):
"""
L{PIDFile.read} raises L{InvalidPIDFileError} when given an empty file
path.
"""
pidValue = b""
pidFile = PIDFile(DummyFilePath(b""))
e = self.assertRaises(InvalidPIDFileError, pidFile.read)
self.assertEqual(
str(e),
"non-integer PID value in PID file: {!r}".format(pidValue)
)
def test_readWithBogusPID(self):
"""
L{PIDFile.read} raises L{InvalidPIDFileError} when given an empty file
path.
"""
pidValue = b"$foo!"
pidFile = PIDFile(DummyFilePath(pidValue))
e = self.assertRaises(InvalidPIDFileError, pidFile.read)
self.assertEqual(
str(e),
"non-integer PID value in PID file: {!r}".format(pidValue)
)
def test_readDoesntExist(self):
"""
L{PIDFile.read} raises L{NoPIDFound} when given a non-existing file
path.
"""
pidFile = PIDFile(DummyFilePath())
e = self.assertRaises(NoPIDFound, pidFile.read)
self.assertEqual(str(e), "PID file does not exist")
def test_readOpenRaisesOSErrorNotENOENT(self):
"""
L{PIDFile.read} re-raises L{OSError} if the associated C{errno} is
anything other than L{errno.ENOENT}.
"""
def oops(mode="r"):
raise OSError(errno.EIO, "I/O error")
self.patch(DummyFilePath, "open", oops)
pidFile = PIDFile(DummyFilePath())
error = self.assertRaises(OSError, pidFile.read)
self.assertEqual(error.errno, errno.EIO)
def test_writePID(self):
"""
L{PIDFile._write} stores the given PID.
"""
pid = 1995
pidFile = PIDFile(DummyFilePath())
pidFile._write(pid)
self.assertEqual(pidFile.read(), pid)
def test_writePIDInvalid(self):
"""
L{PIDFile._write} raises L{ValueError} when given an invalid PID.
"""
pidFile = PIDFile(DummyFilePath())
self.assertRaises(ValueError, pidFile._write, u"burp")
def test_writeRunningPID(self):
"""
L{PIDFile.writeRunningPID} stores the PID for the current process.
"""
pidFile = PIDFile(DummyFilePath())
pidFile.writeRunningPID()
self.assertEqual(pidFile.read(), getpid())
def test_remove(self):
"""
L{PIDFile.remove} removes the PID file.
"""
pidFile = PIDFile(DummyFilePath(b""))
self.assertTrue(pidFile.filePath.exists())
pidFile.remove()
self.assertFalse(pidFile.filePath.exists())
@ifPlatformSupported
def test_isRunningDoesExist(self):
"""
L{PIDFile.isRunning} returns true for a process that does exist.
"""
pidFile = PIDFile(DummyFilePath())
pidFile._write(1337)
def kill(pid, signal):
return # Don't actually kill anything
self.patch(_pidfile, "kill", kill)
self.assertTrue(pidFile.isRunning())
@ifPlatformSupported
def test_isRunningThis(self):
"""
L{PIDFile.isRunning} returns true for this process (which is running).
@note: This differs from L{PIDFileTests.test_isRunningDoesExist} in
that it actually invokes the C{kill} system call, which is useful for
testing of our chosen method for probing the existence of a process.
"""
pidFile = PIDFile(DummyFilePath())
pidFile.writeRunningPID()
self.assertTrue(pidFile.isRunning())
@ifPlatformSupported
def test_isRunningDoesNotExist(self):
"""
L{PIDFile.isRunning} raises L{StalePIDFileError} for a process that
does not exist (errno=ESRCH).
"""
pidFile = PIDFile(DummyFilePath())
pidFile._write(1337)
def kill(pid, signal):
raise OSError(errno.ESRCH, "No such process")
self.patch(_pidfile, "kill", kill)
self.assertRaises(StalePIDFileError, pidFile.isRunning)
@ifPlatformSupported
def test_isRunningNotAllowed(self):
"""
L{PIDFile.isRunning} returns true for a process that we are not allowed
to kill (errno=EPERM).
"""
pidFile = PIDFile(DummyFilePath())
pidFile._write(1337)
def kill(pid, signal):
raise OSError(errno.EPERM, "Operation not permitted")
self.patch(_pidfile, "kill", kill)
self.assertTrue(pidFile.isRunning())
@ifPlatformSupported
def test_isRunningInit(self):
"""
L{PIDFile.isRunning} returns true for a process that we are not allowed
to kill (errno=EPERM).
@note: This differs from L{PIDFileTests.test_isRunningNotAllowed} in
that it actually invokes the C{kill} system call, which is useful for
testing of our chosen method for probing the existence of a process
that we are not allowed to kill.
@note: In this case, we try killing C{init}, which is process #1 on
POSIX systems, so this test is not portable. C{init} should always be
running and should not be killable by non-root users.
"""
if SYSTEM_NAME != "posix":
raise SkipTest("This test assumes POSIX")
pidFile = PIDFile(DummyFilePath())
pidFile._write(1) # PID 1 is init on POSIX systems
self.assertTrue(pidFile.isRunning())
@ifPlatformSupported
def test_isRunningUnknownErrno(self):
"""
L{PIDFile.isRunning} re-raises L{OSError} if the attached C{errno}
value from L{os.kill} is not an expected one.
"""
pidFile = PIDFile(DummyFilePath())
pidFile.writeRunningPID()
def kill(pid, signal):
raise OSError(errno.EEXIST, "File exists")
self.patch(_pidfile, "kill", kill)
self.assertRaises(OSError, pidFile.isRunning)
def test_isRunningNoPIDFile(self):
"""
L{PIDFile.isRunning} returns false if the PID file doesn't exist.
"""
pidFile = PIDFile(DummyFilePath())
self.assertFalse(pidFile.isRunning())
def test_contextManager(self):
"""
When used as a context manager, a L{PIDFile} will store the current pid
on entry, then removes the PID file on exit.
"""
pidFile = PIDFile(DummyFilePath())
self.assertFalse(pidFile.filePath.exists())
with pidFile:
self.assertTrue(pidFile.filePath.exists())
self.assertEqual(pidFile.read(), getpid())
self.assertFalse(pidFile.filePath.exists())
@ifPlatformSupported
def test_contextManagerDoesntExist(self):
"""
When used as a context manager, a L{PIDFile} will replace the
underlying PIDFile rather than raising L{AlreadyRunningError} if the
contained PID file exists but refers to a non-running PID.
"""
pidFile = PIDFile(DummyFilePath())
pidFile._write(1337)
def kill(pid, signal):
raise OSError(errno.ESRCH, "No such process")
self.patch(_pidfile, "kill", kill)
e = self.assertRaises(StalePIDFileError, pidFile.isRunning)
self.assertEqual(str(e), "PID file refers to non-existing process")
with pidFile:
self.assertEqual(pidFile.read(), getpid())
@ifPlatformSupported
def test_contextManagerAlreadyRunning(self):
"""
When used as a context manager, a L{PIDFile} will raise
L{AlreadyRunningError} if the there is already a running process with
the contained PID.
"""
pidFile = PIDFile(DummyFilePath())
pidFile._write(1337)
def kill(pid, signal):
return # Don't actually kill anything
self.patch(_pidfile, "kill", kill)
self.assertTrue(pidFile.isRunning())
self.assertRaises(AlreadyRunningError, pidFile.__enter__)
class NonePIDFileTests(twisted.trial.unittest.TestCase):
"""
Tests for L{NonePIDFile}.
"""
def test_interface(self):
"""
L{NonePIDFile} conforms to L{IPIDFile}.
"""
pidFile = NonePIDFile()
verifyObject(IPIDFile, pidFile)
def test_read(self):
"""
L{NonePIDFile.read} raises L{NoPIDFound}.
"""
pidFile = NonePIDFile()
e = self.assertRaises(NoPIDFound, pidFile.read)
self.assertEqual(str(e), "PID file does not exist")
def test_write(self):
"""
L{NonePIDFile._write} raises L{OSError} with an errno of L{errno.EPERM}.
"""
pidFile = NonePIDFile()
error = self.assertRaises(OSError, pidFile._write, 0)
self.assertEqual(error.errno, errno.EPERM)
def test_writeRunningPID(self):
"""
L{NonePIDFile.writeRunningPID} raises L{OSError} with an errno of
L{errno.EPERM}.
"""
pidFile = NonePIDFile()
error = self.assertRaises(OSError, pidFile.writeRunningPID)
self.assertEqual(error.errno, errno.EPERM)
def test_remove(self):
"""
L{NonePIDFile.remove} raises L{OSError} with an errno of L{errno.EPERM}.
"""
pidFile = NonePIDFile()
error = self.assertRaises(OSError, pidFile.remove)
self.assertEqual(error.errno, errno.ENOENT)
def test_isRunning(self):
"""
L{NonePIDFile.isRunning} returns L{False}.
"""
pidFile = NonePIDFile()
self.assertEqual(pidFile.isRunning(), False)
def test_contextManager(self):
"""
When used as a context manager, a L{NonePIDFile} doesn't raise, despite
not existing.
"""
pidFile = NonePIDFile()
with pidFile:
pass
@implementer(IFilePath)
class DummyFilePath(object):
"""
In-memory L{IFilePath}.
"""
def __init__(self, content=None):
self.setContent(content)
def open(self, mode="r"):
if not self._exists:
raise OSError(errno.ENOENT, "No such file or directory")
return BytesIO(self.getContent())
def setContent(self, content):
self._exists = content is not None
self._content = content
def getContent(self):
return self._content
def remove(self):
self.setContent(None)
def exists(self):
return self._exists

View file

@ -0,0 +1,460 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.application.runner._runner}.
"""
from signal import SIGTERM
from io import BytesIO
import errno
from attr import attrib, attrs, Factory
from twisted.logger import (
LogLevel, LogPublisher, LogBeginner,
FileLogObserver, FilteringLogObserver, LogLevelFilterPredicate,
)
from twisted.internet.testing import MemoryReactor
from ...runner import _runner
from .._exit import ExitStatus
from .._pidfile import PIDFile, NonePIDFile
from .._runner import Runner
from .test_pidfile import DummyFilePath
import twisted.trial.unittest
class RunnerTests(twisted.trial.unittest.TestCase):
"""
Tests for L{Runner}.
"""
def setUp(self):
# Patch exit and kill so we can capture usage and prevent actual exits
# and kills.
self.exit = DummyExit()
self.kill = DummyKill()
self.patch(_runner, "exit", self.exit)
self.patch(_runner, "kill", self.kill)
# Patch getpid so we get a known result
self.pid = 1337
self.pidFileContent = u"{}\n".format(self.pid).encode("utf-8")
# Patch globalLogBeginner so that we aren't trying to install multiple
# global log observers.
self.stdout = BytesIO()
self.stderr = BytesIO()
self.stdio = DummyStandardIO(self.stdout, self.stderr)
self.warnings = DummyWarningsModule()
self.globalLogPublisher = LogPublisher()
self.globalLogBeginner = LogBeginner(
self.globalLogPublisher,
self.stdio.stderr, self.stdio,
self.warnings,
)
self.patch(_runner, "stderr", self.stderr)
self.patch(_runner, "globalLogBeginner", self.globalLogBeginner)
def test_runInOrder(self):
"""
L{Runner.run} calls the expected methods in order.
"""
runner = DummyRunner(reactor=MemoryReactor())
runner.run()
self.assertEqual(
runner.calledMethods,
[
"killIfRequested",
"startLogging",
"startReactor",
"reactorExited",
]
)
def test_runUsesPIDFile(self):
"""
L{Runner.run} uses the provided PID file.
"""
pidFile = DummyPIDFile()
runner = Runner(reactor=MemoryReactor(), pidFile=pidFile)
self.assertFalse(pidFile.entered)
self.assertFalse(pidFile.exited)
runner.run()
self.assertTrue(pidFile.entered)
self.assertTrue(pidFile.exited)
def test_runAlreadyRunning(self):
"""
L{Runner.run} exits with L{ExitStatus.EX_USAGE} and the expected
message if a process is already running that corresponds to the given
PID file.
"""
pidFile = PIDFile(DummyFilePath(self.pidFileContent))
pidFile.isRunning = lambda: True
runner = Runner(reactor=MemoryReactor(), pidFile=pidFile)
runner.run()
self.assertEqual(self.exit.status, ExitStatus.EX_CONFIG)
self.assertEqual(self.exit.message, "Already running.")
def test_killNotRequested(self):
"""
L{Runner.killIfRequested} when C{kill} is false doesn't exit and
doesn't indiscriminately murder anyone.
"""
runner = Runner(reactor=MemoryReactor())
runner.killIfRequested()
self.assertEqual(self.kill.calls, [])
self.assertFalse(self.exit.exited)
def test_killRequestedWithoutPIDFile(self):
"""
L{Runner.killIfRequested} when C{kill} is true but C{pidFile} is
L{nonePIDFile} exits with L{ExitStatus.EX_USAGE} and the expected
message; and also doesn't indiscriminately murder anyone.
"""
runner = Runner(reactor=MemoryReactor(), kill=True)
runner.killIfRequested()
self.assertEqual(self.kill.calls, [])
self.assertEqual(self.exit.status, ExitStatus.EX_USAGE)
self.assertEqual(self.exit.message, "No PID file specified.")
def test_killRequestedWithPIDFile(self):
"""
L{Runner.killIfRequested} when C{kill} is true and given a C{pidFile}
performs a targeted killing of the appropriate process.
"""
pidFile = PIDFile(DummyFilePath(self.pidFileContent))
runner = Runner(reactor=MemoryReactor(), kill=True, pidFile=pidFile)
runner.killIfRequested()
self.assertEqual(self.kill.calls, [(self.pid, SIGTERM)])
self.assertEqual(self.exit.status, ExitStatus.EX_OK)
self.assertIdentical(self.exit.message, None)
def test_killRequestedWithPIDFileCantRead(self):
"""
L{Runner.killIfRequested} when C{kill} is true and given a C{pidFile}
that it can't read exits with L{ExitStatus.EX_IOERR}.
"""
pidFile = PIDFile(DummyFilePath(None))
def read():
raise OSError(errno.EACCES, "Permission denied")
pidFile.read = read
runner = Runner(reactor=MemoryReactor(), kill=True, pidFile=pidFile)
runner.killIfRequested()
self.assertEqual(self.exit.status, ExitStatus.EX_IOERR)
self.assertEqual(self.exit.message, "Unable to read PID file.")
def test_killRequestedWithPIDFileEmpty(self):
"""
L{Runner.killIfRequested} when C{kill} is true and given a C{pidFile}
containing no value exits with L{ExitStatus.EX_DATAERR}.
"""
pidFile = PIDFile(DummyFilePath(b""))
runner = Runner(reactor=MemoryReactor(), kill=True, pidFile=pidFile)
runner.killIfRequested()
self.assertEqual(self.exit.status, ExitStatus.EX_DATAERR)
self.assertEqual(self.exit.message, "Invalid PID file.")
def test_killRequestedWithPIDFileNotAnInt(self):
"""
L{Runner.killIfRequested} when C{kill} is true and given a C{pidFile}
containing a non-integer value exits with L{ExitStatus.EX_DATAERR}.
"""
pidFile = PIDFile(DummyFilePath(b"** totally not a number, dude **"))
runner = Runner(reactor=MemoryReactor(), kill=True, pidFile=pidFile)
runner.killIfRequested()
self.assertEqual(self.exit.status, ExitStatus.EX_DATAERR)
self.assertEqual(self.exit.message, "Invalid PID file.")
def test_startLogging(self):
"""
L{Runner.startLogging} sets up a filtering observer with a log level
predicate set to the given log level that contains a file observer of
the given type which writes to the given file.
"""
logFile = BytesIO()
# Patch the log beginner so that we don't try to start the already
# running (started by trial) logging system.
class LogBeginner(object):
def beginLoggingTo(self, observers):
LogBeginner.observers = observers
self.patch(_runner, "globalLogBeginner", LogBeginner())
# Patch FilteringLogObserver so we can capture its arguments
class MockFilteringLogObserver(FilteringLogObserver):
def __init__(
self, observer, predicates,
negativeObserver=lambda event: None
):
MockFilteringLogObserver.observer = observer
MockFilteringLogObserver.predicates = predicates
FilteringLogObserver.__init__(
self, observer, predicates, negativeObserver
)
self.patch(_runner, "FilteringLogObserver", MockFilteringLogObserver)
# Patch FileLogObserver so we can capture its arguments
class MockFileLogObserver(FileLogObserver):
def __init__(self, outFile):
MockFileLogObserver.outFile = outFile
FileLogObserver.__init__(self, outFile, str)
# Start logging
runner = Runner(
reactor=MemoryReactor(),
defaultLogLevel=LogLevel.critical,
logFile=logFile,
fileLogObserverFactory=MockFileLogObserver,
)
runner.startLogging()
# Check for a filtering observer
self.assertEqual(len(LogBeginner.observers), 1)
self.assertIsInstance(LogBeginner.observers[0], FilteringLogObserver)
# Check log level predicate with the correct default log level
self.assertEqual(len(MockFilteringLogObserver.predicates), 1)
self.assertIsInstance(
MockFilteringLogObserver.predicates[0],
LogLevelFilterPredicate
)
self.assertIdentical(
MockFilteringLogObserver.predicates[0].defaultLogLevel,
LogLevel.critical
)
# Check for a file observer attached to the filtering observer
self.assertIsInstance(
MockFilteringLogObserver.observer, MockFileLogObserver
)
# Check for the file we gave it
self.assertIdentical(
MockFilteringLogObserver.observer.outFile, logFile
)
def test_startReactorWithReactor(self):
"""
L{Runner.startReactor} with the C{reactor} argument runs the given
reactor.
"""
reactor = MemoryReactor()
runner = Runner(reactor=reactor)
runner.startReactor()
self.assertTrue(reactor.hasRun)
def test_startReactorWhenRunning(self):
"""
L{Runner.startReactor} ensures that C{whenRunning} is called with
C{whenRunningArguments} when the reactor is running.
"""
self._testHook("whenRunning", "startReactor")
def test_whenRunningWithArguments(self):
"""
L{Runner.whenRunning} calls C{whenRunning} with
C{whenRunningArguments}.
"""
self._testHook("whenRunning")
def test_reactorExitedWithArguments(self):
"""
L{Runner.whenRunning} calls C{reactorExited} with
C{reactorExitedArguments}.
"""
self._testHook("reactorExited")
def _testHook(self, methodName, callerName=None):
"""
Verify that the named hook is run with the expected arguments as
specified by the arguments used to create the L{Runner}, when the
specified caller is invoked.
@param methodName: The name of the hook to verify.
@type methodName: L{str}
@param callerName: The name of the method that is expected to cause the
hook to be called.
If C{None}, use the L{Runner} method with the same name as the
hook.
@type callerName: L{str}
"""
if callerName is None:
callerName = methodName
arguments = dict(a=object(), b=object(), c=object())
argumentsSeen = []
def hook(**arguments):
argumentsSeen.append(arguments)
runnerArguments = {
methodName: hook,
"{}Arguments".format(methodName): arguments.copy(),
}
runner = Runner(reactor=MemoryReactor(), **runnerArguments)
hookCaller = getattr(runner, callerName)
hookCaller()
self.assertEqual(len(argumentsSeen), 1)
self.assertEqual(argumentsSeen[0], arguments)
@attrs(frozen=True)
class DummyRunner(Runner):
"""
Stub for L{Runner}.
Keep track of calls to some methods without actually doing anything.
"""
calledMethods = attrib(default=Factory(list))
def killIfRequested(self):
self.calledMethods.append("killIfRequested")
def startLogging(self):
self.calledMethods.append("startLogging")
def startReactor(self):
self.calledMethods.append("startReactor")
def reactorExited(self):
self.calledMethods.append("reactorExited")
class DummyPIDFile(NonePIDFile):
"""
Stub for L{PIDFile}.
Tracks context manager entry/exit without doing anything.
"""
def __init__(self):
NonePIDFile.__init__(self)
self.entered = False
self.exited = False
def __enter__(self):
self.entered = True
return self
def __exit__(self, excType, excValue, traceback):
self.exited = True
class DummyExit(object):
"""
Stub for L{exit} that remembers whether it's been called and, if it has,
what arguments it was given.
"""
def __init__(self):
self.exited = False
def __call__(self, status, message=None):
assert not self.exited
self.status = status
self.message = message
self.exited = True
class DummyKill(object):
"""
Stub for L{os.kill} that remembers whether it's been called and, if it has,
what arguments it was given.
"""
def __init__(self):
self.calls = []
def __call__(self, pid, sig):
self.calls.append((pid, sig))
class DummyStandardIO(object):
"""
Stub for L{sys} which provides L{BytesIO} streams as stdout and stderr.
"""
def __init__(self, stdout, stderr):
self.stdout = stdout
self.stderr = stderr
class DummyWarningsModule(object):
"""
Stub for L{warnings} which provides a C{showwarning} method that is a no-op.
"""
def showwarning(*args, **kwargs):
"""
Do nothing.
@param args: ignored.
@param kwargs: ignored.
"""

View file

@ -0,0 +1,424 @@
# -*- test-case-name: twisted.application.test.test_service -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Service architecture for Twisted.
Services are arranged in a hierarchy. At the leafs of the hierarchy,
the services which actually interact with the outside world are started.
Services can be named or anonymous -- usually, they will be named if
there is need to access them through the hierarchy (from a parent or
a sibling).
Maintainer: Moshe Zadka
"""
from __future__ import absolute_import, division
from zope.interface import implementer, Interface, Attribute
from twisted.persisted import sob
from twisted.python.reflect import namedAny
from twisted.python import components
from twisted.python._oldstyle import _oldStyle
from twisted.internet import defer
from twisted.plugin import IPlugin
class IServiceMaker(Interface):
"""
An object which can be used to construct services in a flexible
way.
This interface should most often be implemented along with
L{twisted.plugin.IPlugin}, and will most often be used by the
'twistd' command.
"""
tapname = Attribute(
"A short string naming this Twisted plugin, for example 'web' or "
"'pencil'. This name will be used as the subcommand of 'twistd'.")
description = Attribute(
"A brief summary of the features provided by this "
"Twisted application plugin.")
options = Attribute(
"A C{twisted.python.usage.Options} subclass defining the "
"configuration options for this application.")
def makeService(options):
"""
Create and return an object providing
L{twisted.application.service.IService}.
@param options: A mapping (typically a C{dict} or
L{twisted.python.usage.Options} instance) of configuration
options to desired configuration values.
"""
@implementer(IPlugin, IServiceMaker)
class ServiceMaker(object):
"""
Utility class to simplify the definition of L{IServiceMaker} plugins.
"""
def __init__(self, name, module, description, tapname):
self.name = name
self.module = module
self.description = description
self.tapname = tapname
def options():
def get(self):
return namedAny(self.module).Options
return get,
options = property(*options())
def makeService():
def get(self):
return namedAny(self.module).makeService
return get,
makeService = property(*makeService())
class IService(Interface):
"""
A service.
Run start-up and shut-down code at the appropriate times.
"""
name = Attribute(
"A C{str} which is the name of the service or C{None}.")
running = Attribute(
"A C{boolean} which indicates whether the service is running.")
parent = Attribute(
"An C{IServiceCollection} which is the parent or C{None}.")
def setName(name):
"""
Set the name of the service.
@type name: C{str}
@raise RuntimeError: Raised if the service already has a parent.
"""
def setServiceParent(parent):
"""
Set the parent of the service. This method is responsible for setting
the C{parent} attribute on this service (the child service).
@type parent: L{IServiceCollection}
@raise RuntimeError: Raised if the service already has a parent
or if the service has a name and the parent already has a child
by that name.
"""
def disownServiceParent():
"""
Use this API to remove an L{IService} from an L{IServiceCollection}.
This method is used symmetrically with L{setServiceParent} in that it
sets the C{parent} attribute on the child.
@rtype: L{Deferred<defer.Deferred>}
@return: a L{Deferred<defer.Deferred>} which is triggered when the
service has finished shutting down. If shutting down is immediate,
a value can be returned (usually, L{None}).
"""
def startService():
"""
Start the service.
"""
def stopService():
"""
Stop the service.
@rtype: L{Deferred<defer.Deferred>}
@return: a L{Deferred<defer.Deferred>} which is triggered when the
service has finished shutting down. If shutting down is immediate,
a value can be returned (usually, L{None}).
"""
def privilegedStartService():
"""
Do preparation work for starting the service.
Here things which should be done before changing directory,
root or shedding privileges are done.
"""
@implementer(IService)
class Service(object):
"""
Base class for services.
Most services should inherit from this class. It handles the
book-keeping responsibilities of starting and stopping, as well
as not serializing this book-keeping information.
"""
running = 0
name = None
parent = None
def __getstate__(self):
dict = self.__dict__.copy()
if "running" in dict:
del dict['running']
return dict
def setName(self, name):
if self.parent is not None:
raise RuntimeError("cannot change name when parent exists")
self.name = name
def setServiceParent(self, parent):
if self.parent is not None:
self.disownServiceParent()
parent = IServiceCollection(parent, parent)
self.parent = parent
self.parent.addService(self)
def disownServiceParent(self):
d = self.parent.removeService(self)
self.parent = None
return d
def privilegedStartService(self):
pass
def startService(self):
self.running = 1
def stopService(self):
self.running = 0
class IServiceCollection(Interface):
"""
Collection of services.
Contain several services, and manage their start-up/shut-down.
Services can be accessed by name if they have a name, and it
is always possible to iterate over them.
"""
def getServiceNamed(name):
"""
Get the child service with a given name.
@type name: C{str}
@rtype: L{IService}
@raise KeyError: Raised if the service has no child with the
given name.
"""
def __iter__():
"""
Get an iterator over all child services.
"""
def addService(service):
"""
Add a child service.
Only implementations of L{IService.setServiceParent} should use this
method.
@type service: L{IService}
@raise RuntimeError: Raised if the service has a child with
the given name.
"""
def removeService(service):
"""
Remove a child service.
Only implementations of L{IService.disownServiceParent} should
use this method.
@type service: L{IService}
@raise ValueError: Raised if the given service is not a child.
@rtype: L{Deferred<defer.Deferred>}
@return: a L{Deferred<defer.Deferred>} which is triggered when the
service has finished shutting down. If shutting down is immediate,
a value can be returned (usually, L{None}).
"""
@implementer(IServiceCollection)
class MultiService(Service):
"""
Straightforward Service Container.
Hold a collection of services, and manage them in a simplistic
way. No service will wait for another, but this object itself
will not finish shutting down until all of its child services
will finish.
"""
def __init__(self):
self.services = []
self.namedServices = {}
self.parent = None
def privilegedStartService(self):
Service.privilegedStartService(self)
for service in self:
service.privilegedStartService()
def startService(self):
Service.startService(self)
for service in self:
service.startService()
def stopService(self):
Service.stopService(self)
l = []
services = list(self)
services.reverse()
for service in services:
l.append(defer.maybeDeferred(service.stopService))
return defer.DeferredList(l)
def getServiceNamed(self, name):
return self.namedServices[name]
def __iter__(self):
return iter(self.services)
def addService(self, service):
if service.name is not None:
if service.name in self.namedServices:
raise RuntimeError("cannot have two services with same name"
" '%s'" % service.name)
self.namedServices[service.name] = service
self.services.append(service)
if self.running:
# It may be too late for that, but we will do our best
service.privilegedStartService()
service.startService()
def removeService(self, service):
if service.name:
del self.namedServices[service.name]
self.services.remove(service)
if self.running:
# Returning this so as not to lose information from the
# MultiService.stopService deferred.
return service.stopService()
else:
return None
class IProcess(Interface):
"""
Process running parameters.
Represents parameters for how processes should be run.
"""
processName = Attribute(
"""
A C{str} giving the name the process should have in ps (or L{None}
to leave the name alone).
""")
uid = Attribute(
"""
An C{int} giving the user id as which the process should run (or
L{None} to leave the UID alone).
""")
gid = Attribute(
"""
An C{int} giving the group id as which the process should run (or
L{None} to leave the GID alone).
""")
@implementer(IProcess)
@_oldStyle
class Process:
"""
Process running parameters.
Sets up uid/gid in the constructor, and has a default
of L{None} as C{processName}.
"""
processName = None
def __init__(self, uid=None, gid=None):
"""
Set uid and gid.
@param uid: The user ID as whom to execute the process. If
this is L{None}, no attempt will be made to change the UID.
@param gid: The group ID as whom to execute the process. If
this is L{None}, no attempt will be made to change the GID.
"""
self.uid = uid
self.gid = gid
def Application(name, uid=None, gid=None):
"""
Return a compound class.
Return an object supporting the L{IService}, L{IServiceCollection},
L{IProcess} and L{sob.IPersistable} interfaces, with the given
parameters. Always access the return value by explicit casting to
one of the interfaces.
"""
ret = components.Componentized()
availableComponents = [MultiService(), Process(uid, gid),
sob.Persistent(ret, name)]
for comp in availableComponents:
ret.addComponent(comp, ignoreClass=1)
IService(ret).setName(name)
return ret
def loadApplication(filename, kind, passphrase=None):
"""
Load Application from a given file.
The serialization format it was saved in should be given as
C{kind}, and is one of C{pickle}, C{source}, C{xml} or C{python}. If
C{passphrase} is given, the application was encrypted with the
given passphrase.
@type filename: C{str}
@type kind: C{str}
@type passphrase: C{str}
"""
if kind == 'python':
application = sob.loadValueFromFile(filename, 'application')
else:
application = sob.load(filename, kind)
return application
__all__ = ['IServiceMaker', 'IService', 'Service',
'IServiceCollection', 'MultiService',
'IProcess', 'Process', 'Application', 'loadApplication']

View file

@ -0,0 +1,70 @@
# -*- test-case-name: twisted.test.test_strports -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Construct listening port services from a simple string description.
@see: L{twisted.internet.endpoints.serverFromString}
@see: L{twisted.internet.endpoints.clientFromString}
"""
from __future__ import absolute_import, division
from twisted.application.internet import StreamServerEndpointService
from twisted.internet import endpoints
def service(description, factory, reactor=None):
"""
Return the service corresponding to a description.
@param description: The description of the listening port, in the syntax
described by L{twisted.internet.endpoints.serverFromString}.
@type description: C{str}
@param factory: The protocol factory which will build protocols for
connections to this service.
@type factory: L{twisted.internet.interfaces.IProtocolFactory}
@rtype: C{twisted.application.service.IService}
@return: the service corresponding to a description of a reliable stream
server.
@see: L{twisted.internet.endpoints.serverFromString}
"""
if reactor is None:
from twisted.internet import reactor
svc = StreamServerEndpointService(
endpoints.serverFromString(reactor, description), factory)
svc._raiseSynchronously = True
return svc
def listen(description, factory):
"""
Listen on a port corresponding to a description.
@param description: The description of the connecting port, in the syntax
described by L{twisted.internet.endpoints.serverFromString}.
@type description: L{str}
@param factory: The protocol factory which will build protocols on
connection.
@type factory: L{twisted.internet.interfaces.IProtocolFactory}
@rtype: L{twisted.internet.interfaces.IListeningPort}
@return: the port corresponding to a description of a reliable virtual
circuit server.
@see: L{twisted.internet.endpoints.serverFromString}
"""
from twisted.internet import reactor
name, args, kw = endpoints._parseServer(description, factory)
return getattr(reactor, 'listen' + name)(*args, **kw)
__all__ = ['service', 'listen']

View file

@ -0,0 +1,6 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.internet.application}.
"""

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,188 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.application.service}.
"""
from __future__ import absolute_import, division
from zope.interface import implementer
from zope.interface.exceptions import BrokenImplementation
from zope.interface.verify import verifyObject
from twisted.persisted.sob import IPersistable
from twisted.application.service import Application, IProcess
from twisted.application.service import IService, IServiceCollection
from twisted.application.service import Service
from twisted.trial.unittest import TestCase
@implementer(IService)
class AlmostService(object):
"""
Implement IService in a way that can fail.
In general, classes should maintain invariants that adhere
to the interfaces that they claim to implement --
otherwise, it is a bug.
This is a buggy class -- the IService implementation is fragile,
and several methods will break it. These bugs are intentional,
as the tests trigger them -- and then check that the class,
indeed, no longer complies with the interface (IService)
that it claims to comply with.
Since the verification will, by definition, only fail on buggy classes --
in other words, those which do not actually support the interface they
claim to support, we have to write a buggy class to properly verify
the interface.
"""
def __init__(self, name, parent, running):
self.name = name
self.parent = parent
self.running = running
def makeInvalidByDeletingName(self):
"""
Probably not a wise method to call.
This method removes the :code:`name` attribute,
which has to exist in IService classes.
"""
del self.name
def makeInvalidByDeletingParent(self):
"""
Probably not a wise method to call.
This method removes the :code:`parent` attribute,
which has to exist in IService classes.
"""
del self.parent
def makeInvalidByDeletingRunning(self):
"""
Probably not a wise method to call.
This method removes the :code:`running` attribute,
which has to exist in IService classes.
"""
del self.running
def setName(self, name):
"""
See L{twisted.application.service.IService}.
@param name: ignored
"""
def setServiceParent(self, parent):
"""
See L{twisted.application.service.IService}.
@param parent: ignored
"""
def disownServiceParent(self):
"""
See L{twisted.application.service.IService}.
"""
def privilegedStartService(self):
"""
See L{twisted.application.service.IService}.
"""
def startService(self):
"""
See L{twisted.application.service.IService}.
"""
def stopService(self):
"""
See L{twisted.application.service.IService}.
"""
class ServiceInterfaceTests(TestCase):
"""
Tests for L{twisted.application.service.IService} implementation.
"""
def setUp(self):
"""
Build something that implements IService.
"""
self.almostService = AlmostService(parent=None, running=False,
name=None)
def test_realService(self):
"""
Service implements IService.
"""
myService = Service()
verifyObject(IService, myService)
def test_hasAll(self):
"""
AlmostService implements IService.
"""
verifyObject(IService, self.almostService)
def test_noName(self):
"""
AlmostService with no name does not implement IService.
"""
self.almostService.makeInvalidByDeletingName()
with self.assertRaises(BrokenImplementation):
verifyObject(IService, self.almostService)
def test_noParent(self):
"""
AlmostService with no parent does not implement IService.
"""
self.almostService.makeInvalidByDeletingParent()
with self.assertRaises(BrokenImplementation):
verifyObject(IService, self.almostService)
def test_noRunning(self):
"""
AlmostService with no running does not implement IService.
"""
self.almostService.makeInvalidByDeletingRunning()
with self.assertRaises(BrokenImplementation):
verifyObject(IService, self.almostService)
class ApplicationTests(TestCase):
"""
Tests for L{twisted.application.service.Application}.
"""
def test_applicationComponents(self):
"""
Check L{twisted.application.service.Application} instantiation.
"""
app = Application('app-name')
self.assertTrue(verifyObject(IService, IService(app)))
self.assertTrue(
verifyObject(IServiceCollection, IServiceCollection(app)))
self.assertTrue(verifyObject(IProcess, IProcess(app)))
self.assertTrue(verifyObject(IPersistable, IPersistable(app)))

View file

@ -0,0 +1,7 @@
# -*- test-case-name: twisted.application.twist.test -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
C{twist} command line tool.
"""

View file

@ -0,0 +1,205 @@
# -*- test-case-name: twisted.application.twist.test.test_options -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Command line options for C{twist}.
"""
from sys import stdout, stderr
from textwrap import dedent
from twisted.copyright import version
from twisted.python.usage import Options, UsageError
from twisted.logger import (
LogLevel, InvalidLogLevelError,
textFileLogObserver, jsonFileLogObserver,
)
from twisted.plugin import getPlugins
from ..reactors import installReactor, NoSuchReactor, getReactorTypes
from ..runner._exit import exit, ExitStatus
from ..service import IServiceMaker
openFile = open
class TwistOptions(Options):
"""
Command line options for C{twist}.
"""
defaultReactorName = "default"
defaultLogLevel = LogLevel.info
def __init__(self):
Options.__init__(self)
self["reactorName"] = self.defaultReactorName
self["logLevel"] = self.defaultLogLevel
self["logFile"] = stdout
def getSynopsis(self):
return "{} plugin [plugin_options]".format(
Options.getSynopsis(self)
)
def opt_version(self):
"""
Print version and exit.
"""
exit(ExitStatus.EX_OK, "{}".format(version))
def opt_reactor(self, name):
"""
The name of the reactor to use.
(options: {options})
"""
# Actually actually actually install the reactor right at this very
# moment, before any other code (for example, a sub-command plugin)
# runs and accidentally imports and installs the default reactor.
try:
self["reactor"] = self.installReactor(name)
except NoSuchReactor:
raise UsageError("Unknown reactor: {}".format(name))
else:
self["reactorName"] = name
opt_reactor.__doc__ = dedent(opt_reactor.__doc__).format(
options=", ".join(
'"{}"'.format(rt.shortName) for rt in getReactorTypes()
),
)
def installReactor(self, name):
"""
Install the reactor.
"""
if name == self.defaultReactorName:
from twisted.internet import reactor
return reactor
else:
return installReactor(name)
def opt_log_level(self, levelName):
"""
Set default log level.
(options: {options}; default: "{default}")
"""
try:
self["logLevel"] = LogLevel.levelWithName(levelName)
except InvalidLogLevelError:
raise UsageError("Invalid log level: {}".format(levelName))
opt_log_level.__doc__ = dedent(opt_log_level.__doc__).format(
options=", ".join(
'"{}"'.format(l.name) for l in LogLevel.iterconstants()
),
default=defaultLogLevel.name,
)
def opt_log_file(self, fileName):
"""
Log to file. ("-" for stdout, "+" for stderr; default: "-")
"""
if fileName == "-":
self["logFile"] = stdout
return
if fileName == "+":
self["logFile"] = stderr
return
try:
self["logFile"] = openFile(fileName, "a")
except EnvironmentError as e:
exit(
ExitStatus.EX_IOERR,
"Unable to open log file {!r}: {}".format(fileName, e)
)
def opt_log_format(self, format):
"""
Log file format.
(options: "text", "json"; default: "text" if the log file is a tty,
otherwise "json")
"""
format = format.lower()
if format == "text":
self["fileLogObserverFactory"] = textFileLogObserver
elif format == "json":
self["fileLogObserverFactory"] = jsonFileLogObserver
else:
raise UsageError("Invalid log format: {}".format(format))
self["logFormat"] = format
opt_log_format.__doc__ = dedent(opt_log_format.__doc__)
def selectDefaultLogObserver(self):
"""
Set C{fileLogObserverFactory} to the default appropriate for the
chosen C{logFile}.
"""
if "fileLogObserverFactory" not in self:
logFile = self["logFile"]
if hasattr(logFile, "isatty") and logFile.isatty():
self["fileLogObserverFactory"] = textFileLogObserver
self["logFormat"] = "text"
else:
self["fileLogObserverFactory"] = jsonFileLogObserver
self["logFormat"] = "json"
def parseOptions(self, options=None):
self.selectDefaultLogObserver()
Options.parseOptions(self, options=options)
if "reactor" not in self:
self["reactor"] = self.installReactor(self["reactorName"])
@property
def plugins(self):
if "plugins" not in self:
plugins = {}
for plugin in getPlugins(IServiceMaker):
plugins[plugin.tapname] = plugin
self["plugins"] = plugins
return self["plugins"]
@property
def subCommands(self):
plugins = self.plugins
for name in sorted(plugins):
plugin = plugins[name]
yield (
plugin.tapname,
None,
# Avoid resolving the options attribute right away, in case
# it's a property with a non-trivial getter (eg, one which
# imports modules).
lambda plugin=plugin: plugin.options(),
plugin.description,
)
def postOptions(self):
Options.postOptions(self)
if self.subCommand is None:
raise UsageError("No plugin specified.")

View file

@ -0,0 +1,128 @@
# -*- test-case-name: twisted.application.twist.test.test_twist -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Run a Twisted application.
"""
import sys
from twisted.python.usage import UsageError
from ..service import Application, IService
from ..runner._exit import exit, ExitStatus
from ..runner._runner import Runner
from ._options import TwistOptions
from twisted.application.app import _exitWithSignal
from twisted.internet.interfaces import _ISupportsExitSignalCapturing
class Twist(object):
"""
Run a Twisted application.
"""
@staticmethod
def options(argv):
"""
Parse command line options.
@param argv: Command line arguments.
@type argv: L{list}
@return: The parsed options.
@rtype: L{TwistOptions}
"""
options = TwistOptions()
try:
options.parseOptions(argv[1:])
except UsageError as e:
exit(ExitStatus.EX_USAGE, "Error: {}\n\n{}".format(e, options))
return options
@staticmethod
def service(plugin, options):
"""
Create the application service.
@param plugin: The name of the plugin that implements the service
application to run.
@type plugin: L{str}
@param options: Options to pass to the application.
@type options: L{twisted.python.usage.Options}
@return: The created application service.
@rtype: L{IService}
"""
service = plugin.makeService(options)
application = Application(plugin.tapname)
service.setServiceParent(application)
return IService(application)
@staticmethod
def startService(reactor, service):
"""
Start the application service.
@param reactor: The reactor to run the service with.
@type reactor: L{twisted.internet.interfaces.IReactorCore}
@param service: The application service to run.
@type service: L{IService}
"""
service.startService()
# Ask the reactor to stop the service before shutting down
reactor.addSystemEventTrigger(
"before", "shutdown", service.stopService
)
@staticmethod
def run(twistOptions):
"""
Run the application service.
@param twistOptions: Command line options to convert to runner
arguments.
@type twistOptions: L{TwistOptions}
"""
runner = Runner(
reactor=twistOptions["reactor"],
defaultLogLevel=twistOptions["logLevel"],
logFile=twistOptions["logFile"],
fileLogObserverFactory=twistOptions["fileLogObserverFactory"],
)
runner.run()
reactor = twistOptions["reactor"]
if _ISupportsExitSignalCapturing.providedBy(reactor):
if reactor._exitSignal is not None:
_exitWithSignal(reactor._exitSignal)
@classmethod
def main(cls, argv=sys.argv):
"""
Executable entry point for L{Twist}.
Processes options and run a twisted reactor with a service.
@param argv: Command line arguments.
@type argv: L{list}
"""
options = cls.options(argv)
reactor = options["reactor"]
service = cls.service(
plugin=options.plugins[options.subCommand],
options=options.subOptions,
)
cls.startService(reactor, service)
cls.run(options)

View file

@ -0,0 +1,7 @@
# -*- test-case-name: twisted.application.twist.test -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.application.twist}.
"""

View file

@ -0,0 +1,385 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.application.twist._options}.
"""
from sys import stdout, stderr
from twisted.internet import reactor
from twisted.copyright import version
from twisted.python.usage import UsageError
from twisted.logger import LogLevel, textFileLogObserver, jsonFileLogObserver
from twisted.test.proto_helpers import MemoryReactor
from ...reactors import NoSuchReactor
from ...service import ServiceMaker
from ...runner._exit import ExitStatus
from ...runner.test.test_runner import DummyExit
from ...twist import _options
from .._options import TwistOptions
import twisted.trial.unittest
class OptionsTests(twisted.trial.unittest.TestCase):
"""
Tests for L{TwistOptions}.
"""
def patchExit(self):
"""
Patch L{_twist.exit} so we can capture usage and prevent actual exits.
"""
self.exit = DummyExit()
self.patch(_options, "exit", self.exit)
def patchOpen(self):
"""
Patch L{_options.open} so we can capture usage and prevent actual opens.
"""
self.opened = []
def fakeOpen(name, mode=None):
if name == "nocanopen":
raise IOError(None, None, name)
self.opened.append((name, mode))
return NotImplemented
self.patch(_options, "openFile", fakeOpen)
def patchInstallReactor(self):
"""
Patch C{_options.installReactor} so we can capture usage and prevent
actual installs.
"""
self.installedReactors = {}
def installReactor(name):
if name != "fusion":
raise NoSuchReactor()
reactor = MemoryReactor()
self.installedReactors[name] = reactor
return reactor
self.patch(_options, "installReactor", installReactor)
def test_synopsis(self):
"""
L{TwistOptions.getSynopsis} appends arguments.
"""
options = TwistOptions()
self.assertTrue(
options.getSynopsis().endswith(" plugin [plugin_options]")
)
def test_version(self):
"""
L{TwistOptions.opt_version} exits with L{ExitStatus.EX_OK} and prints
the version.
"""
self.patchExit()
options = TwistOptions()
options.opt_version()
self.assertEquals(self.exit.status, ExitStatus.EX_OK)
self.assertEquals(self.exit.message, version)
def test_reactor(self):
"""
L{TwistOptions.installReactor} installs the chosen reactor and sets
the reactor name.
"""
self.patchInstallReactor()
options = TwistOptions()
options.opt_reactor("fusion")
self.assertEqual(set(self.installedReactors), set(["fusion"]))
self.assertEquals(options["reactorName"], "fusion")
def test_installCorrectReactor(self):
"""
L{TwistOptions.installReactor} installs the chosen reactor after the
command line options have been parsed.
"""
self.patchInstallReactor()
options = TwistOptions()
options.subCommand = "test-subcommand"
options.parseOptions(["--reactor=fusion"])
self.assertEqual(set(self.installedReactors), set(["fusion"]))
def test_installReactorBogus(self):
"""
L{TwistOptions.installReactor} raises UsageError if an unknown reactor
is specified.
"""
self.patchInstallReactor()
options = TwistOptions()
self.assertRaises(UsageError, options.opt_reactor, "coal")
def test_installReactorDefault(self):
"""
L{TwistOptions.installReactor} returns the currently installed reactor
when the default reactor name is specified.
"""
options = TwistOptions()
self.assertIdentical(reactor, options.installReactor('default'))
def test_logLevelValid(self):
"""
L{TwistOptions.opt_log_level} sets the corresponding log level.
"""
options = TwistOptions()
options.opt_log_level("warn")
self.assertIdentical(options["logLevel"], LogLevel.warn)
def test_logLevelInvalid(self):
"""
L{TwistOptions.opt_log_level} with an invalid log level name raises
UsageError.
"""
options = TwistOptions()
self.assertRaises(UsageError, options.opt_log_level, "cheese")
def _testLogFile(self, name, expectedStream):
"""
Set log file name and check the selected output stream.
@param name: The name of the file.
@param expectedStream: The expected stream.
"""
options = TwistOptions()
options.opt_log_file(name)
self.assertIdentical(options["logFile"], expectedStream)
def test_logFileStdout(self):
"""
L{TwistOptions.opt_log_file} given C{"-"} as a file name uses stdout.
"""
self._testLogFile("-", stdout)
def test_logFileStderr(self):
"""
L{TwistOptions.opt_log_file} given C{"+"} as a file name uses stderr.
"""
self._testLogFile("+", stderr)
def test_logFileNamed(self):
"""
L{TwistOptions.opt_log_file} opens the given file name in append mode.
"""
self.patchOpen()
options = TwistOptions()
options.opt_log_file("mylog")
self.assertEqual([("mylog", "a")], self.opened)
def test_logFileCantOpen(self):
"""
L{TwistOptions.opt_log_file} exits with L{ExitStatus.EX_IOERR} if
unable to open the log file due to an L{EnvironmentError}.
"""
self.patchExit()
self.patchOpen()
options = TwistOptions()
options.opt_log_file("nocanopen")
self.assertEquals(self.exit.status, ExitStatus.EX_IOERR)
self.assertTrue(
self.exit.message.startswith(
"Unable to open log file 'nocanopen': "
)
)
def _testLogFormat(self, format, expectedObserver):
"""
Set log file format and check the selected observer.
@param format: The format of the file.
@param expectedObserver: The expected observer.
"""
options = TwistOptions()
options.opt_log_format(format)
self.assertIdentical(
options["fileLogObserverFactory"], expectedObserver
)
self.assertEqual(options["logFormat"], format)
def test_logFormatText(self):
"""
L{TwistOptions.opt_log_format} given C{"text"} uses a
L{textFileLogObserver}.
"""
self._testLogFormat("text", textFileLogObserver)
def test_logFormatJSON(self):
"""
L{TwistOptions.opt_log_format} given C{"text"} uses a
L{textFileLogObserver}.
"""
self._testLogFormat("json", jsonFileLogObserver)
def test_logFormatInvalid(self):
"""
L{TwistOptions.opt_log_format} given an invalid format name raises
L{UsageError}.
"""
options = TwistOptions()
self.assertRaises(UsageError, options.opt_log_format, "frommage")
def test_selectDefaultLogObserverNoOverride(self):
"""
L{TwistOptions.selectDefaultLogObserver} will not override an already
selected observer.
"""
self.patchOpen()
options = TwistOptions()
options.opt_log_format("text") # Ask for text
options.opt_log_file("queso") # File, not a tty
options.selectDefaultLogObserver()
# Because we didn't select a file that is a tty, the default is JSON,
# but since we asked for text, we should get text.
self.assertIdentical(
options["fileLogObserverFactory"], textFileLogObserver
)
self.assertEqual(options["logFormat"], "text")
def test_selectDefaultLogObserverDefaultWithTTY(self):
"""
L{TwistOptions.selectDefaultLogObserver} will not override an already
selected observer.
"""
class TTYFile(object):
def isatty(self):
return True
# stdout may not be a tty, so let's make sure it thinks it is
self.patch(_options, "stdout", TTYFile())
options = TwistOptions()
options.opt_log_file("-") # stdout, a tty
options.selectDefaultLogObserver()
self.assertIdentical(
options["fileLogObserverFactory"], textFileLogObserver
)
self.assertEqual(options["logFormat"], "text")
def test_selectDefaultLogObserverDefaultWithoutTTY(self):
"""
L{TwistOptions.selectDefaultLogObserver} will not override an already
selected observer.
"""
self.patchOpen()
options = TwistOptions()
options.opt_log_file("queso") # File, not a tty
options.selectDefaultLogObserver()
self.assertIdentical(
options["fileLogObserverFactory"], jsonFileLogObserver
)
self.assertEqual(options["logFormat"], "json")
def test_pluginsType(self):
"""
L{TwistOptions.plugins} is a mapping of available plug-ins.
"""
options = TwistOptions()
plugins = options.plugins
for name in plugins:
self.assertIsInstance(name, str)
self.assertIsInstance(plugins[name], ServiceMaker)
def test_pluginsIncludeWeb(self):
"""
L{TwistOptions.plugins} includes a C{"web"} plug-in.
This is an attempt to verify that something we expect to be in the list
is in there without enumerating all of the built-in plug-ins.
"""
options = TwistOptions()
self.assertIn("web", options.plugins)
def test_subCommandsType(self):
"""
L{TwistOptions.subCommands} is an iterable of tuples as expected by
L{twisted.python.usage.Options}.
"""
options = TwistOptions()
for name, shortcut, parser, doc in options.subCommands:
self.assertIsInstance(name, str)
self.assertIdentical(shortcut, None)
self.assertTrue(callable(parser))
self.assertIsInstance(doc, str)
def test_subCommandsIncludeWeb(self):
"""
L{TwistOptions.subCommands} includes a sub-command for every plug-in.
"""
options = TwistOptions()
plugins = set(options.plugins)
subCommands = set(
name for name, shortcut, parser, doc in options.subCommands
)
self.assertEqual(subCommands, plugins)
def test_postOptionsNoSubCommand(self):
"""
L{TwistOptions.postOptions} raises L{UsageError} is it has no
sub-command.
"""
self.patchInstallReactor()
options = TwistOptions()
self.assertRaises(UsageError, options.postOptions)

View file

@ -0,0 +1,269 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.application.twist._twist}.
"""
from sys import stdout
from twisted.logger import LogLevel, jsonFileLogObserver
from twisted.test.proto_helpers import MemoryReactor
from ...service import IService, MultiService
from ...runner._exit import ExitStatus
from ...runner._runner import Runner
from ...runner.test.test_runner import DummyExit
from ...twist import _twist
from .._options import TwistOptions
from .._twist import Twist
from twisted.test.test_twistd import SignalCapturingMemoryReactor
import twisted.trial.unittest
class TwistTests(twisted.trial.unittest.TestCase):
"""
Tests for L{Twist}.
"""
def setUp(self):
self.patchInstallReactor()
def patchExit(self):
"""
Patch L{_twist.exit} so we can capture usage and prevent actual exits.
"""
self.exit = DummyExit()
self.patch(_twist, "exit", self.exit)
def patchInstallReactor(self):
"""
Patch C{_options.installReactor} so we can capture usage and prevent
actual installs.
"""
self.installedReactors = {}
def installReactor(_, name):
reactor = MemoryReactor()
self.installedReactors[name] = reactor
return reactor
self.patch(TwistOptions, "installReactor", installReactor)
def patchStartService(self):
"""
Patch L{MultiService.startService} so we can capture usage and prevent
actual starts.
"""
self.serviceStarts = []
def startService(service):
self.serviceStarts.append(service)
self.patch(MultiService, "startService", startService)
def test_optionsValidArguments(self):
"""
L{Twist.options} given valid arguments returns options.
"""
options = Twist.options(["twist", "web"])
self.assertIsInstance(options, TwistOptions)
def test_optionsInvalidArguments(self):
"""
L{Twist.options} given invalid arguments exits with
L{ExitStatus.EX_USAGE} and an error/usage message.
"""
self.patchExit()
Twist.options(["twist", "--bogus-bagels"])
self.assertIdentical(self.exit.status, ExitStatus.EX_USAGE)
self.assertTrue(self.exit.message.startswith("Error: "))
self.assertTrue(self.exit.message.endswith(
"\n\n{}".format(TwistOptions())
))
def test_service(self):
"""
L{Twist.service} returns an L{IService}.
"""
options = Twist.options(["twist", "web"]) # web should exist
service = Twist.service(options.plugins["web"], options.subOptions)
self.assertTrue(IService.providedBy(service))
def test_startService(self):
"""
L{Twist.startService} starts the service and registers a trigger to
stop the service when the reactor shuts down.
"""
options = Twist.options(["twist", "web"])
reactor = options["reactor"]
service = Twist.service(
plugin=options.plugins[options.subCommand],
options=options.subOptions,
)
self.patchStartService()
Twist.startService(reactor, service)
self.assertEqual(self.serviceStarts, [service])
self.assertEqual(
reactor.triggers["before"]["shutdown"],
[(service.stopService, (), {})]
)
def test_run(self):
"""
L{Twist.run} runs the runner with arguments corresponding to the given
options.
"""
argsSeen = []
self.patch(
Runner, "__init__", lambda self, **args: argsSeen.append(args)
)
self.patch(
Runner, "run", lambda self: None
)
twistOptions = Twist.options([
"twist", "--reactor=default", "--log-format=json", "web"
])
Twist.run(twistOptions)
self.assertEqual(len(argsSeen), 1)
self.assertEqual(
argsSeen[0],
dict(
reactor=self.installedReactors["default"],
defaultLogLevel=LogLevel.info,
logFile=stdout,
fileLogObserverFactory=jsonFileLogObserver,
)
)
def test_main(self):
"""
L{Twist.main} runs the runner with arguments corresponding to the given
command line arguments.
"""
self.patchStartService()
runners = []
class Runner(object):
def __init__(self, **kwargs):
self.args = kwargs
self.runs = 0
runners.append(self)
def run(self):
self.runs += 1
self.patch(_twist, "Runner", Runner)
Twist.main([
"twist", "--reactor=default", "--log-format=json", "web"
])
self.assertEqual(len(self.serviceStarts), 1)
self.assertEqual(len(runners), 1)
self.assertEqual(
runners[0].args,
dict(
reactor=self.installedReactors["default"],
defaultLogLevel=LogLevel.info,
logFile=stdout,
fileLogObserverFactory=jsonFileLogObserver,
)
)
self.assertEqual(runners[0].runs, 1)
class TwistExitTests(twisted.trial.unittest.TestCase):
"""
Tests to verify that the Twist script takes the expected actions related
to signals and the reactor.
"""
def setUp(self):
self.exitWithSignalCalled = False
def fakeExitWithSignal(sig):
"""
Fake to capture whether L{twisted.application._exitWithSignal
was called.
@param sig: Signal value
@type sig: C{int}
"""
self.exitWithSignalCalled = True
self.patch(_twist, '_exitWithSignal', fakeExitWithSignal)
def startLogging(_):
"""
Prevent Runner from adding new log observers or other
tests outside this module will fail.
@param _: Unused self param
"""
self.patch(Runner, 'startLogging', startLogging)
def test_twistReactorDoesntExitWithSignal(self):
"""
_exitWithSignal is not called if the reactor's _exitSignal attribute
is zero.
"""
reactor = SignalCapturingMemoryReactor()
reactor._exitSignal = None
options = TwistOptions()
options["reactor"] = reactor
options["fileLogObserverFactory"] = jsonFileLogObserver
Twist.run(options)
self.assertFalse(self.exitWithSignalCalled)
def test_twistReactorHasNoExitSignalAttr(self):
"""
_exitWithSignal is not called if the runner's reactor does not
implement L{twisted.internet.interfaces._ISupportsExitSignalCapturing}
"""
reactor = MemoryReactor()
options = TwistOptions()
options["reactor"] = reactor
options["fileLogObserverFactory"] = jsonFileLogObserver
Twist.run(options)
self.assertFalse(self.exitWithSignalCalled)
def test_twistReactorExitsWithSignal(self):
"""
_exitWithSignal is called if the runner's reactor exits due
to a signal.
"""
reactor = SignalCapturingMemoryReactor()
reactor._exitSignal = 2
options = TwistOptions()
options["reactor"] = reactor
options["fileLogObserverFactory"] = jsonFileLogObserver
Twist.run(options)
self.assertTrue(self.exitWithSignalCalled)

View file

@ -0,0 +1,7 @@
# -*- test-case-name: twisted.conch.test -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Twisted Conch: The Twisted Shell. Terminal emulation, SSHv2 and telnet.
"""

View file

@ -0,0 +1,45 @@
# -*- test-case-name: twisted.conch.test.test_conch -*-
from __future__ import absolute_import, division
from zope.interface import implementer
from twisted.conch.error import ConchError
from twisted.conch.interfaces import IConchUser
from twisted.conch.ssh.connection import OPEN_UNKNOWN_CHANNEL_TYPE
from twisted.python import log
from twisted.python.compat import nativeString
@implementer(IConchUser)
class ConchUser:
def __init__(self):
self.channelLookup = {}
self.subsystemLookup = {}
def lookupChannel(self, channelType, windowSize, maxPacket, data):
klass = self.channelLookup.get(channelType, None)
if not klass:
raise ConchError(OPEN_UNKNOWN_CHANNEL_TYPE, "unknown channel")
else:
return klass(remoteWindow=windowSize,
remoteMaxPacket=maxPacket,
data=data, avatar=self)
def lookupSubsystem(self, subsystem, data):
log.msg(repr(self.subsystemLookup))
klass = self.subsystemLookup.get(subsystem, None)
if not klass:
return False
return klass(data, avatar=self)
def gotGlobalRequest(self, requestType, data):
# XXX should this use method dispatch?
requestType = nativeString(requestType.replace(b'-', b'_'))
f = getattr(self, "global_%s" % requestType, None)
if not f:
return 0
return f(data)

View file

@ -0,0 +1,592 @@
# -*- test-case-name: twisted.conch.test.test_checkers -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Provide L{ICredentialsChecker} implementations to be used in Conch protocols.
"""
from __future__ import absolute_import, division
import sys
import binascii
import errno
try:
import pwd
except ImportError:
pwd = None
else:
import crypt
try:
import spwd
except ImportError:
spwd = None
from zope.interface import providedBy, implementer, Interface
from incremental import Version
from twisted.conch import error
from twisted.conch.ssh import keys
from twisted.cred.checkers import ICredentialsChecker
from twisted.cred.credentials import IUsernamePassword, ISSHPrivateKey
from twisted.cred.error import UnauthorizedLogin, UnhandledCredentials
from twisted.internet import defer
from twisted.python.compat import _keys, _PY3, _b64decodebytes
from twisted.python import failure, reflect, log
from twisted.python.deprecate import deprecatedModuleAttribute
from twisted.python.util import runAsEffectiveUser
from twisted.python.filepath import FilePath
def verifyCryptedPassword(crypted, pw):
"""
Check that the password, when crypted, matches the stored crypted password.
@param crypted: The stored crypted password.
@type crypted: L{str}
@param pw: The password the user has given.
@type pw: L{str}
@rtype: L{bool}
"""
return crypt.crypt(pw, crypted) == crypted
def _pwdGetByName(username):
"""
Look up a user in the /etc/passwd database using the pwd module. If the
pwd module is not available, return None.
@param username: the username of the user to return the passwd database
information for.
@type username: L{str}
"""
if pwd is None:
return None
return pwd.getpwnam(username)
def _shadowGetByName(username):
"""
Look up a user in the /etc/shadow database using the spwd module. If it is
not available, return L{None}.
@param username: the username of the user to return the shadow database
information for.
@type username: L{str}
"""
if spwd is not None:
f = spwd.getspnam
else:
return None
return runAsEffectiveUser(0, 0, f, username)
@implementer(ICredentialsChecker)
class UNIXPasswordDatabase:
"""
A checker which validates users out of the UNIX password databases, or
databases of a compatible format.
@ivar _getByNameFunctions: a C{list} of functions which are called in order
to valid a user. The default value is such that the C{/etc/passwd}
database will be tried first, followed by the C{/etc/shadow} database.
"""
credentialInterfaces = IUsernamePassword,
def __init__(self, getByNameFunctions=None):
if getByNameFunctions is None:
getByNameFunctions = [_pwdGetByName, _shadowGetByName]
self._getByNameFunctions = getByNameFunctions
def requestAvatarId(self, credentials):
# We get bytes, but the Py3 pwd module uses str. So attempt to decode
# it using the same method that CPython does for the file on disk.
if _PY3:
username = credentials.username.decode(sys.getfilesystemencoding())
password = credentials.password.decode(sys.getfilesystemencoding())
else:
username = credentials.username
password = credentials.password
for func in self._getByNameFunctions:
try:
pwnam = func(username)
except KeyError:
return defer.fail(UnauthorizedLogin("invalid username"))
else:
if pwnam is not None:
crypted = pwnam[1]
if crypted == '':
continue
if verifyCryptedPassword(crypted, password):
return defer.succeed(credentials.username)
# fallback
return defer.fail(UnauthorizedLogin("unable to verify password"))
@implementer(ICredentialsChecker)
class SSHPublicKeyDatabase:
"""
Checker that authenticates SSH public keys, based on public keys listed in
authorized_keys and authorized_keys2 files in user .ssh/ directories.
"""
credentialInterfaces = (ISSHPrivateKey,)
_userdb = pwd
def requestAvatarId(self, credentials):
d = defer.maybeDeferred(self.checkKey, credentials)
d.addCallback(self._cbRequestAvatarId, credentials)
d.addErrback(self._ebRequestAvatarId)
return d
def _cbRequestAvatarId(self, validKey, credentials):
"""
Check whether the credentials themselves are valid, now that we know
if the key matches the user.
@param validKey: A boolean indicating whether or not the public key
matches a key in the user's authorized_keys file.
@param credentials: The credentials offered by the user.
@type credentials: L{ISSHPrivateKey} provider
@raise UnauthorizedLogin: (as a failure) if the key does not match the
user in C{credentials}. Also raised if the user provides an invalid
signature.
@raise ValidPublicKey: (as a failure) if the key matches the user but
the credentials do not include a signature. See
L{error.ValidPublicKey} for more information.
@return: The user's username, if authentication was successful.
"""
if not validKey:
return failure.Failure(UnauthorizedLogin("invalid key"))
if not credentials.signature:
return failure.Failure(error.ValidPublicKey())
else:
try:
pubKey = keys.Key.fromString(credentials.blob)
if pubKey.verify(credentials.signature, credentials.sigData):
return credentials.username
except: # any error should be treated as a failed login
log.err()
return failure.Failure(UnauthorizedLogin('error while verifying key'))
return failure.Failure(UnauthorizedLogin("unable to verify key"))
def getAuthorizedKeysFiles(self, credentials):
"""
Return a list of L{FilePath} instances for I{authorized_keys} files
which might contain information about authorized keys for the given
credentials.
On OpenSSH servers, the default location of the file containing the
list of authorized public keys is
U{$HOME/.ssh/authorized_keys<http://www.openbsd.org/cgi-bin/man.cgi?query=sshd_config>}.
I{$HOME/.ssh/authorized_keys2} is also returned, though it has been
U{deprecated by OpenSSH since
2001<http://marc.info/?m=100508718416162>}.
@return: A list of L{FilePath} instances to files with the authorized keys.
"""
pwent = self._userdb.getpwnam(credentials.username)
root = FilePath(pwent.pw_dir).child('.ssh')
files = ['authorized_keys', 'authorized_keys2']
return [root.child(f) for f in files]
def checkKey(self, credentials):
"""
Retrieve files containing authorized keys and check against user
credentials.
"""
ouid, ogid = self._userdb.getpwnam(credentials.username)[2:4]
for filepath in self.getAuthorizedKeysFiles(credentials):
if not filepath.exists():
continue
try:
lines = filepath.open()
except IOError as e:
if e.errno == errno.EACCES:
lines = runAsEffectiveUser(ouid, ogid, filepath.open)
else:
raise
with lines:
for l in lines:
l2 = l.split()
if len(l2) < 2:
continue
try:
if _b64decodebytes(l2[1]) == credentials.blob:
return True
except binascii.Error:
continue
return False
def _ebRequestAvatarId(self, f):
if not f.check(UnauthorizedLogin):
log.msg(f)
return failure.Failure(UnauthorizedLogin("unable to get avatar id"))
return f
@implementer(ICredentialsChecker)
class SSHProtocolChecker:
"""
SSHProtocolChecker is a checker that requires multiple authentications
to succeed. To add a checker, call my registerChecker method with
the checker and the interface.
After each successful authenticate, I call my areDone method with the
avatar id. To get a list of the successful credentials for an avatar id,
use C{SSHProcotolChecker.successfulCredentials[avatarId]}. If L{areDone}
returns True, the authentication has succeeded.
"""
def __init__(self):
self.checkers = {}
self.successfulCredentials = {}
def get_credentialInterfaces(self):
return _keys(self.checkers)
credentialInterfaces = property(get_credentialInterfaces)
def registerChecker(self, checker, *credentialInterfaces):
if not credentialInterfaces:
credentialInterfaces = checker.credentialInterfaces
for credentialInterface in credentialInterfaces:
self.checkers[credentialInterface] = checker
def requestAvatarId(self, credentials):
"""
Part of the L{ICredentialsChecker} interface. Called by a portal with
some credentials to check if they'll authenticate a user. We check the
interfaces that the credentials provide against our list of acceptable
checkers. If one of them matches, we ask that checker to verify the
credentials. If they're valid, we call our L{_cbGoodAuthentication}
method to continue.
@param credentials: the credentials the L{Portal} wants us to verify
"""
ifac = providedBy(credentials)
for i in ifac:
c = self.checkers.get(i)
if c is not None:
d = defer.maybeDeferred(c.requestAvatarId, credentials)
return d.addCallback(self._cbGoodAuthentication,
credentials)
return defer.fail(UnhandledCredentials("No checker for %s" % \
', '.join(map(reflect.qual, ifac))))
def _cbGoodAuthentication(self, avatarId, credentials):
"""
Called if a checker has verified the credentials. We call our
L{areDone} method to see if the whole of the successful authentications
are enough. If they are, we return the avatar ID returned by the first
checker.
"""
if avatarId not in self.successfulCredentials:
self.successfulCredentials[avatarId] = []
self.successfulCredentials[avatarId].append(credentials)
if self.areDone(avatarId):
del self.successfulCredentials[avatarId]
return avatarId
else:
raise error.NotEnoughAuthentication()
def areDone(self, avatarId):
"""
Override to determine if the authentication is finished for a given
avatarId.
@param avatarId: the avatar returned by the first checker. For
this checker to function correctly, all the checkers must
return the same avatar ID.
"""
return True
deprecatedModuleAttribute(
Version("Twisted", 15, 0, 0),
("Please use twisted.conch.checkers.SSHPublicKeyChecker, "
"initialized with an instance of "
"twisted.conch.checkers.UNIXAuthorizedKeysFiles instead."),
__name__, "SSHPublicKeyDatabase")
class IAuthorizedKeysDB(Interface):
"""
An object that provides valid authorized ssh keys mapped to usernames.
@since: 15.0
"""
def getAuthorizedKeys(avatarId):
"""
Gets an iterable of authorized keys that are valid for the given
C{avatarId}.
@param avatarId: the ID of the avatar
@type avatarId: valid return value of
L{twisted.cred.checkers.ICredentialsChecker.requestAvatarId}
@return: an iterable of L{twisted.conch.ssh.keys.Key}
"""
def readAuthorizedKeyFile(fileobj, parseKey=keys.Key.fromString):
"""
Reads keys from an authorized keys file. Any non-comment line that cannot
be parsed as a key will be ignored, although that particular line will
be logged.
@param fileobj: something from which to read lines which can be parsed
as keys
@type fileobj: L{file}-like object
@param parseKey: a callable that takes a string and returns a
L{twisted.conch.ssh.keys.Key}, mainly to be used for testing. The
default is L{twisted.conch.ssh.keys.Key.fromString}.
@type parseKey: L{callable}
@return: an iterable of L{twisted.conch.ssh.keys.Key}
@rtype: iterable
@since: 15.0
"""
for line in fileobj:
line = line.strip()
if line and not line.startswith(b'#'): # for comments
try:
yield parseKey(line)
except keys.BadKeyError as e:
log.msg('Unable to parse line "{0}" as a key: {1!s}'
.format(line, e))
def _keysFromFilepaths(filepaths, parseKey):
"""
Helper function that turns an iterable of filepaths into a generator of
keys. If any file cannot be read, a message is logged but it is
otherwise ignored.
@param filepaths: iterable of L{twisted.python.filepath.FilePath}.
@type filepaths: iterable
@param parseKey: a callable that takes a string and returns a
L{twisted.conch.ssh.keys.Key}
@type parseKey: L{callable}
@return: generator of L{twisted.conch.ssh.keys.Key}
@rtype: generator
@since: 15.0
"""
for fp in filepaths:
if fp.exists():
try:
with fp.open() as f:
for key in readAuthorizedKeyFile(f, parseKey):
yield key
except (IOError, OSError) as e:
log.msg("Unable to read {0}: {1!s}".format(fp.path, e))
@implementer(IAuthorizedKeysDB)
class InMemorySSHKeyDB(object):
"""
Object that provides SSH public keys based on a dictionary of usernames
mapped to L{twisted.conch.ssh.keys.Key}s.
@since: 15.0
"""
def __init__(self, mapping):
"""
Initializes a new L{InMemorySSHKeyDB}.
@param mapping: mapping of usernames to iterables of
L{twisted.conch.ssh.keys.Key}s
@type mapping: L{dict}
"""
self._mapping = mapping
def getAuthorizedKeys(self, username):
return self._mapping.get(username, [])
@implementer(IAuthorizedKeysDB)
class UNIXAuthorizedKeysFiles(object):
"""
Object that provides SSH public keys based on public keys listed in
authorized_keys and authorized_keys2 files in UNIX user .ssh/ directories.
If any of the files cannot be read, a message is logged but that file is
otherwise ignored.
@since: 15.0
"""
def __init__(self, userdb=None, parseKey=keys.Key.fromString):
"""
Initializes a new L{UNIXAuthorizedKeysFiles}.
@param userdb: access to the Unix user account and password database
(default is the Python module L{pwd})
@type userdb: L{pwd}-like object
@param parseKey: a callable that takes a string and returns a
L{twisted.conch.ssh.keys.Key}, mainly to be used for testing. The
default is L{twisted.conch.ssh.keys.Key.fromString}.
@type parseKey: L{callable}
"""
self._userdb = userdb
self._parseKey = parseKey
if userdb is None:
self._userdb = pwd
def getAuthorizedKeys(self, username):
try:
passwd = self._userdb.getpwnam(username)
except KeyError:
return ()
root = FilePath(passwd.pw_dir).child('.ssh')
files = ['authorized_keys', 'authorized_keys2']
return _keysFromFilepaths((root.child(f) for f in files),
self._parseKey)
@implementer(ICredentialsChecker)
class SSHPublicKeyChecker(object):
"""
Checker that authenticates SSH public keys, based on public keys listed in
authorized_keys and authorized_keys2 files in user .ssh/ directories.
Initializing this checker with a L{UNIXAuthorizedKeysFiles} should be
used instead of L{twisted.conch.checkers.SSHPublicKeyDatabase}.
@since: 15.0
"""
credentialInterfaces = (ISSHPrivateKey,)
def __init__(self, keydb):
"""
Initializes a L{SSHPublicKeyChecker}.
@param keydb: a provider of L{IAuthorizedKeysDB}
@type keydb: L{IAuthorizedKeysDB} provider
"""
self._keydb = keydb
def requestAvatarId(self, credentials):
d = defer.maybeDeferred(self._sanityCheckKey, credentials)
d.addCallback(self._checkKey, credentials)
d.addCallback(self._verifyKey, credentials)
return d
def _sanityCheckKey(self, credentials):
"""
Checks whether the provided credentials are a valid SSH key with a
signature (does not actually verify the signature).
@param credentials: the credentials offered by the user
@type credentials: L{ISSHPrivateKey} provider
@raise ValidPublicKey: the credentials do not include a signature. See
L{error.ValidPublicKey} for more information.
@raise BadKeyError: The key included with the credentials is not
recognized as a key.
@return: the key in the credentials
@rtype: L{twisted.conch.ssh.keys.Key}
"""
if not credentials.signature:
raise error.ValidPublicKey()
return keys.Key.fromString(credentials.blob)
def _checkKey(self, pubKey, credentials):
"""
Checks the public key against all authorized keys (if any) for the
user.
@param pubKey: the key in the credentials (just to prevent it from
having to be calculated again)
@type pubKey:
@param credentials: the credentials offered by the user
@type credentials: L{ISSHPrivateKey} provider
@raise UnauthorizedLogin: If the key is not authorized, or if there
was any error obtaining a list of authorized keys for the user.
@return: C{pubKey} if the key is authorized
@rtype: L{twisted.conch.ssh.keys.Key}
"""
if any(key == pubKey for key in
self._keydb.getAuthorizedKeys(credentials.username)):
return pubKey
raise UnauthorizedLogin("Key not authorized")
def _verifyKey(self, pubKey, credentials):
"""
Checks whether the credentials themselves are valid, now that we know
if the key matches the user.
@param pubKey: the key in the credentials (just to prevent it from
having to be calculated again)
@type pubKey: L{twisted.conch.ssh.keys.Key}
@param credentials: the credentials offered by the user
@type credentials: L{ISSHPrivateKey} provider
@raise UnauthorizedLogin: If the key signature is invalid or there
was any error verifying the signature.
@return: The user's username, if authentication was successful
@rtype: L{bytes}
"""
try:
if pubKey.verify(credentials.signature, credentials.sigData):
return credentials.username
except: # Any error should be treated as a failed login
log.err()
raise UnauthorizedLogin('Error while verifying key')
raise UnauthorizedLogin("Key signature invalid.")

View file

@ -0,0 +1,9 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
#
"""
Client support code for Conch.
Maintainer: Paul Swartz
"""

View file

@ -0,0 +1,73 @@
# -*- test-case-name: twisted.conch.test.test_default -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Accesses the key agent for user authentication.
Maintainer: Paul Swartz
"""
import os
from twisted.conch.ssh import agent, channel, keys
from twisted.internet import protocol, reactor
from twisted.python import log
class SSHAgentClient(agent.SSHAgentClient):
def __init__(self):
agent.SSHAgentClient.__init__(self)
self.blobs = []
def getPublicKeys(self):
return self.requestIdentities().addCallback(self._cbPublicKeys)
def _cbPublicKeys(self, blobcomm):
log.msg('got %i public keys' % len(blobcomm))
self.blobs = [x[0] for x in blobcomm]
def getPublicKey(self):
"""
Return a L{Key} from the first blob in C{self.blobs}, if any, or
return L{None}.
"""
if self.blobs:
return keys.Key.fromString(self.blobs.pop(0))
return None
class SSHAgentForwardingChannel(channel.SSHChannel):
def channelOpen(self, specificData):
cc = protocol.ClientCreator(reactor, SSHAgentForwardingLocal)
d = cc.connectUNIX(os.environ['SSH_AUTH_SOCK'])
d.addCallback(self._cbGotLocal)
d.addErrback(lambda x:self.loseConnection())
self.buf = ''
def _cbGotLocal(self, local):
self.local = local
self.dataReceived = self.local.transport.write
self.local.dataReceived = self.write
def dataReceived(self, data):
self.buf += data
def closed(self):
if self.local:
self.local.loseConnection()
self.local = None
class SSHAgentForwardingLocal(protocol.Protocol):
pass

View file

@ -0,0 +1,21 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
#
from twisted.conch.client import direct
connectTypes = {"direct" : direct.connect}
def connect(host, port, options, verifyHostKey, userAuthObject):
useConnects = ['direct']
return _ebConnect(None, useConnects, host, port, options, verifyHostKey,
userAuthObject)
def _ebConnect(f, useConnects, host, port, options, vhk, uao):
if not useConnects:
return f
connectType = useConnects.pop(0)
f = connectTypes[connectType]
d = f(host, port, options, vhk, uao)
d.addErrback(_ebConnect, useConnects, host, port, options, vhk, uao)
return d

View file

@ -0,0 +1,349 @@
# -*- test-case-name: twisted.conch.test.test_knownhosts,twisted.conch.test.test_default -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Various classes and functions for implementing user-interaction in the
command-line conch client.
You probably shouldn't use anything in this module directly, since it assumes
you are sitting at an interactive terminal. For example, to programmatically
interact with a known_hosts database, use L{twisted.conch.client.knownhosts}.
"""
from __future__ import print_function
from twisted.python import log
from twisted.python.compat import (
nativeString, raw_input, _PY3, _b64decodebytes as decodebytes)
from twisted.python.filepath import FilePath
from twisted.conch.error import ConchError
from twisted.conch.ssh import common, keys, userauth
from twisted.internet import defer, protocol, reactor
from twisted.conch.client.knownhosts import KnownHostsFile, ConsoleUI
from twisted.conch.client import agent
import os, sys, getpass, contextlib
if _PY3:
import io
# The default location of the known hosts file (probably should be parsed out
# of an ssh config file someday).
_KNOWN_HOSTS = "~/.ssh/known_hosts"
# This name is bound so that the unit tests can use 'patch' to override it.
_open = open
def verifyHostKey(transport, host, pubKey, fingerprint):
"""
Verify a host's key.
This function is a gross vestige of some bad factoring in the client
internals. The actual implementation, and a better signature of this logic
is in L{KnownHostsFile.verifyHostKey}. This function is not deprecated yet
because the callers have not yet been rehabilitated, but they should
eventually be changed to call that method instead.
However, this function does perform two functions not implemented by
L{KnownHostsFile.verifyHostKey}. It determines the path to the user's
known_hosts file based on the options (which should really be the options
object's job), and it provides an opener to L{ConsoleUI} which opens
'/dev/tty' so that the user will be prompted on the tty of the process even
if the input and output of the process has been redirected. This latter
part is, somewhat obviously, not portable, but I don't know of a portable
equivalent that could be used.
@param host: Due to a bug in L{SSHClientTransport.verifyHostKey}, this is
always the dotted-quad IP address of the host being connected to.
@type host: L{str}
@param transport: the client transport which is attempting to connect to
the given host.
@type transport: L{SSHClientTransport}
@param fingerprint: the fingerprint of the given public key, in
xx:xx:xx:... format. This is ignored in favor of getting the fingerprint
from the key itself.
@type fingerprint: L{str}
@param pubKey: The public key of the server being connected to.
@type pubKey: L{str}
@return: a L{Deferred} which fires with C{1} if the key was successfully
verified, or fails if the key could not be successfully verified. Failure
types may include L{HostKeyChanged}, L{UserRejectedKey}, L{IOError} or
L{KeyboardInterrupt}.
"""
actualHost = transport.factory.options['host']
actualKey = keys.Key.fromString(pubKey)
kh = KnownHostsFile.fromPath(FilePath(
transport.factory.options['known-hosts']
or os.path.expanduser(_KNOWN_HOSTS)
))
ui = ConsoleUI(lambda : _open("/dev/tty", "r+b", buffering=0))
return kh.verifyHostKey(ui, actualHost, host, actualKey)
def isInKnownHosts(host, pubKey, options):
"""
Checks to see if host is in the known_hosts file for the user.
@return: 0 if it isn't, 1 if it is and is the same, 2 if it's changed.
@rtype: L{int}
"""
keyType = common.getNS(pubKey)[0]
retVal = 0
if not options['known-hosts'] and not os.path.exists(os.path.expanduser('~/.ssh/')):
print('Creating ~/.ssh directory...')
os.mkdir(os.path.expanduser('~/.ssh'))
kh_file = options['known-hosts'] or _KNOWN_HOSTS
try:
known_hosts = open(os.path.expanduser(kh_file), 'rb')
except IOError:
return 0
with known_hosts:
for line in known_hosts.readlines():
split = line.split()
if len(split) < 3:
continue
hosts, hostKeyType, encodedKey = split[:3]
if host not in hosts.split(b','): # incorrect host
continue
if hostKeyType != keyType: # incorrect type of key
continue
try:
decodedKey = decodebytes(encodedKey)
except:
continue
if decodedKey == pubKey:
return 1
else:
retVal = 2
return retVal
def getHostKeyAlgorithms(host, options):
"""
Look in known_hosts for a key corresponding to C{host}.
This can be used to change the order of supported key types
in the KEXINIT packet.
@type host: L{str}
@param host: the host to check in known_hosts
@type options: L{twisted.conch.client.options.ConchOptions}
@param options: options passed to client
@return: L{list} of L{str} representing key types or L{None}.
"""
knownHosts = KnownHostsFile.fromPath(FilePath(
options['known-hosts']
or os.path.expanduser(_KNOWN_HOSTS)
))
keyTypes = []
for entry in knownHosts.iterentries():
if entry.matchesHost(host):
if entry.keyType not in keyTypes:
keyTypes.append(entry.keyType)
return keyTypes or None
class SSHUserAuthClient(userauth.SSHUserAuthClient):
def __init__(self, user, options, *args):
userauth.SSHUserAuthClient.__init__(self, user, *args)
self.keyAgent = None
self.options = options
self.usedFiles = []
if not options.identitys:
options.identitys = ['~/.ssh/id_rsa', '~/.ssh/id_dsa']
def serviceStarted(self):
if 'SSH_AUTH_SOCK' in os.environ and not self.options['noagent']:
log.msg('using agent')
cc = protocol.ClientCreator(reactor, agent.SSHAgentClient)
d = cc.connectUNIX(os.environ['SSH_AUTH_SOCK'])
d.addCallback(self._setAgent)
d.addErrback(self._ebSetAgent)
else:
userauth.SSHUserAuthClient.serviceStarted(self)
def serviceStopped(self):
if self.keyAgent:
self.keyAgent.transport.loseConnection()
self.keyAgent = None
def _setAgent(self, a):
self.keyAgent = a
d = self.keyAgent.getPublicKeys()
d.addBoth(self._ebSetAgent)
return d
def _ebSetAgent(self, f):
userauth.SSHUserAuthClient.serviceStarted(self)
def _getPassword(self, prompt):
"""
Prompt for a password using L{getpass.getpass}.
@param prompt: Written on tty to ask for the input.
@type prompt: L{str}
@return: The input.
@rtype: L{str}
"""
with self._replaceStdoutStdin():
try:
p = getpass.getpass(prompt)
return p
except (KeyboardInterrupt, IOError):
print()
raise ConchError('PEBKAC')
def getPassword(self, prompt = None):
if prompt:
prompt = nativeString(prompt)
else:
prompt = ("%s@%s's password: " %
(nativeString(self.user), self.transport.transport.getPeer().host))
try:
# We don't know the encoding the other side is using,
# signaling that is not part of the SSH protocol. But
# using our defaultencoding is better than just going for
# ASCII.
p = self._getPassword(prompt).encode(sys.getdefaultencoding())
return defer.succeed(p)
except ConchError:
return defer.fail()
def getPublicKey(self):
"""
Get a public key from the key agent if possible, otherwise look in
the next configured identity file for one.
"""
if self.keyAgent:
key = self.keyAgent.getPublicKey()
if key is not None:
return key
files = [x for x in self.options.identitys if x not in self.usedFiles]
log.msg(str(self.options.identitys))
log.msg(str(files))
if not files:
return None
file = files[0]
log.msg(file)
self.usedFiles.append(file)
file = os.path.expanduser(file)
file += '.pub'
if not os.path.exists(file):
return self.getPublicKey() # try again
try:
return keys.Key.fromFile(file)
except keys.BadKeyError:
return self.getPublicKey() # try again
def signData(self, publicKey, signData):
"""
Extend the base signing behavior by using an SSH agent to sign the
data, if one is available.
@type publicKey: L{Key}
@type signData: L{bytes}
"""
if not self.usedFiles: # agent key
return self.keyAgent.signData(publicKey.blob(), signData)
else:
return userauth.SSHUserAuthClient.signData(self, publicKey, signData)
def getPrivateKey(self):
"""
Try to load the private key from the last used file identified by
C{getPublicKey}, potentially asking for the passphrase if the key is
encrypted.
"""
file = os.path.expanduser(self.usedFiles[-1])
if not os.path.exists(file):
return None
try:
return defer.succeed(keys.Key.fromFile(file))
except keys.EncryptedKeyError:
for i in range(3):
prompt = "Enter passphrase for key '%s': " % self.usedFiles[-1]
try:
p = self._getPassword(prompt).encode(
sys.getfilesystemencoding())
return defer.succeed(keys.Key.fromFile(file, passphrase=p))
except (keys.BadKeyError, ConchError):
pass
return defer.fail(ConchError('bad password'))
raise
except KeyboardInterrupt:
print()
reactor.stop()
def getGenericAnswers(self, name, instruction, prompts):
responses = []
with self._replaceStdoutStdin():
if name:
print(name.decode("utf-8"))
if instruction:
print(instruction.decode("utf-8"))
for prompt, echo in prompts:
prompt = prompt.decode("utf-8")
if echo:
responses.append(raw_input(prompt))
else:
responses.append(getpass.getpass(prompt))
return defer.succeed(responses)
@classmethod
def _openTty(cls):
"""
Open /dev/tty as two streams one in read, one in write mode,
and return them.
@return: File objects for reading and writing to /dev/tty,
corresponding to standard input and standard output.
@rtype: A L{tuple} of L{io.TextIOWrapper} on Python 3.
A L{tuple} of binary files on Python 2.
"""
stdin = open("/dev/tty", "rb")
stdout = open("/dev/tty", "wb")
if _PY3:
stdin = io.TextIOWrapper(stdin)
stdout = io.TextIOWrapper(stdout)
return stdin, stdout
@classmethod
@contextlib.contextmanager
def _replaceStdoutStdin(cls):
"""
Contextmanager that replaces stdout and stdin with /dev/tty
and resets them when it is done.
"""
oldout, oldin = sys.stdout, sys.stdin
sys.stdin, sys.stdout = cls._openTty()
try:
yield
finally:
sys.stdout.close()
sys.stdin.close()
sys.stdout, sys.stdin = oldout, oldin

View file

@ -0,0 +1,109 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
from __future__ import print_function
from twisted.internet import defer, protocol, reactor
from twisted.conch import error
from twisted.conch.ssh import transport
from twisted.python import log
class SSHClientFactory(protocol.ClientFactory):
def __init__(self, d, options, verifyHostKey, userAuthObject):
self.d = d
self.options = options
self.verifyHostKey = verifyHostKey
self.userAuthObject = userAuthObject
def clientConnectionLost(self, connector, reason):
if self.options['reconnect']:
connector.connect()
def clientConnectionFailed(self, connector, reason):
if self.d is None:
return
d, self.d = self.d, None
d.errback(reason)
def buildProtocol(self, addr):
trans = SSHClientTransport(self)
if self.options['ciphers']:
trans.supportedCiphers = self.options['ciphers']
if self.options['macs']:
trans.supportedMACs = self.options['macs']
if self.options['compress']:
trans.supportedCompressions[0:1] = ['zlib']
if self.options['host-key-algorithms']:
trans.supportedPublicKeys = self.options['host-key-algorithms']
return trans
class SSHClientTransport(transport.SSHClientTransport):
def __init__(self, factory):
self.factory = factory
self.unixServer = None
def connectionLost(self, reason):
if self.unixServer:
d = self.unixServer.stopListening()
self.unixServer = None
else:
d = defer.succeed(None)
d.addCallback(lambda x:
transport.SSHClientTransport.connectionLost(self, reason))
def receiveError(self, code, desc):
if self.factory.d is None:
return
d, self.factory.d = self.factory.d, None
d.errback(error.ConchError(desc, code))
def sendDisconnect(self, code, reason):
if self.factory.d is None:
return
d, self.factory.d = self.factory.d, None
transport.SSHClientTransport.sendDisconnect(self, code, reason)
d.errback(error.ConchError(reason, code))
def receiveDebug(self, alwaysDisplay, message, lang):
log.msg('Received Debug Message: %s' % message)
if alwaysDisplay: # XXX what should happen here?
print(message)
def verifyHostKey(self, pubKey, fingerprint):
return self.factory.verifyHostKey(self, self.transport.getPeer().host, pubKey,
fingerprint)
def setService(self, service):
log.msg('setting client server to %s' % service)
transport.SSHClientTransport.setService(self, service)
if service.name != 'ssh-userauth' and self.factory.d is not None:
d, self.factory.d = self.factory.d, None
d.callback(None)
def connectionSecure(self):
self.requestService(self.factory.userAuthObject)
def connect(host, port, options, verifyHostKey, userAuthObject):
d = defer.Deferred()
factory = SSHClientFactory(d, options, verifyHostKey, userAuthObject)
reactor.connectTCP(host, port, factory)
return d

View file

@ -0,0 +1,630 @@
# -*- test-case-name: twisted.conch.test.test_knownhosts -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
An implementation of the OpenSSH known_hosts database.
@since: 8.2
"""
from __future__ import absolute_import, division
import hmac
from binascii import Error as DecodeError, b2a_base64, a2b_base64
from contextlib import closing
from hashlib import sha1
import sys
from zope.interface import implementer
from twisted.conch.interfaces import IKnownHostEntry
from twisted.conch.error import HostKeyChanged, UserRejectedKey, InvalidEntry
from twisted.conch.ssh.keys import Key, BadKeyError, FingerprintFormats
from twisted.internet import defer
from twisted.python import log
from twisted.python.compat import nativeString, unicode
from twisted.python.randbytes import secureRandom
from twisted.python.util import FancyEqMixin
def _b64encode(s):
"""
Encode a binary string as base64 with no trailing newline.
@param s: The string to encode.
@type s: L{bytes}
@return: The base64-encoded string.
@rtype: L{bytes}
"""
return b2a_base64(s).strip()
def _extractCommon(string):
"""
Extract common elements of base64 keys from an entry in a hosts file.
@param string: A known hosts file entry (a single line).
@type string: L{bytes}
@return: a 4-tuple of hostname data (L{bytes}), ssh key type (L{bytes}), key
(L{Key}), and comment (L{bytes} or L{None}). The hostname data is
simply the beginning of the line up to the first occurrence of
whitespace.
@rtype: L{tuple}
"""
elements = string.split(None, 2)
if len(elements) != 3:
raise InvalidEntry()
hostnames, keyType, keyAndComment = elements
splitkey = keyAndComment.split(None, 1)
if len(splitkey) == 2:
keyString, comment = splitkey
comment = comment.rstrip(b"\n")
else:
keyString = splitkey[0]
comment = None
key = Key.fromString(a2b_base64(keyString))
return hostnames, keyType, key, comment
class _BaseEntry(object):
"""
Abstract base of both hashed and non-hashed entry objects, since they
represent keys and key types the same way.
@ivar keyType: The type of the key; either ssh-dss or ssh-rsa.
@type keyType: L{bytes}
@ivar publicKey: The server public key indicated by this line.
@type publicKey: L{twisted.conch.ssh.keys.Key}
@ivar comment: Trailing garbage after the key line.
@type comment: L{bytes}
"""
def __init__(self, keyType, publicKey, comment):
self.keyType = keyType
self.publicKey = publicKey
self.comment = comment
def matchesKey(self, keyObject):
"""
Check to see if this entry matches a given key object.
@param keyObject: A public key object to check.
@type keyObject: L{Key}
@return: C{True} if this entry's key matches C{keyObject}, C{False}
otherwise.
@rtype: L{bool}
"""
return self.publicKey == keyObject
@implementer(IKnownHostEntry)
class PlainEntry(_BaseEntry):
"""
A L{PlainEntry} is a representation of a plain-text entry in a known_hosts
file.
@ivar _hostnames: the list of all host-names associated with this entry.
@type _hostnames: L{list} of L{bytes}
"""
def __init__(self, hostnames, keyType, publicKey, comment):
self._hostnames = hostnames
super(PlainEntry, self).__init__(keyType, publicKey, comment)
@classmethod
def fromString(cls, string):
"""
Parse a plain-text entry in a known_hosts file, and return a
corresponding L{PlainEntry}.
@param string: a space-separated string formatted like "hostname
key-type base64-key-data comment".
@type string: L{bytes}
@raise DecodeError: if the key is not valid encoded as valid base64.
@raise InvalidEntry: if the entry does not have the right number of
elements and is therefore invalid.
@raise BadKeyError: if the key, once decoded from base64, is not
actually an SSH key.
@return: an IKnownHostEntry representing the hostname and key in the
input line.
@rtype: L{PlainEntry}
"""
hostnames, keyType, key, comment = _extractCommon(string)
self = cls(hostnames.split(b","), keyType, key, comment)
return self
def matchesHost(self, hostname):
"""
Check to see if this entry matches a given hostname.
@param hostname: A hostname or IP address literal to check against this
entry.
@type hostname: L{bytes}
@return: C{True} if this entry is for the given hostname or IP address,
C{False} otherwise.
@rtype: L{bool}
"""
if isinstance(hostname, unicode):
hostname = hostname.encode("utf-8")
return hostname in self._hostnames
def toString(self):
"""
Implement L{IKnownHostEntry.toString} by recording the comma-separated
hostnames, key type, and base-64 encoded key.
@return: The string representation of this entry, with unhashed hostname
information.
@rtype: L{bytes}
"""
fields = [b','.join(self._hostnames),
self.keyType,
_b64encode(self.publicKey.blob())]
if self.comment is not None:
fields.append(self.comment)
return b' '.join(fields)
@implementer(IKnownHostEntry)
class UnparsedEntry(object):
"""
L{UnparsedEntry} is an entry in a L{KnownHostsFile} which can't actually be
parsed; therefore it matches no keys and no hosts.
"""
def __init__(self, string):
"""
Create an unparsed entry from a line in a known_hosts file which cannot
otherwise be parsed.
"""
self._string = string
def matchesHost(self, hostname):
"""
Always returns False.
"""
return False
def matchesKey(self, key):
"""
Always returns False.
"""
return False
def toString(self):
"""
Returns the input line, without its newline if one was given.
@return: The string representation of this entry, almost exactly as was
used to initialize this entry but without a trailing newline.
@rtype: L{bytes}
"""
return self._string.rstrip(b"\n")
def _hmacedString(key, string):
"""
Return the SHA-1 HMAC hash of the given key and string.
@param key: The HMAC key.
@type key: L{bytes}
@param string: The string to be hashed.
@type string: L{bytes}
@return: The keyed hash value.
@rtype: L{bytes}
"""
hash = hmac.HMAC(key, digestmod=sha1)
if isinstance(string, unicode):
string = string.encode("utf-8")
hash.update(string)
return hash.digest()
@implementer(IKnownHostEntry)
class HashedEntry(_BaseEntry, FancyEqMixin):
"""
A L{HashedEntry} is a representation of an entry in a known_hosts file
where the hostname has been hashed and salted.
@ivar _hostSalt: the salt to combine with a hostname for hashing.
@ivar _hostHash: the hashed representation of the hostname.
@cvar MAGIC: the 'hash magic' string used to identify a hashed line in a
known_hosts file as opposed to a plaintext one.
"""
MAGIC = b'|1|'
compareAttributes = (
"_hostSalt", "_hostHash", "keyType", "publicKey", "comment")
def __init__(self, hostSalt, hostHash, keyType, publicKey, comment):
self._hostSalt = hostSalt
self._hostHash = hostHash
super(HashedEntry, self).__init__(keyType, publicKey, comment)
@classmethod
def fromString(cls, string):
"""
Load a hashed entry from a string representing a line in a known_hosts
file.
@param string: A complete single line from a I{known_hosts} file,
formatted as defined by OpenSSH.
@type string: L{bytes}
@raise DecodeError: if the key, the hostname, or the is not valid
encoded as valid base64
@raise InvalidEntry: if the entry does not have the right number of
elements and is therefore invalid, or the host/hash portion contains
more items than just the host and hash.
@raise BadKeyError: if the key, once decoded from base64, is not
actually an SSH key.
@return: The newly created L{HashedEntry} instance, initialized with the
information from C{string}.
"""
stuff, keyType, key, comment = _extractCommon(string)
saltAndHash = stuff[len(cls.MAGIC):].split(b"|")
if len(saltAndHash) != 2:
raise InvalidEntry()
hostSalt, hostHash = saltAndHash
self = cls(a2b_base64(hostSalt), a2b_base64(hostHash),
keyType, key, comment)
return self
def matchesHost(self, hostname):
"""
Implement L{IKnownHostEntry.matchesHost} to compare the hash of the
input to the stored hash.
@param hostname: A hostname or IP address literal to check against this
entry.
@type hostname: L{bytes}
@return: C{True} if this entry is for the given hostname or IP address,
C{False} otherwise.
@rtype: L{bool}
"""
return (_hmacedString(self._hostSalt, hostname) == self._hostHash)
def toString(self):
"""
Implement L{IKnownHostEntry.toString} by base64-encoding the salt, host
hash, and key.
@return: The string representation of this entry, with the hostname part
hashed.
@rtype: L{bytes}
"""
fields = [self.MAGIC + b'|'.join([_b64encode(self._hostSalt),
_b64encode(self._hostHash)]),
self.keyType,
_b64encode(self.publicKey.blob())]
if self.comment is not None:
fields.append(self.comment)
return b' '.join(fields)
class KnownHostsFile(object):
"""
A structured representation of an OpenSSH-format ~/.ssh/known_hosts file.
@ivar _added: A list of L{IKnownHostEntry} providers which have been added
to this instance in memory but not yet saved.
@ivar _clobber: A flag indicating whether the current contents of the save
path will be disregarded and potentially overwritten or not. If
C{True}, this will be done. If C{False}, entries in the save path will
be read and new entries will be saved by appending rather than
overwriting.
@type _clobber: L{bool}
@ivar _savePath: See C{savePath} parameter of L{__init__}.
"""
def __init__(self, savePath):
"""
Create a new, empty KnownHostsFile.
Unless you want to erase the current contents of C{savePath}, you want
to use L{KnownHostsFile.fromPath} instead.
@param savePath: The L{FilePath} to which to save new entries.
@type savePath: L{FilePath}
"""
self._added = []
self._savePath = savePath
self._clobber = True
@property
def savePath(self):
"""
@see: C{savePath} parameter of L{__init__}
"""
return self._savePath
def iterentries(self):
"""
Iterate over the host entries in this file.
@return: An iterable the elements of which provide L{IKnownHostEntry}.
There is an element for each entry in the file as well as an element
for each added but not yet saved entry.
@rtype: iterable of L{IKnownHostEntry} providers
"""
for entry in self._added:
yield entry
if self._clobber:
return
try:
fp = self._savePath.open()
except IOError:
return
with fp:
for line in fp:
try:
if line.startswith(HashedEntry.MAGIC):
entry = HashedEntry.fromString(line)
else:
entry = PlainEntry.fromString(line)
except (DecodeError, InvalidEntry, BadKeyError):
entry = UnparsedEntry(line)
yield entry
def hasHostKey(self, hostname, key):
"""
Check for an entry with matching hostname and key.
@param hostname: A hostname or IP address literal to check for.
@type hostname: L{bytes}
@param key: The public key to check for.
@type key: L{Key}
@return: C{True} if the given hostname and key are present in this file,
C{False} if they are not.
@rtype: L{bool}
@raise HostKeyChanged: if the host key found for the given hostname
does not match the given key.
"""
for lineidx, entry in enumerate(self.iterentries(), -len(self._added)):
if entry.matchesHost(hostname) and entry.keyType == key.sshType():
if entry.matchesKey(key):
return True
else:
# Notice that lineidx is 0-based but HostKeyChanged.lineno
# is 1-based.
if lineidx < 0:
line = None
path = None
else:
line = lineidx + 1
path = self._savePath
raise HostKeyChanged(entry, path, line)
return False
def verifyHostKey(self, ui, hostname, ip, key):
"""
Verify the given host key for the given IP and host, asking for
confirmation from, and notifying, the given UI about changes to this
file.
@param ui: The user interface to request an IP address from.
@param hostname: The hostname that the user requested to connect to.
@param ip: The string representation of the IP address that is actually
being connected to.
@param key: The public key of the server.
@return: a L{Deferred} that fires with True when the key has been
verified, or fires with an errback when the key either cannot be
verified or has changed.
@rtype: L{Deferred}
"""
hhk = defer.maybeDeferred(self.hasHostKey, hostname, key)
def gotHasKey(result):
if result:
if not self.hasHostKey(ip, key):
ui.warn("Warning: Permanently added the %s host key for "
"IP address '%s' to the list of known hosts." %
(key.type(), nativeString(ip)))
self.addHostKey(ip, key)
self.save()
return result
else:
def promptResponse(response):
if response:
self.addHostKey(hostname, key)
self.addHostKey(ip, key)
self.save()
return response
else:
raise UserRejectedKey()
keytype = key.type()
if keytype == "EC":
keytype = "ECDSA"
prompt = (
"The authenticity of host '%s (%s)' "
"can't be established.\n"
"%s key fingerprint is SHA256:%s.\n"
"Are you sure you want to continue connecting (yes/no)? " %
(nativeString(hostname), nativeString(ip), keytype,
key.fingerprint(format=FingerprintFormats.SHA256_BASE64)))
proceed = ui.prompt(prompt.encode(sys.getdefaultencoding()))
return proceed.addCallback(promptResponse)
return hhk.addCallback(gotHasKey)
def addHostKey(self, hostname, key):
"""
Add a new L{HashedEntry} to the key database.
Note that you still need to call L{KnownHostsFile.save} if you wish
these changes to be persisted.
@param hostname: A hostname or IP address literal to associate with the
new entry.
@type hostname: L{bytes}
@param key: The public key to associate with the new entry.
@type key: L{Key}
@return: The L{HashedEntry} that was added.
@rtype: L{HashedEntry}
"""
salt = secureRandom(20)
keyType = key.sshType()
entry = HashedEntry(salt, _hmacedString(salt, hostname),
keyType, key, None)
self._added.append(entry)
return entry
def save(self):
"""
Save this L{KnownHostsFile} to the path it was loaded from.
"""
p = self._savePath.parent()
if not p.isdir():
p.makedirs()
if self._clobber:
mode = "wb"
else:
mode = "ab"
with self._savePath.open(mode) as hostsFileObj:
if self._added:
hostsFileObj.write(
b"\n".join([entry.toString() for entry in self._added]) +
b"\n")
self._added = []
self._clobber = False
@classmethod
def fromPath(cls, path):
"""
Create a new L{KnownHostsFile}, potentially reading existing known
hosts information from the given file.
@param path: A path object to use for both reading contents from and
later saving to. If no file exists at this path, it is not an
error; a L{KnownHostsFile} with no entries is returned.
@type path: L{FilePath}
@return: A L{KnownHostsFile} initialized with entries from C{path}.
@rtype: L{KnownHostsFile}
"""
knownHosts = cls(path)
knownHosts._clobber = False
return knownHosts
class ConsoleUI(object):
"""
A UI object that can ask true/false questions and post notifications on the
console, to be used during key verification.
"""
def __init__(self, opener):
"""
@param opener: A no-argument callable which should open a console
binary-mode file-like object to be used for reading and writing.
This initializes the C{opener} attribute.
@type opener: callable taking no arguments and returning a read/write
file-like object
"""
self.opener = opener
def prompt(self, text):
"""
Write the given text as a prompt to the console output, then read a
result from the console input.
@param text: Something to present to a user to solicit a yes or no
response.
@type text: L{bytes}
@return: a L{Deferred} which fires with L{True} when the user answers
'yes' and L{False} when the user answers 'no'. It may errback if
there were any I/O errors.
"""
d = defer.succeed(None)
def body(ignored):
with closing(self.opener()) as f:
f.write(text)
while True:
answer = f.readline().strip().lower()
if answer == b'yes':
return True
elif answer == b'no':
return False
else:
f.write(b"Please type 'yes' or 'no': ")
return d.addCallback(body)
def warn(self, text):
"""
Notify the user (non-interactively) of the provided text, by writing it
to the console.
@param text: Some information the user is to be made aware of.
@type text: L{bytes}
"""
try:
with closing(self.opener()) as f:
f.write(text)
except:
log.err()

View file

@ -0,0 +1,103 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
#
from twisted.conch.ssh.transport import SSHClientTransport, SSHCiphers
from twisted.python import usage
from twisted.python.compat import unicode
import sys
class ConchOptions(usage.Options):
optParameters = [['user', 'l', None, 'Log in using this user name.'],
['identity', 'i', None],
['ciphers', 'c', None],
['macs', 'm', None],
['port', 'p', None, 'Connect to this port. Server must be on the same port.'],
['option', 'o', None, 'Ignored OpenSSH options'],
['host-key-algorithms', '', None],
['known-hosts', '', None, 'File to check for host keys'],
['user-authentications', '', None, 'Types of user authentications to use.'],
['logfile', '', None, 'File to log to, or - for stdout'],
]
optFlags = [['version', 'V', 'Display version number only.'],
['compress', 'C', 'Enable compression.'],
['log', 'v', 'Enable logging (defaults to stderr)'],
['nox11', 'x', 'Disable X11 connection forwarding (default)'],
['agent', 'A', 'Enable authentication agent forwarding'],
['noagent', 'a', 'Disable authentication agent forwarding (default)'],
['reconnect', 'r', 'Reconnect to the server if the connection is lost.'],
]
compData = usage.Completions(
mutuallyExclusive=[("agent", "noagent")],
optActions={
"user": usage.CompleteUsernames(),
"ciphers": usage.CompleteMultiList(
SSHCiphers.cipherMap.keys(),
descr='ciphers to choose from'),
"macs": usage.CompleteMultiList(
SSHCiphers.macMap.keys(),
descr='macs to choose from'),
"host-key-algorithms": usage.CompleteMultiList(
SSHClientTransport.supportedPublicKeys,
descr='host key algorithms to choose from'),
#"user-authentications": usage.CompleteMultiList(?
# descr='user authentication types' ),
},
extraActions=[usage.CompleteUserAtHost(),
usage.Completer(descr="command"),
usage.Completer(descr='argument',
repeat=True)]
)
def __init__(self, *args, **kw):
usage.Options.__init__(self, *args, **kw)
self.identitys = []
self.conns = None
def opt_identity(self, i):
"""Identity for public-key authentication"""
self.identitys.append(i)
def opt_ciphers(self, ciphers):
"Select encryption algorithms"
ciphers = ciphers.split(',')
for cipher in ciphers:
if cipher not in SSHCiphers.cipherMap:
sys.exit("Unknown cipher type '%s'" % cipher)
self['ciphers'] = ciphers
def opt_macs(self, macs):
"Specify MAC algorithms"
if isinstance(macs, unicode):
macs = macs.encode("utf-8")
macs = macs.split(b',')
for mac in macs:
if mac not in SSHCiphers.macMap:
sys.exit("Unknown mac type '%r'" % mac)
self['macs'] = macs
def opt_host_key_algorithms(self, hkas):
"Select host key algorithms"
if isinstance(hkas, unicode):
hkas = hkas.encode("utf-8")
hkas = hkas.split(b',')
for hka in hkas:
if hka not in SSHClientTransport.supportedPublicKeys:
sys.exit("Unknown host key type '%r'" % hka)
self['host-key-algorithms'] = hkas
def opt_user_authentications(self, uas):
"Choose how to authenticate to the remote server"
if isinstance(uas, unicode):
uas = uas.encode("utf-8")
self['user-authentications'] = uas.split(b',')
# def opt_compress(self):
# "Enable compression"
# self.enableCompression = 1
# SSHClientTransport.supportedCompressions[0:1] = ['zlib']

View file

@ -0,0 +1,872 @@
# -*- test-case-name: twisted.conch.test.test_endpoints -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Endpoint implementations of various SSH interactions.
"""
__all__ = [
'AuthenticationFailed', 'SSHCommandAddress', 'SSHCommandClientEndpoint']
from struct import unpack
from os.path import expanduser
import signal
from zope.interface import Interface, implementer
from twisted.logger import Logger
from twisted.python.compat import nativeString, networkString
from twisted.python.filepath import FilePath
from twisted.python.failure import Failure
from twisted.internet.error import ConnectionDone, ProcessTerminated
from twisted.internet.interfaces import IStreamClientEndpoint
from twisted.internet.protocol import Factory
from twisted.internet.defer import Deferred, succeed, CancelledError
from twisted.internet.endpoints import TCP4ClientEndpoint, connectProtocol
from twisted.conch.ssh.keys import Key
from twisted.conch.ssh.common import getNS, NS
from twisted.conch.ssh.transport import SSHClientTransport
from twisted.conch.ssh.connection import SSHConnection
from twisted.conch.ssh.userauth import SSHUserAuthClient
from twisted.conch.ssh.channel import SSHChannel
from twisted.conch.client.knownhosts import ConsoleUI, KnownHostsFile
from twisted.conch.client.agent import SSHAgentClient
from twisted.conch.client.default import _KNOWN_HOSTS
class AuthenticationFailed(Exception):
"""
An SSH session could not be established because authentication was not
successful.
"""
# This should be public. See #6541.
class _ISSHConnectionCreator(Interface):
"""
An L{_ISSHConnectionCreator} knows how to create SSH connections somehow.
"""
def secureConnection():
"""
Return a new, connected, secured, but not yet authenticated instance of
L{twisted.conch.ssh.transport.SSHServerTransport} or
L{twisted.conch.ssh.transport.SSHClientTransport}.
"""
def cleanupConnection(connection, immediate):
"""
Perform cleanup necessary for a connection object previously returned
from this creator's C{secureConnection} method.
@param connection: An L{twisted.conch.ssh.transport.SSHServerTransport}
or L{twisted.conch.ssh.transport.SSHClientTransport} returned by a
previous call to C{secureConnection}. It is no longer needed by
the caller of that method and may be closed or otherwise cleaned up
as necessary.
@param immediate: If C{True} don't wait for any network communication,
just close the connection immediately and as aggressively as
necessary.
"""
class SSHCommandAddress(object):
"""
An L{SSHCommandAddress} instance represents the address of an SSH server, a
username which was used to authenticate with that server, and a command
which was run there.
@ivar server: See L{__init__}
@ivar username: See L{__init__}
@ivar command: See L{__init__}
"""
def __init__(self, server, username, command):
"""
@param server: The address of the SSH server on which the command is
running.
@type server: L{IAddress} provider
@param username: An authentication username which was used to
authenticate against the server at the given address.
@type username: L{bytes}
@param command: A command which was run in a session channel on the
server at the given address.
@type command: L{bytes}
"""
self.server = server
self.username = username
self.command = command
class _CommandChannel(SSHChannel):
"""
A L{_CommandChannel} executes a command in a session channel and connects
its input and output to an L{IProtocol} provider.
@ivar _creator: See L{__init__}
@ivar _command: See L{__init__}
@ivar _protocolFactory: See L{__init__}
@ivar _commandConnected: See L{__init__}
@ivar _protocol: An L{IProtocol} provider created using C{_protocolFactory}
which is hooked up to the running command's input and output streams.
"""
name = b'session'
_log = Logger()
def __init__(self, creator, command, protocolFactory, commandConnected):
"""
@param creator: The L{_ISSHConnectionCreator} provider which was used
to get the connection which this channel exists on.
@type creator: L{_ISSHConnectionCreator} provider
@param command: The command to be executed.
@type command: L{bytes}
@param protocolFactory: A client factory to use to build a L{IProtocol}
provider to use to associate with the running command.
@param commandConnected: A L{Deferred} to use to signal that execution
of the command has failed or that it has succeeded and the command
is now running.
@type commandConnected: L{Deferred}
"""
SSHChannel.__init__(self)
self._creator = creator
self._command = command
self._protocolFactory = protocolFactory
self._commandConnected = commandConnected
self._reason = None
def openFailed(self, reason):
"""
When the request to open a new channel to run this command in fails,
fire the C{commandConnected} deferred with a failure indicating that.
"""
self._commandConnected.errback(reason)
def channelOpen(self, ignored):
"""
When the request to open a new channel to run this command in succeeds,
issue an C{"exec"} request to run the command.
"""
command = self.conn.sendRequest(
self, b'exec', NS(self._command), wantReply=True)
command.addCallbacks(self._execSuccess, self._execFailure)
def _execFailure(self, reason):
"""
When the request to execute the command in this channel fails, fire the
C{commandConnected} deferred with a failure indicating this.
@param reason: The cause of the command execution failure.
@type reason: L{Failure}
"""
self._commandConnected.errback(reason)
def _execSuccess(self, ignored):
"""
When the request to execute the command in this channel succeeds, use
C{protocolFactory} to build a protocol to handle the command's input
and output and connect the protocol to a transport representing those
streams.
Also fire C{commandConnected} with the created protocol after it is
connected to its transport.
@param ignored: The (ignored) result of the execute request
"""
self._protocol = self._protocolFactory.buildProtocol(
SSHCommandAddress(
self.conn.transport.transport.getPeer(),
self.conn.transport.creator.username,
self.conn.transport.creator.command))
self._protocol.makeConnection(self)
self._commandConnected.callback(self._protocol)
def dataReceived(self, data):
"""
When the command's stdout data arrives over the channel, deliver it to
the protocol instance.
@param data: The bytes from the command's stdout.
@type data: L{bytes}
"""
self._protocol.dataReceived(data)
def request_exit_status(self, data):
"""
When the server sends the command's exit status, record it for later
delivery to the protocol.
@param data: The network-order four byte representation of the exit
status of the command.
@type data: L{bytes}
"""
(status,) = unpack('>L', data)
if status != 0:
self._reason = ProcessTerminated(status, None, None)
def request_exit_signal(self, data):
"""
When the server sends the command's exit status, record it for later
delivery to the protocol.
@param data: The network-order four byte representation of the exit
signal of the command.
@type data: L{bytes}
"""
shortSignalName, data = getNS(data)
coreDumped, data = bool(ord(data[0:1])), data[1:]
errorMessage, data = getNS(data)
languageTag, data = getNS(data)
signalName = "SIG%s" % (nativeString(shortSignalName),)
signalID = getattr(signal, signalName, -1)
self._log.info(
"Process exited with signal {shortSignalName!r};"
" core dumped: {coreDumped};"
" error message: {errorMessage};"
" language: {languageTag!r}",
shortSignalName=shortSignalName,
coreDumped=coreDumped,
errorMessage=errorMessage.decode('utf-8'),
languageTag=languageTag,
)
self._reason = ProcessTerminated(None, signalID, None)
def closed(self):
"""
When the channel closes, deliver disconnection notification to the
protocol.
"""
self._creator.cleanupConnection(self.conn, False)
if self._reason is None:
reason = ConnectionDone("ssh channel closed")
else:
reason = self._reason
self._protocol.connectionLost(Failure(reason))
class _ConnectionReady(SSHConnection):
"""
L{_ConnectionReady} is an L{SSHConnection} (an SSH service) which only
propagates the I{serviceStarted} event to a L{Deferred} to be handled
elsewhere.
"""
def __init__(self, ready):
"""
@param ready: A L{Deferred} which should be fired when
I{serviceStarted} happens.
"""
SSHConnection.__init__(self)
self._ready = ready
def serviceStarted(self):
"""
When the SSH I{connection} I{service} this object represents is ready
to be used, fire the C{connectionReady} L{Deferred} to publish that
event to some other interested party.
"""
self._ready.callback(self)
del self._ready
class _UserAuth(SSHUserAuthClient):
"""
L{_UserAuth} implements the client part of SSH user authentication in the
convenient way a user might expect if they are familiar with the
interactive I{ssh} command line client.
L{_UserAuth} supports key-based authentication, password-based
authentication, and delegating authentication to an agent.
"""
password = None
keys = None
agent = None
def getPublicKey(self):
"""
Retrieve the next public key object to offer to the server, possibly
delegating to an authentication agent if there is one.
@return: The public part of a key pair that could be used to
authenticate with the server, or L{None} if there are no more
public keys to try.
@rtype: L{twisted.conch.ssh.keys.Key} or L{None}
"""
if self.agent is not None:
return self.agent.getPublicKey()
if self.keys:
self.key = self.keys.pop(0)
else:
self.key = None
return self.key.public()
def signData(self, publicKey, signData):
"""
Extend the base signing behavior by using an SSH agent to sign the
data, if one is available.
@type publicKey: L{Key}
@type signData: L{str}
"""
if self.agent is not None:
return self.agent.signData(publicKey.blob(), signData)
else:
return SSHUserAuthClient.signData(self, publicKey, signData)
def getPrivateKey(self):
"""
Get the private part of a key pair to use for authentication. The key
corresponds to the public part most recently returned from
C{getPublicKey}.
@return: A L{Deferred} which fires with the private key.
@rtype: L{Deferred}
"""
return succeed(self.key)
def getPassword(self):
"""
Get the password to use for authentication.
@return: A L{Deferred} which fires with the password, or L{None} if the
password was not specified.
"""
if self.password is None:
return
return succeed(self.password)
def ssh_USERAUTH_SUCCESS(self, packet):
"""
Handle user authentication success in the normal way, but also make a
note of the state change on the L{_CommandTransport}.
"""
self.transport._state = b'CHANNELLING'
return SSHUserAuthClient.ssh_USERAUTH_SUCCESS(self, packet)
def connectToAgent(self, endpoint):
"""
Set up a connection to the authentication agent and trigger its
initialization.
@param endpoint: An endpoint which can be used to connect to the
authentication agent.
@type endpoint: L{IStreamClientEndpoint} provider
@return: A L{Deferred} which fires when the agent connection is ready
for use.
"""
factory = Factory()
factory.protocol = SSHAgentClient
d = endpoint.connect(factory)
def connected(agent):
self.agent = agent
return agent.getPublicKeys()
d.addCallback(connected)
return d
def loseAgentConnection(self):
"""
Disconnect the agent.
"""
if self.agent is None:
return
self.agent.transport.loseConnection()
class _CommandTransport(SSHClientTransport):
"""
L{_CommandTransport} is an SSH client I{transport} which includes a host
key verification step before it will proceed to secure the connection.
L{_CommandTransport} also knows how to set up a connection to an
authentication agent if it is told where it can connect to one.
@ivar _userauth: The L{_UserAuth} instance which is in charge of the
overall authentication process or L{None} if the SSH connection has not
reach yet the C{user-auth} service.
@type _userauth: L{_UserAuth}
"""
# STARTING -> SECURING -> AUTHENTICATING -> CHANNELLING -> RUNNING
_state = b'STARTING'
_hostKeyFailure = None
_userauth = None
def __init__(self, creator):
"""
@param creator: The L{_NewConnectionHelper} that created this
connection.
@type creator: L{_NewConnectionHelper}.
"""
self.connectionReady = Deferred(
lambda d: self.transport.abortConnection())
# Clear the reference to that deferred to help the garbage collector
# and to signal to other parts of this implementation (in particular
# connectionLost) that it has already been fired and does not need to
# be fired again.
def readyFired(result):
self.connectionReady = None
return result
self.connectionReady.addBoth(readyFired)
self.creator = creator
def verifyHostKey(self, hostKey, fingerprint):
"""
Ask the L{KnownHostsFile} provider available on the factory which
created this protocol this protocol to verify the given host key.
@return: A L{Deferred} which fires with the result of
L{KnownHostsFile.verifyHostKey}.
"""
hostname = self.creator.hostname
ip = networkString(self.transport.getPeer().host)
self._state = b'SECURING'
d = self.creator.knownHosts.verifyHostKey(
self.creator.ui, hostname, ip, Key.fromString(hostKey))
d.addErrback(self._saveHostKeyFailure)
return d
def _saveHostKeyFailure(self, reason):
"""
When host key verification fails, record the reason for the failure in
order to fire a L{Deferred} with it later.
@param reason: The cause of the host key verification failure.
@type reason: L{Failure}
@return: C{reason}
@rtype: L{Failure}
"""
self._hostKeyFailure = reason
return reason
def connectionSecure(self):
"""
When the connection is secure, start the authentication process.
"""
self._state = b'AUTHENTICATING'
command = _ConnectionReady(self.connectionReady)
self._userauth = _UserAuth(self.creator.username, command)
self._userauth.password = self.creator.password
if self.creator.keys:
self._userauth.keys = list(self.creator.keys)
if self.creator.agentEndpoint is not None:
d = self._userauth.connectToAgent(self.creator.agentEndpoint)
else:
d = succeed(None)
def maybeGotAgent(ignored):
self.requestService(self._userauth)
d.addBoth(maybeGotAgent)
def connectionLost(self, reason):
"""
When the underlying connection to the SSH server is lost, if there were
any connection setup errors, propagate them. Also, clean up the
connection to the ssh agent if one was created.
"""
if self._userauth:
self._userauth.loseAgentConnection()
if self._state == b'RUNNING' or self.connectionReady is None:
return
if self._state == b'SECURING' and self._hostKeyFailure is not None:
reason = self._hostKeyFailure
elif self._state == b'AUTHENTICATING':
reason = Failure(
AuthenticationFailed("Connection lost while authenticating"))
self.connectionReady.errback(reason)
@implementer(IStreamClientEndpoint)
class SSHCommandClientEndpoint(object):
"""
L{SSHCommandClientEndpoint} exposes the command-executing functionality of
SSH servers.
L{SSHCommandClientEndpoint} can set up a new SSH connection, authenticate
it in any one of a number of different ways (keys, passwords, agents),
launch a command over that connection and then associate its input and
output with a protocol.
It can also re-use an existing, already-authenticated SSH connection
(perhaps one which already has some SSH channels being used for other
purposes). In this case it creates a new SSH channel to use to execute the
command. Notably this means it supports multiplexing several different
command invocations over a single SSH connection.
"""
def __init__(self, creator, command):
"""
@param creator: An L{_ISSHConnectionCreator} provider which will be
used to set up the SSH connection which will be used to run a
command.
@type creator: L{_ISSHConnectionCreator} provider
@param command: The command line to execute on the SSH server. This
byte string is interpreted by a shell on the SSH server, so it may
have a value like C{"ls /"}. Take care when trying to run a
command like C{"/Volumes/My Stuff/a-program"} - spaces (and other
special bytes) may require escaping.
@type command: L{bytes}
"""
self._creator = creator
self._command = command
@classmethod
def newConnection(cls, reactor, command, username, hostname, port=None,
keys=None, password=None, agentEndpoint=None,
knownHosts=None, ui=None):
"""
Create and return a new endpoint which will try to create a new
connection to an SSH server and run a command over it. It will also
close the connection if there are problems leading up to the command
being executed, after the command finishes, or if the connection
L{Deferred} is cancelled.
@param reactor: The reactor to use to establish the connection.
@type reactor: L{IReactorTCP} provider
@param command: See L{__init__}'s C{command} argument.
@param username: The username with which to authenticate to the SSH
server.
@type username: L{bytes}
@param hostname: The hostname of the SSH server.
@type hostname: L{bytes}
@param port: The port number of the SSH server. By default, the
standard SSH port number is used.
@type port: L{int}
@param keys: Private keys with which to authenticate to the SSH server,
if key authentication is to be attempted (otherwise L{None}).
@type keys: L{list} of L{Key}
@param password: The password with which to authenticate to the SSH
server, if password authentication is to be attempted (otherwise
L{None}).
@type password: L{bytes} or L{None}
@param agentEndpoint: An L{IStreamClientEndpoint} provider which may be
used to connect to an SSH agent, if one is to be used to help with
authentication.
@type agentEndpoint: L{IStreamClientEndpoint} provider
@param knownHosts: The currently known host keys, used to check the
host key presented by the server we actually connect to.
@type knownHosts: L{KnownHostsFile}
@param ui: An object for interacting with users to make decisions about
whether to accept the server host keys. If L{None}, a L{ConsoleUI}
connected to /dev/tty will be used; if /dev/tty is unavailable, an
object which answers C{b"no"} to all prompts will be used.
@type ui: L{None} or L{ConsoleUI}
@return: A new instance of C{cls} (probably
L{SSHCommandClientEndpoint}).
"""
helper = _NewConnectionHelper(
reactor, hostname, port, command, username, keys, password,
agentEndpoint, knownHosts, ui)
return cls(helper, command)
@classmethod
def existingConnection(cls, connection, command):
"""
Create and return a new endpoint which will try to open a new channel
on an existing SSH connection and run a command over it. It will
B{not} close the connection if there is a problem executing the command
or after the command finishes.
@param connection: An existing connection to an SSH server.
@type connection: L{SSHConnection}
@param command: See L{SSHCommandClientEndpoint.newConnection}'s
C{command} parameter.
@type command: L{bytes}
@return: A new instance of C{cls} (probably
L{SSHCommandClientEndpoint}).
"""
helper = _ExistingConnectionHelper(connection)
return cls(helper, command)
def connect(self, protocolFactory):
"""
Set up an SSH connection, use a channel from that connection to launch
a command, and hook the stdin and stdout of that command up as a
transport for a protocol created by the given factory.
@param protocolFactory: A L{Factory} to use to create the protocol
which will be connected to the stdin and stdout of the command on
the SSH server.
@return: A L{Deferred} which will fire with an error if the connection
cannot be set up for any reason or with the protocol instance
created by C{protocolFactory} once it has been connected to the
command.
"""
d = self._creator.secureConnection()
d.addCallback(self._executeCommand, protocolFactory)
return d
def _executeCommand(self, connection, protocolFactory):
"""
Given a secured SSH connection, try to execute a command in a new
channel created on it and associate the result with a protocol from the
given factory.
@param connection: See L{SSHCommandClientEndpoint.existingConnection}'s
C{connection} parameter.
@param protocolFactory: See L{SSHCommandClientEndpoint.connect}'s
C{protocolFactory} parameter.
@return: See L{SSHCommandClientEndpoint.connect}'s return value.
"""
commandConnected = Deferred()
def disconnectOnFailure(passthrough):
# Close the connection immediately in case of cancellation, since
# that implies user wants it gone immediately (e.g. a timeout):
immediate = passthrough.check(CancelledError)
self._creator.cleanupConnection(connection, immediate)
return passthrough
commandConnected.addErrback(disconnectOnFailure)
channel = _CommandChannel(
self._creator, self._command, protocolFactory, commandConnected)
connection.openChannel(channel)
return commandConnected
class _ReadFile(object):
"""
A weakly file-like object which can be used with L{KnownHostsFile} to
respond in the negative to all prompts for decisions.
"""
def __init__(self, contents):
"""
@param contents: L{bytes} which will be returned from every C{readline}
call.
"""
self._contents = contents
def write(self, data):
"""
No-op.
@param data: ignored
"""
def readline(self, count=-1):
"""
Always give back the byte string that this L{_ReadFile} was initialized
with.
@param count: ignored
@return: A fixed byte-string.
@rtype: L{bytes}
"""
return self._contents
def close(self):
"""
No-op.
"""
@implementer(_ISSHConnectionCreator)
class _NewConnectionHelper(object):
"""
L{_NewConnectionHelper} implements L{_ISSHConnectionCreator} by
establishing a brand new SSH connection, securing it, and authenticating.
"""
_KNOWN_HOSTS = _KNOWN_HOSTS
port = 22
def __init__(self, reactor, hostname, port, command, username, keys,
password, agentEndpoint, knownHosts, ui,
tty=FilePath(b"/dev/tty")):
"""
@param tty: The path of the tty device to use in case C{ui} is L{None}.
@type tty: L{FilePath}
@see: L{SSHCommandClientEndpoint.newConnection}
"""
self.reactor = reactor
self.hostname = hostname
if port is not None:
self.port = port
self.command = command
self.username = username
self.keys = keys
self.password = password
self.agentEndpoint = agentEndpoint
if knownHosts is None:
knownHosts = self._knownHosts()
self.knownHosts = knownHosts
if ui is None:
ui = ConsoleUI(self._opener)
self.ui = ui
self.tty = tty
def _opener(self):
"""
Open the tty if possible, otherwise give back a file-like object from
which C{b"no"} can be read.
For use as the opener argument to L{ConsoleUI}.
"""
try:
return self.tty.open("rb+")
except:
# Give back a file-like object from which can be read a byte string
# that KnownHostsFile recognizes as rejecting some option (b"no").
return _ReadFile(b"no")
@classmethod
def _knownHosts(cls):
"""
@return: A L{KnownHostsFile} instance pointed at the user's personal
I{known hosts} file.
@type: L{KnownHostsFile}
"""
return KnownHostsFile.fromPath(FilePath(expanduser(cls._KNOWN_HOSTS)))
def secureConnection(self):
"""
Create and return a new SSH connection which has been secured and on
which authentication has already happened.
@return: A L{Deferred} which fires with the ready-to-use connection or
with a failure if something prevents the connection from being
setup, secured, or authenticated.
"""
protocol = _CommandTransport(self)
ready = protocol.connectionReady
sshClient = TCP4ClientEndpoint(
self.reactor, nativeString(self.hostname), self.port)
d = connectProtocol(sshClient, protocol)
d.addCallback(lambda ignored: ready)
return d
def cleanupConnection(self, connection, immediate):
"""
Clean up the connection by closing it. The command running on the
endpoint has ended so the connection is no longer needed.
@param connection: The L{SSHConnection} to close.
@type connection: L{SSHConnection}
@param immediate: Whether to close connection immediately.
@type immediate: L{bool}.
"""
if immediate:
# We're assuming the underlying connection is an ITCPTransport,
# which is what the current implementation is restricted to:
connection.transport.transport.abortConnection()
else:
connection.transport.loseConnection()
@implementer(_ISSHConnectionCreator)
class _ExistingConnectionHelper(object):
"""
L{_ExistingConnectionHelper} implements L{_ISSHConnectionCreator} by
handing out an existing SSH connection which is supplied to its
initializer.
"""
def __init__(self, connection):
"""
@param connection: See L{SSHCommandClientEndpoint.existingConnection}'s
C{connection} parameter.
"""
self.connection = connection
def secureConnection(self):
"""
@return: A L{Deferred} that fires synchronously with the
already-established connection object.
"""
return succeed(self.connection)
def cleanupConnection(self, connection, immediate):
"""
Do not do any cleanup on the connection. Leave that responsibility to
whatever code created it in the first place.
@param connection: The L{SSHConnection} which will not be modified in
any way.
@type connection: L{SSHConnection}
@param immediate: An argument which will be ignored.
@type immediate: L{bool}.
"""

View file

@ -0,0 +1,103 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
An error to represent bad things happening in Conch.
Maintainer: Paul Swartz
"""
from __future__ import absolute_import, division
from twisted.cred.error import UnauthorizedLogin
class ConchError(Exception):
def __init__(self, value, data = None):
Exception.__init__(self, value, data)
self.value = value
self.data = data
class NotEnoughAuthentication(Exception):
"""
This is thrown if the authentication is valid, but is not enough to
successfully verify the user. i.e. don't retry this type of
authentication, try another one.
"""
class ValidPublicKey(UnauthorizedLogin):
"""
Raised by public key checkers when they receive public key credentials
that don't contain a signature at all, but are valid in every other way.
(e.g. the public key matches one in the user's authorized_keys file).
Protocol code (eg
L{SSHUserAuthServer<twisted.conch.ssh.userauth.SSHUserAuthServer>}) which
attempts to log in using
L{ISSHPrivateKey<twisted.cred.credentials.ISSHPrivateKey>} credentials
should be prepared to handle a failure of this type by telling the user to
re-authenticate using the same key and to include a signature with the new
attempt.
See U{http://www.ietf.org/rfc/rfc4252.txt} section 7 for more details.
"""
class IgnoreAuthentication(Exception):
"""
This is thrown to let the UserAuthServer know it doesn't need to handle the
authentication anymore.
"""
class MissingKeyStoreError(Exception):
"""
Raised if an SSHAgentServer starts receiving data without its factory
providing a keys dict on which to read/write key data.
"""
class UserRejectedKey(Exception):
"""
The user interactively rejected a key.
"""
class InvalidEntry(Exception):
"""
An entry in a known_hosts file could not be interpreted as a valid entry.
"""
class HostKeyChanged(Exception):
"""
The host key of a remote host has changed.
@ivar offendingEntry: The entry which contains the persistent host key that
disagrees with the given host key.
@type offendingEntry: L{twisted.conch.interfaces.IKnownHostEntry}
@ivar path: a reference to the known_hosts file that the offending entry
was loaded from
@type path: L{twisted.python.filepath.FilePath}
@ivar lineno: The line number of the offending entry in the given path.
@type lineno: L{int}
"""
def __init__(self, offendingEntry, path, lineno):
Exception.__init__(self)
self.offendingEntry = offendingEntry
self.path = path
self.lineno = lineno

View file

@ -0,0 +1,4 @@
"""
Insults: a replacement for Curses/S-Lang.
Very basic at the moment."""

View file

@ -0,0 +1,517 @@
# -*- test-case-name: twisted.conch.test.test_helper -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Partial in-memory terminal emulator
@author: Jp Calderone
"""
from __future__ import print_function
import re, string
from zope.interface import implementer
from incremental import Version
from twisted.internet import defer, protocol, reactor
from twisted.python import log, _textattributes
from twisted.python.compat import iterbytes
from twisted.python.deprecate import deprecated, deprecatedModuleAttribute
from twisted.conch.insults import insults
FOREGROUND = 30
BACKGROUND = 40
BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE, N_COLORS = range(9)
class _FormattingState(_textattributes._FormattingStateMixin):
"""
Represents the formatting state/attributes of a single character.
Character set, intensity, underlinedness, blinkitude, video
reversal, as well as foreground and background colors made up a
character's attributes.
"""
compareAttributes = (
'charset', 'bold', 'underline', 'blink', 'reverseVideo', 'foreground',
'background', '_subtracting')
def __init__(self, charset=insults.G0, bold=False, underline=False,
blink=False, reverseVideo=False, foreground=WHITE,
background=BLACK, _subtracting=False):
self.charset = charset
self.bold = bold
self.underline = underline
self.blink = blink
self.reverseVideo = reverseVideo
self.foreground = foreground
self.background = background
self._subtracting = _subtracting
@deprecated(Version('Twisted', 13, 1, 0))
def wantOne(self, **kw):
"""
Add a character attribute to a copy of this formatting state.
@param **kw: An optional attribute name and value can be provided with
a keyword argument.
@return: A formatting state instance with the new attribute.
@see: L{DefaultFormattingState._withAttribute}.
"""
k, v = kw.popitem()
return self._withAttribute(k, v)
def toVT102(self):
# Spit out a vt102 control sequence that will set up
# all the attributes set here. Except charset.
attrs = []
if self._subtracting:
attrs.append(0)
if self.bold:
attrs.append(insults.BOLD)
if self.underline:
attrs.append(insults.UNDERLINE)
if self.blink:
attrs.append(insults.BLINK)
if self.reverseVideo:
attrs.append(insults.REVERSE_VIDEO)
if self.foreground != WHITE:
attrs.append(FOREGROUND + self.foreground)
if self.background != BLACK:
attrs.append(BACKGROUND + self.background)
if attrs:
return '\x1b[' + ';'.join(map(str, attrs)) + 'm'
return ''
CharacterAttribute = _FormattingState
deprecatedModuleAttribute(
Version('Twisted', 13, 1, 0),
'Use twisted.conch.insults.text.assembleFormattedText instead.',
'twisted.conch.insults.helper',
'CharacterAttribute')
# XXX - need to support scroll regions and scroll history
@implementer(insults.ITerminalTransport)
class TerminalBuffer(protocol.Protocol):
"""
An in-memory terminal emulator.
"""
for keyID in (b'UP_ARROW', b'DOWN_ARROW', b'RIGHT_ARROW', b'LEFT_ARROW',
b'HOME', b'INSERT', b'DELETE', b'END', b'PGUP', b'PGDN',
b'F1', b'F2', b'F3', b'F4', b'F5', b'F6', b'F7', b'F8', b'F9',
b'F10', b'F11', b'F12'):
execBytes = keyID + b" = object()"
execStr = execBytes.decode("ascii")
exec(execStr)
TAB = b'\t'
BACKSPACE = b'\x7f'
width = 80
height = 24
fill = b' '
void = object()
def getCharacter(self, x, y):
return self.lines[y][x]
def connectionMade(self):
self.reset()
def write(self, data):
"""
Add the given printable bytes to the terminal.
Line feeds in L{bytes} will be replaced with carriage return / line
feed pairs.
"""
for b in iterbytes(data.replace(b'\n', b'\r\n')):
self.insertAtCursor(b)
def _currentFormattingState(self):
return _FormattingState(self.activeCharset, **self.graphicRendition)
def insertAtCursor(self, b):
"""
Add one byte to the terminal at the cursor and make consequent state
updates.
If b is a carriage return, move the cursor to the beginning of the
current row.
If b is a line feed, move the cursor to the next row or scroll down if
the cursor is already in the last row.
Otherwise, if b is printable, put it at the cursor position (inserting
or overwriting as dictated by the current mode) and move the cursor.
"""
if b == b'\r':
self.x = 0
elif b == b'\n':
self._scrollDown()
elif b in string.printable.encode("ascii"):
if self.x >= self.width:
self.nextLine()
ch = (b, self._currentFormattingState())
if self.modes.get(insults.modes.IRM):
self.lines[self.y][self.x:self.x] = [ch]
self.lines[self.y].pop()
else:
self.lines[self.y][self.x] = ch
self.x += 1
def _emptyLine(self, width):
return [(self.void, self._currentFormattingState())
for i in range(width)]
def _scrollDown(self):
self.y += 1
if self.y >= self.height:
self.y -= 1
del self.lines[0]
self.lines.append(self._emptyLine(self.width))
def _scrollUp(self):
self.y -= 1
if self.y < 0:
self.y = 0
del self.lines[-1]
self.lines.insert(0, self._emptyLine(self.width))
def cursorUp(self, n=1):
self.y = max(0, self.y - n)
def cursorDown(self, n=1):
self.y = min(self.height - 1, self.y + n)
def cursorBackward(self, n=1):
self.x = max(0, self.x - n)
def cursorForward(self, n=1):
self.x = min(self.width, self.x + n)
def cursorPosition(self, column, line):
self.x = column
self.y = line
def cursorHome(self):
self.x = self.home.x
self.y = self.home.y
def index(self):
self._scrollDown()
def reverseIndex(self):
self._scrollUp()
def nextLine(self):
"""
Update the cursor position attributes and scroll down if appropriate.
"""
self.x = 0
self._scrollDown()
def saveCursor(self):
self._savedCursor = (self.x, self.y)
def restoreCursor(self):
self.x, self.y = self._savedCursor
del self._savedCursor
def setModes(self, modes):
for m in modes:
self.modes[m] = True
def resetModes(self, modes):
for m in modes:
try:
del self.modes[m]
except KeyError:
pass
def setPrivateModes(self, modes):
"""
Enable the given modes.
Track which modes have been enabled so that the implementations of
other L{insults.ITerminalTransport} methods can be properly implemented
to respect these settings.
@see: L{resetPrivateModes}
@see: L{insults.ITerminalTransport.setPrivateModes}
"""
for m in modes:
self.privateModes[m] = True
def resetPrivateModes(self, modes):
"""
Disable the given modes.
@see: L{setPrivateModes}
@see: L{insults.ITerminalTransport.resetPrivateModes}
"""
for m in modes:
try:
del self.privateModes[m]
except KeyError:
pass
def applicationKeypadMode(self):
self.keypadMode = 'app'
def numericKeypadMode(self):
self.keypadMode = 'num'
def selectCharacterSet(self, charSet, which):
self.charsets[which] = charSet
def shiftIn(self):
self.activeCharset = insults.G0
def shiftOut(self):
self.activeCharset = insults.G1
def singleShift2(self):
oldActiveCharset = self.activeCharset
self.activeCharset = insults.G2
f = self.insertAtCursor
def insertAtCursor(b):
f(b)
del self.insertAtCursor
self.activeCharset = oldActiveCharset
self.insertAtCursor = insertAtCursor
def singleShift3(self):
oldActiveCharset = self.activeCharset
self.activeCharset = insults.G3
f = self.insertAtCursor
def insertAtCursor(b):
f(b)
del self.insertAtCursor
self.activeCharset = oldActiveCharset
self.insertAtCursor = insertAtCursor
def selectGraphicRendition(self, *attributes):
for a in attributes:
if a == insults.NORMAL:
self.graphicRendition = {
'bold': False,
'underline': False,
'blink': False,
'reverseVideo': False,
'foreground': WHITE,
'background': BLACK}
elif a == insults.BOLD:
self.graphicRendition['bold'] = True
elif a == insults.UNDERLINE:
self.graphicRendition['underline'] = True
elif a == insults.BLINK:
self.graphicRendition['blink'] = True
elif a == insults.REVERSE_VIDEO:
self.graphicRendition['reverseVideo'] = True
else:
try:
v = int(a)
except ValueError:
log.msg("Unknown graphic rendition attribute: " + repr(a))
else:
if FOREGROUND <= v <= FOREGROUND + N_COLORS:
self.graphicRendition['foreground'] = v - FOREGROUND
elif BACKGROUND <= v <= BACKGROUND + N_COLORS:
self.graphicRendition['background'] = v - BACKGROUND
else:
log.msg("Unknown graphic rendition attribute: " + repr(a))
def eraseLine(self):
self.lines[self.y] = self._emptyLine(self.width)
def eraseToLineEnd(self):
width = self.width - self.x
self.lines[self.y][self.x:] = self._emptyLine(width)
def eraseToLineBeginning(self):
self.lines[self.y][:self.x + 1] = self._emptyLine(self.x + 1)
def eraseDisplay(self):
self.lines = [self._emptyLine(self.width) for i in range(self.height)]
def eraseToDisplayEnd(self):
self.eraseToLineEnd()
height = self.height - self.y - 1
self.lines[self.y + 1:] = [self._emptyLine(self.width) for i in range(height)]
def eraseToDisplayBeginning(self):
self.eraseToLineBeginning()
self.lines[:self.y] = [self._emptyLine(self.width) for i in range(self.y)]
def deleteCharacter(self, n=1):
del self.lines[self.y][self.x:self.x+n]
self.lines[self.y].extend(self._emptyLine(min(self.width - self.x, n)))
def insertLine(self, n=1):
self.lines[self.y:self.y] = [self._emptyLine(self.width) for i in range(n)]
del self.lines[self.height:]
def deleteLine(self, n=1):
del self.lines[self.y:self.y+n]
self.lines.extend([self._emptyLine(self.width) for i in range(n)])
def reportCursorPosition(self):
return (self.x, self.y)
def reset(self):
self.home = insults.Vector(0, 0)
self.x = self.y = 0
self.modes = {}
self.privateModes = {}
self.setPrivateModes([insults.privateModes.AUTO_WRAP,
insults.privateModes.CURSOR_MODE])
self.numericKeypad = 'app'
self.activeCharset = insults.G0
self.graphicRendition = {
'bold': False,
'underline': False,
'blink': False,
'reverseVideo': False,
'foreground': WHITE,
'background': BLACK}
self.charsets = {
insults.G0: insults.CS_US,
insults.G1: insults.CS_US,
insults.G2: insults.CS_ALTERNATE,
insults.G3: insults.CS_ALTERNATE_SPECIAL}
self.eraseDisplay()
def unhandledControlSequence(self, buf):
print('Could not handle', repr(buf))
def __bytes__(self):
lines = []
for L in self.lines:
buf = []
length = 0
for (ch, attr) in L:
if ch is not self.void:
buf.append(ch)
length = len(buf)
else:
buf.append(self.fill)
lines.append(b''.join(buf[:length]))
return b'\n'.join(lines)
class ExpectationTimeout(Exception):
pass
class ExpectableBuffer(TerminalBuffer):
_mark = 0
def connectionMade(self):
TerminalBuffer.connectionMade(self)
self._expecting = []
def write(self, data):
TerminalBuffer.write(self, data)
self._checkExpected()
def cursorHome(self):
TerminalBuffer.cursorHome(self)
self._mark = 0
def _timeoutExpected(self, d):
d.errback(ExpectationTimeout())
self._checkExpected()
def _checkExpected(self):
s = self.__bytes__()[self._mark:]
while self._expecting:
expr, timer, deferred = self._expecting[0]
if timer and not timer.active():
del self._expecting[0]
continue
for match in expr.finditer(s):
if timer:
timer.cancel()
del self._expecting[0]
self._mark += match.end()
s = s[match.end():]
deferred.callback(match)
break
else:
return
def expect(self, expression, timeout=None, scheduler=reactor):
d = defer.Deferred()
timer = None
if timeout:
timer = scheduler.callLater(timeout, self._timeoutExpected, d)
self._expecting.append((re.compile(expression), timer, d))
self._checkExpected()
return d
__all__ = [
'CharacterAttribute', 'TerminalBuffer', 'ExpectableBuffer']

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,176 @@
# -*- test-case-name: twisted.conch.test.test_text -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Character attribute manipulation API.
This module provides a domain-specific language (using Python syntax)
for the creation of text with additional display attributes associated
with it. It is intended as an alternative to manually building up
strings containing ECMA 48 character attribute control codes. It
currently supports foreground and background colors (black, red,
green, yellow, blue, magenta, cyan, and white), intensity selection,
underlining, blinking and reverse video. Character set selection
support is planned.
Character attributes are specified by using two Python operations:
attribute lookup and indexing. For example, the string \"Hello
world\" with red foreground and all other attributes set to their
defaults, assuming the name twisted.conch.insults.text.attributes has
been imported and bound to the name \"A\" (with the statement C{from
twisted.conch.insults.text import attributes as A}, for example) one
uses this expression::
A.fg.red[\"Hello world\"]
Other foreground colors are set by substituting their name for
\"red\". To set both a foreground and a background color, this
expression is used::
A.fg.red[A.bg.green[\"Hello world\"]]
Note that either A.bg.green can be nested within A.fg.red or vice
versa. Also note that multiple items can be nested within a single
index operation by separating them with commas::
A.bg.green[A.fg.red[\"Hello\"], " ", A.fg.blue[\"world\"]]
Other character attributes are set in a similar fashion. To specify a
blinking version of the previous expression::
A.blink[A.bg.green[A.fg.red[\"Hello\"], " ", A.fg.blue[\"world\"]]]
C{A.reverseVideo}, C{A.underline}, and C{A.bold} are also valid.
A third operation is actually supported: unary negation. This turns
off an attribute when an enclosing expression would otherwise have
caused it to be on. For example::
A.underline[A.fg.red[\"Hello\", -A.underline[\" world\"]]]
A formatting structure can then be serialized into a string containing the
necessary VT102 control codes with L{assembleFormattedText}.
@see: L{twisted.conch.insults.text._CharacterAttributes}
@author: Jp Calderone
"""
from incremental import Version
from twisted.conch.insults import helper, insults
from twisted.python import _textattributes
from twisted.python.deprecate import deprecatedModuleAttribute
flatten = _textattributes.flatten
deprecatedModuleAttribute(
Version('Twisted', 13, 1, 0),
'Use twisted.conch.insults.text.assembleFormattedText instead.',
'twisted.conch.insults.text',
'flatten')
_TEXT_COLORS = {
'black': helper.BLACK,
'red': helper.RED,
'green': helper.GREEN,
'yellow': helper.YELLOW,
'blue': helper.BLUE,
'magenta': helper.MAGENTA,
'cyan': helper.CYAN,
'white': helper.WHITE}
class _CharacterAttributes(_textattributes.CharacterAttributesMixin):
"""
Factory for character attributes, including foreground and background color
and non-color attributes such as bold, reverse video and underline.
Character attributes are applied to actual text by using object
indexing-syntax (C{obj['abc']}) after accessing a factory attribute, for
example::
attributes.bold['Some text']
These can be nested to mix attributes::
attributes.bold[attributes.underline['Some text']]
And multiple values can be passed::
attributes.normal[attributes.bold['Some'], ' text']
Non-color attributes can be accessed by attribute name, available
attributes are:
- bold
- blink
- reverseVideo
- underline
Available colors are:
0. black
1. red
2. green
3. yellow
4. blue
5. magenta
6. cyan
7. white
@ivar fg: Foreground colors accessed by attribute name, see above
for possible names.
@ivar bg: Background colors accessed by attribute name, see above
for possible names.
"""
fg = _textattributes._ColorAttribute(
_textattributes._ForegroundColorAttr, _TEXT_COLORS)
bg = _textattributes._ColorAttribute(
_textattributes._BackgroundColorAttr, _TEXT_COLORS)
attrs = {
'bold': insults.BOLD,
'blink': insults.BLINK,
'underline': insults.UNDERLINE,
'reverseVideo': insults.REVERSE_VIDEO}
def assembleFormattedText(formatted):
"""
Assemble formatted text from structured information.
Currently handled formatting includes: bold, blink, reverse, underline and
color codes.
For example::
from twisted.conch.insults.text import attributes as A
assembleFormattedText(
A.normal[A.bold['Time: '], A.fg.lightRed['Now!']])
Would produce "Time: " in bold formatting, followed by "Now!" with a
foreground color of light red and without any additional formatting.
@param formatted: Structured text and attributes.
@rtype: L{str}
@return: String containing VT102 control sequences that mimic those
specified by C{formatted}.
@see: L{twisted.conch.insults.text._CharacterAttributes}
@since: 13.1
"""
return _textattributes.flatten(
formatted, helper._FormattingState(), 'toVT102')
attributes = _CharacterAttributes()
__all__ = ['attributes', 'flatten']

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,444 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
This module contains interfaces defined for the L{twisted.conch} package.
"""
from zope.interface import Interface, Attribute
class IConchUser(Interface):
"""
A user who has been authenticated to Cred through Conch. This is
the interface between the SSH connection and the user.
"""
conn = Attribute('The SSHConnection object for this user.')
def lookupChannel(channelType, windowSize, maxPacket, data):
"""
The other side requested a channel of some sort.
C{channelType} is the type of channel being requested,
as an ssh connection protocol channel type.
C{data} is any other packet data (often nothing).
We return a subclass of L{SSHChannel<ssh.channel.SSHChannel>}. If
the channel type is unknown, we return C{None}.
For other failures, we raise an exception. If a
L{ConchError<error.ConchError>} is raised, the C{.value} will
be the message, and the C{.data} will be the error code.
@param channelType: The requested channel type
@type channelType: L{bytes}
@param windowSize: The initial size of the remote window
@type windowSize: L{int}
@param maxPacket: The largest packet we should send
@type maxPacket: L{int}
@param data: Additional request data
@type data: L{bytes}
@rtype: a subclass of L{SSHChannel} or L{None}
"""
def lookupSubsystem(subsystem, data):
"""
The other side requested a subsystem.
We return a L{Protocol} implementing the requested subsystem.
If the subsystem is not available, we return C{None}.
@param subsystem: The name of the subsystem being requested
@type subsystem: L{bytes}
@param data: Additional request data (often nothing)
@type data: L{bytes}
@rtype: L{Protocol} or L{None}
"""
def gotGlobalRequest(requestType, data):
"""
A global request was sent from the other side.
We return a true value on success or a false value on failure.
If we indicate success by returning a tuple, its second item
will be sent to the other side as additional response data.
@param requestType: The type of the request
@type requestType: L{bytes}
@param data: Additional request data
@type data: L{bytes}
@rtype: boolean or L{tuple}
"""
class ISession(Interface):
def getPty(term, windowSize, modes):
"""
Get a pseudo-terminal for use by a shell or command.
If a pseudo-terminal is not available, or the request otherwise
fails, raise an exception.
"""
def openShell(proto):
"""
Open a shell and connect it to proto.
@param proto: a L{ProcessProtocol} instance.
"""
def execCommand(proto, command):
"""
Execute a command.
@param proto: a L{ProcessProtocol} instance.
"""
def windowChanged(newWindowSize):
"""
Called when the size of the remote screen has changed.
"""
def eofReceived():
"""
Called when the other side has indicated no more data will be sent.
"""
def closed():
"""
Called when the session is closed.
"""
class ISFTPServer(Interface):
"""
SFTP subsystem for server-side communication.
Each method should check to verify that the user has permission for
their actions.
"""
avatar = Attribute(
"""
The avatar returned by the Realm that we are authenticated with,
and represents the logged-in user.
""")
def gotVersion(otherVersion, extData):
"""
Called when the client sends their version info.
otherVersion is an integer representing the version of the SFTP
protocol they are claiming.
extData is a dictionary of extended_name : extended_data items.
These items are sent by the client to indicate additional features.
This method should return a dictionary of extended_name : extended_data
items. These items are the additional features (if any) supported
by the server.
"""
return {}
def openFile(filename, flags, attrs):
"""
Called when the clients asks to open a file.
@param filename: a string representing the file to open.
@param flags: an integer of the flags to open the file with, ORed
together. The flags and their values are listed at the bottom of
L{twisted.conch.ssh.filetransfer} as FXF_*.
@param attrs: a list of attributes to open the file with. It is a
dictionary, consisting of 0 or more keys. The possible keys are::
size: the size of the file in bytes
uid: the user ID of the file as an integer
gid: the group ID of the file as an integer
permissions: the permissions of the file with as an integer.
the bit representation of this field is defined by POSIX.
atime: the access time of the file as seconds since the epoch.
mtime: the modification time of the file as seconds since the epoch.
ext_*: extended attributes. The server is not required to
understand this, but it may.
NOTE: there is no way to indicate text or binary files. it is up
to the SFTP client to deal with this.
This method returns an object that meets the ISFTPFile interface.
Alternatively, it can return a L{Deferred} that will be called back
with the object.
"""
def removeFile(filename):
"""
Remove the given file.
This method returns when the remove succeeds, or a Deferred that is
called back when it succeeds.
@param filename: the name of the file as a string.
"""
def renameFile(oldpath, newpath):
"""
Rename the given file.
This method returns when the rename succeeds, or a L{Deferred} that is
called back when it succeeds. If the rename fails, C{renameFile} will
raise an implementation-dependent exception.
@param oldpath: the current location of the file.
@param newpath: the new file name.
"""
def makeDirectory(path, attrs):
"""
Make a directory.
This method returns when the directory is created, or a Deferred that
is called back when it is created.
@param path: the name of the directory to create as a string.
@param attrs: a dictionary of attributes to create the directory with.
Its meaning is the same as the attrs in the L{openFile} method.
"""
def removeDirectory(path):
"""
Remove a directory (non-recursively)
It is an error to remove a directory that has files or directories in
it.
This method returns when the directory is removed, or a Deferred that
is called back when it is removed.
@param path: the directory to remove.
"""
def openDirectory(path):
"""
Open a directory for scanning.
This method returns an iterable object that has a close() method,
or a Deferred that is called back with same.
The close() method is called when the client is finished reading
from the directory. At this point, the iterable will no longer
be used.
The iterable should return triples of the form (filename,
longname, attrs) or Deferreds that return the same. The
sequence must support __getitem__, but otherwise may be any
'sequence-like' object.
filename is the name of the file relative to the directory.
logname is an expanded format of the filename. The recommended format
is:
-rwxr-xr-x 1 mjos staff 348911 Mar 25 14:29 t-filexfer
1234567890 123 12345678 12345678 12345678 123456789012
The first line is sample output, the second is the length of the field.
The fields are: permissions, link count, user owner, group owner,
size in bytes, modification time.
attrs is a dictionary in the format of the attrs argument to openFile.
@param path: the directory to open.
"""
def getAttrs(path, followLinks):
"""
Return the attributes for the given path.
This method returns a dictionary in the same format as the attrs
argument to openFile or a Deferred that is called back with same.
@param path: the path to return attributes for as a string.
@param followLinks: a boolean. If it is True, follow symbolic links
and return attributes for the real path at the base. If it is False,
return attributes for the specified path.
"""
def setAttrs(path, attrs):
"""
Set the attributes for the path.
This method returns when the attributes are set or a Deferred that is
called back when they are.
@param path: the path to set attributes for as a string.
@param attrs: a dictionary in the same format as the attrs argument to
L{openFile}.
"""
def readLink(path):
"""
Find the root of a set of symbolic links.
This method returns the target of the link, or a Deferred that
returns the same.
@param path: the path of the symlink to read.
"""
def makeLink(linkPath, targetPath):
"""
Create a symbolic link.
This method returns when the link is made, or a Deferred that
returns the same.
@param linkPath: the pathname of the symlink as a string.
@param targetPath: the path of the target of the link as a string.
"""
def realPath(path):
"""
Convert any path to an absolute path.
This method returns the absolute path as a string, or a Deferred
that returns the same.
@param path: the path to convert as a string.
"""
def extendedRequest(extendedName, extendedData):
"""
This is the extension mechanism for SFTP. The other side can send us
arbitrary requests.
If we don't implement the request given by extendedName, raise
NotImplementedError.
The return value is a string, or a Deferred that will be called
back with a string.
@param extendedName: the name of the request as a string.
@param extendedData: the data the other side sent with the request,
as a string.
"""
class IKnownHostEntry(Interface):
"""
A L{IKnownHostEntry} is an entry in an OpenSSH-formatted C{known_hosts}
file.
@since: 8.2
"""
def matchesKey(key):
"""
Return True if this entry matches the given Key object, False
otherwise.
@param key: The key object to match against.
@type key: L{twisted.conch.ssh.keys.Key}
"""
def matchesHost(hostname):
"""
Return True if this entry matches the given hostname, False otherwise.
Note that this does no name resolution; if you want to match an IP
address, you have to resolve it yourself, and pass it in as a dotted
quad string.
@param hostname: The hostname to match against.
@type hostname: L{str}
"""
def toString():
"""
@return: a serialized string representation of this entry, suitable for
inclusion in a known_hosts file. (Newline not included.)
@rtype: L{str}
"""
class ISFTPFile(Interface):
"""
This represents an open file on the server. An object adhering to this
interface should be returned from L{openFile}().
"""
def close():
"""
Close the file.
This method returns nothing if the close succeeds immediately, or a
Deferred that is called back when the close succeeds.
"""
def readChunk(offset, length):
"""
Read from the file.
If EOF is reached before any data is read, raise EOFError.
This method returns the data as a string, or a Deferred that is
called back with same.
@param offset: an integer that is the index to start from in the file.
@param length: the maximum length of data to return. The actual amount
returned may less than this. For normal disk files, however,
this should read the requested number (up to the end of the file).
"""
def writeChunk(offset, data):
"""
Write to the file.
This method returns when the write completes, or a Deferred that is
called when it completes.
@param offset: an integer that is the index to start from in the file.
@param data: a string that is the data to write.
"""
def getAttrs():
"""
Return the attributes for the file.
This method returns a dictionary in the same format as the attrs
argument to L{openFile} or a L{Deferred} that is called back with same.
"""
def setAttrs(attrs):
"""
Set the attributes for the file.
This method returns when the attributes are set or a Deferred that is
called back when they are.
@param attrs: a dictionary in the same format as the attrs argument to
L{openFile}.
"""

View file

@ -0,0 +1,83 @@
# -*- test-case-name: twisted.conch.test.test_cftp -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
import array
import stat
from time import time, strftime, localtime
from twisted.python.compat import _PY3
# Locale-independent month names to use instead of strftime's
_MONTH_NAMES = dict(list(zip(
list(range(1, 13)),
"Jan Feb Mar Apr May Jun Jul Aug Sep Oct Nov Dec".split())))
def lsLine(name, s):
"""
Build an 'ls' line for a file ('file' in its generic sense, it
can be of any type).
"""
mode = s.st_mode
perms = array.array('B', b'-'*10)
ft = stat.S_IFMT(mode)
if stat.S_ISDIR(ft): perms[0] = ord('d')
elif stat.S_ISCHR(ft): perms[0] = ord('c')
elif stat.S_ISBLK(ft): perms[0] = ord('b')
elif stat.S_ISREG(ft): perms[0] = ord('-')
elif stat.S_ISFIFO(ft): perms[0] = ord('f')
elif stat.S_ISLNK(ft): perms[0] = ord('l')
elif stat.S_ISSOCK(ft): perms[0] = ord('s')
else: perms[0] = ord('!')
# User
if mode&stat.S_IRUSR:perms[1] = ord('r')
if mode&stat.S_IWUSR:perms[2] = ord('w')
if mode&stat.S_IXUSR:perms[3] = ord('x')
# Group
if mode&stat.S_IRGRP:perms[4] = ord('r')
if mode&stat.S_IWGRP:perms[5] = ord('w')
if mode&stat.S_IXGRP:perms[6] = ord('x')
# Other
if mode&stat.S_IROTH:perms[7] = ord('r')
if mode&stat.S_IWOTH:perms[8] = ord('w')
if mode&stat.S_IXOTH:perms[9] = ord('x')
# Suid/sgid
if mode&stat.S_ISUID:
if perms[3] == ord('x'): perms[3] = ord('s')
else: perms[3] = ord('S')
if mode&stat.S_ISGID:
if perms[6] == ord('x'): perms[6] = ord('s')
else: perms[6] = ord('S')
if _PY3:
if isinstance(name, bytes):
name = name.decode("utf-8")
lsPerms = perms.tobytes()
lsPerms = lsPerms.decode("utf-8")
else:
lsPerms = perms.tostring()
lsresult = [
lsPerms,
str(s.st_nlink).rjust(5),
' ',
str(s.st_uid).ljust(9),
str(s.st_gid).ljust(9),
str(s.st_size).rjust(8),
' ',
]
# Need to specify the month manually, as strftime depends on locale
ttup = localtime(s.st_mtime)
sixmonths = 60 * 60 * 24 * 7 * 26
if s.st_mtime + sixmonths < time(): # Last edited more than 6mo ago
strtime = strftime("%%s %d %Y ", ttup)
else:
strtime = strftime("%%s %d %H:%M ", ttup)
lsresult.append(strtime % (_MONTH_NAMES[ttup[1]],))
lsresult.append(name)
return ''.join(lsresult)
__all__ = ['lsLine']

View file

@ -0,0 +1,401 @@
# -*- test-case-name: twisted.conch.test.test_manhole -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Line-input oriented interactive interpreter loop.
Provides classes for handling Python source input and arbitrary output
interactively from a Twisted application. Also included is syntax coloring
code with support for VT102 terminals, control code handling (^C, ^D, ^Q),
and reasonable handling of Deferreds.
@author: Jp Calderone
"""
import code, sys, tokenize
from io import BytesIO
from twisted.conch import recvline
from twisted.internet import defer
from twisted.python.compat import _tokenize, _get_async_param
from twisted.python.htmlizer import TokenPrinter
class FileWrapper:
"""
Minimal write-file-like object.
Writes are translated into addOutput calls on an object passed to
__init__. Newlines are also converted from network to local style.
"""
softspace = 0
state = 'normal'
def __init__(self, o):
self.o = o
def flush(self):
pass
def write(self, data):
self.o.addOutput(data.replace('\r\n', '\n'))
def writelines(self, lines):
self.write(''.join(lines))
class ManholeInterpreter(code.InteractiveInterpreter):
"""
Interactive Interpreter with special output and Deferred support.
Aside from the features provided by L{code.InteractiveInterpreter}, this
class captures sys.stdout output and redirects it to the appropriate
location (the Manhole protocol instance). It also treats Deferreds
which reach the top-level specially: each is formatted to the user with
a unique identifier and a new callback and errback added to it, each of
which will format the unique identifier and the result with which the
Deferred fires and then pass it on to the next participant in the
callback chain.
"""
numDeferreds = 0
def __init__(self, handler, locals=None, filename="<console>"):
code.InteractiveInterpreter.__init__(self, locals)
self._pendingDeferreds = {}
self.handler = handler
self.filename = filename
self.resetBuffer()
def resetBuffer(self):
"""
Reset the input buffer.
"""
self.buffer = []
def push(self, line):
"""
Push a line to the interpreter.
The line should not have a trailing newline; it may have
internal newlines. The line is appended to a buffer and the
interpreter's runsource() method is called with the
concatenated contents of the buffer as source. If this
indicates that the command was executed or invalid, the buffer
is reset; otherwise, the command is incomplete, and the buffer
is left as it was after the line was appended. The return
value is 1 if more input is required, 0 if the line was dealt
with in some way (this is the same as runsource()).
@param line: line of text
@type line: L{bytes}
@return: L{bool} from L{code.InteractiveInterpreter.runsource}
"""
self.buffer.append(line)
source = b"\n".join(self.buffer)
source = source.decode("utf-8")
more = self.runsource(source, self.filename)
if not more:
self.resetBuffer()
return more
def runcode(self, *a, **kw):
orighook, sys.displayhook = sys.displayhook, self.displayhook
try:
origout, sys.stdout = sys.stdout, FileWrapper(self.handler)
try:
code.InteractiveInterpreter.runcode(self, *a, **kw)
finally:
sys.stdout = origout
finally:
sys.displayhook = orighook
def displayhook(self, obj):
self.locals['_'] = obj
if isinstance(obj, defer.Deferred):
# XXX Ick, where is my "hasFired()" interface?
if hasattr(obj, "result"):
self.write(repr(obj))
elif id(obj) in self._pendingDeferreds:
self.write("<Deferred #%d>" % (self._pendingDeferreds[id(obj)][0],))
else:
d = self._pendingDeferreds
k = self.numDeferreds
d[id(obj)] = (k, obj)
self.numDeferreds += 1
obj.addCallbacks(self._cbDisplayDeferred, self._ebDisplayDeferred,
callbackArgs=(k, obj), errbackArgs=(k, obj))
self.write("<Deferred #%d>" % (k,))
elif obj is not None:
self.write(repr(obj))
def _cbDisplayDeferred(self, result, k, obj):
self.write("Deferred #%d called back: %r" % (k, result), True)
del self._pendingDeferreds[id(obj)]
return result
def _ebDisplayDeferred(self, failure, k, obj):
self.write("Deferred #%d failed: %r" % (k, failure.getErrorMessage()), True)
del self._pendingDeferreds[id(obj)]
return failure
def write(self, data, isAsync=None, **kwargs):
isAsync = _get_async_param(isAsync, **kwargs)
self.handler.addOutput(data, isAsync)
CTRL_C = b'\x03'
CTRL_D = b'\x04'
CTRL_BACKSLASH = b'\x1c'
CTRL_L = b'\x0c'
CTRL_A = b'\x01'
CTRL_E = b'\x05'
class Manhole(recvline.HistoricRecvLine):
"""
Mediator between a fancy line source and an interactive interpreter.
This accepts lines from its transport and passes them on to a
L{ManholeInterpreter}. Control commands (^C, ^D, ^\) are also handled
with something approximating their normal terminal-mode behavior. It
can optionally be constructed with a dict which will be used as the
local namespace for any code executed.
"""
namespace = None
def __init__(self, namespace=None):
recvline.HistoricRecvLine.__init__(self)
if namespace is not None:
self.namespace = namespace.copy()
def connectionMade(self):
recvline.HistoricRecvLine.connectionMade(self)
self.interpreter = ManholeInterpreter(self, self.namespace)
self.keyHandlers[CTRL_C] = self.handle_INT
self.keyHandlers[CTRL_D] = self.handle_EOF
self.keyHandlers[CTRL_L] = self.handle_FF
self.keyHandlers[CTRL_A] = self.handle_HOME
self.keyHandlers[CTRL_E] = self.handle_END
self.keyHandlers[CTRL_BACKSLASH] = self.handle_QUIT
def handle_INT(self):
"""
Handle ^C as an interrupt keystroke by resetting the current input
variables to their initial state.
"""
self.pn = 0
self.lineBuffer = []
self.lineBufferIndex = 0
self.interpreter.resetBuffer()
self.terminal.nextLine()
self.terminal.write(b"KeyboardInterrupt")
self.terminal.nextLine()
self.terminal.write(self.ps[self.pn])
def handle_EOF(self):
if self.lineBuffer:
self.terminal.write(b'\a')
else:
self.handle_QUIT()
def handle_FF(self):
"""
Handle a 'form feed' byte - generally used to request a screen
refresh/redraw.
"""
self.terminal.eraseDisplay()
self.terminal.cursorHome()
self.drawInputLine()
def handle_QUIT(self):
self.terminal.loseConnection()
def _needsNewline(self):
w = self.terminal.lastWrite
return not w.endswith(b'\n') and not w.endswith(b'\x1bE')
def addOutput(self, data, isAsync=None, **kwargs):
isAsync = _get_async_param(isAsync, **kwargs)
if isAsync:
self.terminal.eraseLine()
self.terminal.cursorBackward(len(self.lineBuffer) +
len(self.ps[self.pn]))
self.terminal.write(data)
if isAsync:
if self._needsNewline():
self.terminal.nextLine()
self.terminal.write(self.ps[self.pn])
if self.lineBuffer:
oldBuffer = self.lineBuffer
self.lineBuffer = []
self.lineBufferIndex = 0
self._deliverBuffer(oldBuffer)
def lineReceived(self, line):
more = self.interpreter.push(line)
self.pn = bool(more)
if self._needsNewline():
self.terminal.nextLine()
self.terminal.write(self.ps[self.pn])
class VT102Writer:
"""
Colorizer for Python tokens.
A series of tokens are written to instances of this object. Each is
colored in a particular way. The final line of the result of this is
generally added to the output.
"""
typeToColor = {
'identifier': b'\x1b[31m',
'keyword': b'\x1b[32m',
'parameter': b'\x1b[33m',
'variable': b'\x1b[1;33m',
'string': b'\x1b[35m',
'number': b'\x1b[36m',
'op': b'\x1b[37m'}
normalColor = b'\x1b[0m'
def __init__(self):
self.written = []
def color(self, type):
r = self.typeToColor.get(type, b'')
return r
def write(self, token, type=None):
if token and token != b'\r':
c = self.color(type)
if c:
self.written.append(c)
self.written.append(token)
if c:
self.written.append(self.normalColor)
def __bytes__(self):
s = b''.join(self.written)
return s.strip(b'\n').splitlines()[-1]
if bytes == str:
# Compat with Python 2.7
__str__ = __bytes__
def lastColorizedLine(source):
"""
Tokenize and colorize the given Python source.
Returns a VT102-format colorized version of the last line of C{source}.
@param source: Python source code
@type source: L{str} or L{bytes}
@return: L{bytes} of colorized source
"""
if not isinstance(source, bytes):
source = source.encode("utf-8")
w = VT102Writer()
p = TokenPrinter(w.write).printtoken
s = BytesIO(source)
for token in _tokenize(s.readline):
(tokenType, string, start, end, line) = token
p(tokenType, string, start, end, line)
return bytes(w)
class ColoredManhole(Manhole):
"""
A REPL which syntax colors input as users type it.
"""
def getSource(self):
"""
Return a string containing the currently entered source.
This is only the code which will be considered for execution
next.
"""
return (b'\n'.join(self.interpreter.buffer) +
b'\n' +
b''.join(self.lineBuffer))
def characterReceived(self, ch, moreCharactersComing):
if self.mode == 'insert':
self.lineBuffer.insert(self.lineBufferIndex, ch)
else:
self.lineBuffer[self.lineBufferIndex:self.lineBufferIndex+1] = [ch]
self.lineBufferIndex += 1
if moreCharactersComing:
# Skip it all, we'll get called with another character in
# like 2 femtoseconds.
return
if ch == b' ':
# Don't bother to try to color whitespace
self.terminal.write(ch)
return
source = self.getSource()
# Try to write some junk
try:
coloredLine = lastColorizedLine(source)
except tokenize.TokenError:
# We couldn't do it. Strange. Oh well, just add the character.
self.terminal.write(ch)
else:
# Success! Clear the source on this line.
self.terminal.eraseLine()
self.terminal.cursorBackward(len(self.lineBuffer) + len(self.ps[self.pn]) - 1)
# And write a new, colorized one.
self.terminal.write(self.ps[self.pn] + coloredLine)
# And move the cursor to where it belongs
n = len(self.lineBuffer) - self.lineBufferIndex
if n:
self.terminal.cursorBackward(n)

View file

@ -0,0 +1,141 @@
# -*- test-case-name: twisted.conch.test.test_manhole -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
insults/SSH integration support.
@author: Jp Calderone
"""
from zope.interface import implementer
from twisted.conch import avatar, interfaces as iconch, error as econch
from twisted.conch.ssh import factory, session
from twisted.python import components
from twisted.conch.insults import insults
class _Glue:
"""
A feeble class for making one attribute look like another.
This should be replaced with a real class at some point, probably.
Try not to write new code that uses it.
"""
def __init__(self, **kw):
self.__dict__.update(kw)
def __getattr__(self, name):
raise AttributeError(self.name, "has no attribute", name)
class TerminalSessionTransport:
def __init__(self, proto, chainedProtocol, avatar, width, height):
self.proto = proto
self.avatar = avatar
self.chainedProtocol = chainedProtocol
protoSession = self.proto.session
self.proto.makeConnection(
_Glue(write=self.chainedProtocol.dataReceived,
loseConnection=lambda: avatar.conn.sendClose(protoSession),
name="SSH Proto Transport"))
def loseConnection():
self.proto.loseConnection()
self.chainedProtocol.makeConnection(
_Glue(write=self.proto.write,
loseConnection=loseConnection,
name="Chained Proto Transport"))
# XXX TODO
# chainedProtocol is supposed to be an ITerminalTransport,
# maybe. That means perhaps its terminalProtocol attribute is
# an ITerminalProtocol, it could be. So calling terminalSize
# on that should do the right thing But it'd be nice to clean
# this bit up.
self.chainedProtocol.terminalProtocol.terminalSize(width, height)
@implementer(iconch.ISession)
class TerminalSession(components.Adapter):
transportFactory = TerminalSessionTransport
chainedProtocolFactory = insults.ServerProtocol
def getPty(self, term, windowSize, attrs):
self.height, self.width = windowSize[:2]
def openShell(self, proto):
self.transportFactory(
proto, self.chainedProtocolFactory(),
iconch.IConchUser(self.original),
self.width, self.height)
def execCommand(self, proto, cmd):
raise econch.ConchError("Cannot execute commands")
def closed(self):
pass
class TerminalUser(avatar.ConchUser, components.Adapter):
def __init__(self, original, avatarId):
components.Adapter.__init__(self, original)
avatar.ConchUser.__init__(self)
self.channelLookup[b'session'] = session.SSHSession
class TerminalRealm:
userFactory = TerminalUser
sessionFactory = TerminalSession
transportFactory = TerminalSessionTransport
chainedProtocolFactory = insults.ServerProtocol
def _getAvatar(self, avatarId):
comp = components.Componentized()
user = self.userFactory(comp, avatarId)
sess = self.sessionFactory(comp)
sess.transportFactory = self.transportFactory
sess.chainedProtocolFactory = self.chainedProtocolFactory
comp.setComponent(iconch.IConchUser, user)
comp.setComponent(iconch.ISession, sess)
return user
def __init__(self, transportFactory=None):
if transportFactory is not None:
self.transportFactory = transportFactory
def requestAvatar(self, avatarId, mind, *interfaces):
for i in interfaces:
if i is iconch.IConchUser:
return (iconch.IConchUser,
self._getAvatar(avatarId),
lambda: None)
raise NotImplementedError()
class ConchFactory(factory.SSHFactory):
publicKeys = {}
privateKeys = {}
def __init__(self, portal):
self.portal = portal

View file

@ -0,0 +1,165 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
TAP plugin for creating telnet- and ssh-accessible manhole servers.
@author: Jp Calderone
"""
from zope.interface import implementer
from twisted.internet import protocol
from twisted.application import service, strports
from twisted.cred import portal, checkers
from twisted.python import usage, filepath
from twisted.conch import manhole, manhole_ssh, telnet
from twisted.conch.insults import insults
from twisted.conch.ssh import keys
class makeTelnetProtocol:
def __init__(self, portal):
self.portal = portal
def __call__(self):
auth = telnet.AuthenticatingTelnetProtocol
args = (self.portal,)
return telnet.TelnetTransport(auth, *args)
class chainedProtocolFactory:
def __init__(self, namespace):
self.namespace = namespace
def __call__(self):
return insults.ServerProtocol(manhole.ColoredManhole, self.namespace)
@implementer(portal.IRealm)
class _StupidRealm:
def __init__(self, proto, *a, **kw):
self.protocolFactory = proto
self.protocolArgs = a
self.protocolKwArgs = kw
def requestAvatar(self, avatarId, *interfaces):
if telnet.ITelnetProtocol in interfaces:
return (telnet.ITelnetProtocol,
self.protocolFactory(*self.protocolArgs,
**self.protocolKwArgs),
lambda: None)
raise NotImplementedError()
class Options(usage.Options):
optParameters = [
["telnetPort", "t", None,
("strports description of the address on which to listen for telnet "
"connections")],
["sshPort", "s", None,
("strports description of the address on which to listen for ssh "
"connections")],
["passwd", "p", "/etc/passwd",
"name of a passwd(5)-format username/password file"],
["sshKeyDir", None, "<USER DATA DIR>",
"Directory where the autogenerated SSH key is kept."],
["sshKeyName", None, "server.key",
"Filename of the autogenerated SSH key."],
["sshKeySize", None, 4096,
"Size of the automatically generated SSH key."],
]
def __init__(self):
usage.Options.__init__(self)
self['namespace'] = None
def postOptions(self):
if self['telnetPort'] is None and self['sshPort'] is None:
raise usage.UsageError(
"At least one of --telnetPort and --sshPort must be specified")
def makeService(options):
"""
Create a manhole server service.
@type options: L{dict}
@param options: A mapping describing the configuration of
the desired service. Recognized key/value pairs are::
"telnetPort": strports description of the address on which
to listen for telnet connections. If None,
no telnet service will be started.
"sshPort": strports description of the address on which to
listen for ssh connections. If None, no ssh
service will be started.
"namespace": dictionary containing desired initial locals
for manhole connections. If None, an empty
dictionary will be used.
"passwd": Name of a passwd(5)-format username/password file.
"sshKeyDir": The folder that the SSH server key will be kept in.
"sshKeyName": The filename of the key.
"sshKeySize": The size of the key, in bits. Default is 4096.
@rtype: L{twisted.application.service.IService}
@return: A manhole service.
"""
svc = service.MultiService()
namespace = options['namespace']
if namespace is None:
namespace = {}
checker = checkers.FilePasswordDB(options['passwd'])
if options['telnetPort']:
telnetRealm = _StupidRealm(telnet.TelnetBootstrapProtocol,
insults.ServerProtocol,
manhole.ColoredManhole,
namespace)
telnetPortal = portal.Portal(telnetRealm, [checker])
telnetFactory = protocol.ServerFactory()
telnetFactory.protocol = makeTelnetProtocol(telnetPortal)
telnetService = strports.service(options['telnetPort'],
telnetFactory)
telnetService.setServiceParent(svc)
if options['sshPort']:
sshRealm = manhole_ssh.TerminalRealm()
sshRealm.chainedProtocolFactory = chainedProtocolFactory(namespace)
sshPortal = portal.Portal(sshRealm, [checker])
sshFactory = manhole_ssh.ConchFactory(sshPortal)
if options['sshKeyDir'] != "<USER DATA DIR>":
keyDir = options['sshKeyDir']
else:
from twisted.python._appdirs import getDataDirectory
keyDir = getDataDirectory()
keyLocation = filepath.FilePath(keyDir).child(options['sshKeyName'])
sshKey = keys._getPersistentRSAKey(keyLocation,
int(options['sshKeySize']))
sshFactory.publicKeys[b"ssh-rsa"] = sshKey
sshFactory.privateKeys[b"ssh-rsa"] = sshKey
sshService = strports.service(options['sshPort'], sshFactory)
sshService.setServiceParent(svc)
return svc

View file

@ -0,0 +1,55 @@
# -*- test-case-name: twisted.conch.test.test_mixin -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Experimental optimization
This module provides a single mixin class which allows protocols to
collapse numerous small writes into a single larger one.
@author: Jp Calderone
"""
from twisted.internet import reactor
class BufferingMixin:
"""
Mixin which adds write buffering.
"""
_delayedWriteCall = None
data = None
DELAY = 0.0
def schedule(self):
return reactor.callLater(self.DELAY, self.flush)
def reschedule(self, token):
token.reset(self.DELAY)
def write(self, data):
"""
Buffer some bytes to be written soon.
Every call to this function delays the real write by C{self.DELAY}
seconds. When the delay expires, all collected bytes are written
to the underlying transport using L{ITransport.writeSequence}.
"""
if self._delayedWriteCall is None:
self.data = []
self._delayedWriteCall = self.schedule()
else:
self.reschedule(self._delayedWriteCall)
self.data.append(data)
def flush(self):
"""
Flush the buffer immediately.
"""
self._delayedWriteCall = None
self.transport.writeSequence(self.data)
self.data = None

View file

@ -0,0 +1,11 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
#
"""
Support for OpenSSH configuration files.
Maintainer: Paul Swartz
"""

View file

@ -0,0 +1,72 @@
# -*- test-case-name: twisted.conch.test.test_openssh_compat -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Factory for reading openssh configuration files: public keys, private keys, and
moduli file.
"""
import os, errno
from twisted.python import log
from twisted.python.util import runAsEffectiveUser
from twisted.conch.ssh import keys, factory, common
from twisted.conch.openssh_compat import primes
class OpenSSHFactory(factory.SSHFactory):
dataRoot = '/usr/local/etc'
# For openbsd which puts moduli in a different directory from keys.
moduliRoot = '/usr/local/etc'
def getPublicKeys(self):
"""
Return the server public keys.
"""
ks = {}
for filename in os.listdir(self.dataRoot):
if filename[:9] == 'ssh_host_' and filename[-8:]=='_key.pub':
try:
k = keys.Key.fromFile(
os.path.join(self.dataRoot, filename))
t = common.getNS(k.blob())[0]
ks[t] = k
except Exception as e:
log.msg('bad public key file %s: %s' % (filename, e))
return ks
def getPrivateKeys(self):
"""
Return the server private keys.
"""
privateKeys = {}
for filename in os.listdir(self.dataRoot):
if filename[:9] == 'ssh_host_' and filename[-4:]=='_key':
fullPath = os.path.join(self.dataRoot, filename)
try:
key = keys.Key.fromFile(fullPath)
except IOError as e:
if e.errno == errno.EACCES:
# Not allowed, let's switch to root
key = runAsEffectiveUser(
0, 0, keys.Key.fromFile, fullPath)
privateKeys[key.sshType()] = key
else:
raise
except Exception as e:
log.msg('bad private key file %s: %s' % (filename, e))
else:
privateKeys[key.sshType()] = key
return privateKeys
def getPrimes(self):
try:
return primes.parseModuliFile(self.moduliRoot+'/moduli')
except IOError:
return None

View file

@ -0,0 +1,30 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
#
"""
Parsing for the moduli file, which contains Diffie-Hellman prime groups.
Maintainer: Paul Swartz
"""
from twisted.python.compat import long
def parseModuliFile(filename):
with open(filename) as f:
lines = f.readlines()
primes = {}
for l in lines:
l = l.strip()
if not l or l[0]=='#':
continue
tim, typ, tst, tri, size, gen, mod = l.split()
size = int(size) + 1
gen = long(gen)
mod = long(mod, 16)
if size not in primes:
primes[size] = []
primes[size].append((gen, mod))
return primes

View file

@ -0,0 +1,374 @@
# -*- test-case-name: twisted.conch.test.test_recvline -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Basic line editing support.
@author: Jp Calderone
"""
import string
from zope.interface import implementer
from twisted.conch.insults import insults, helper
from twisted.python import log, reflect
from twisted.python.compat import iterbytes
_counters = {}
class Logging(object):
"""
Wrapper which logs attribute lookups.
This was useful in debugging something, I guess. I forget what.
It can probably be deleted or moved somewhere more appropriate.
Nothing special going on here, really.
"""
def __init__(self, original):
self.original = original
key = reflect.qual(original.__class__)
count = _counters.get(key, 0)
_counters[key] = count + 1
self._logFile = open(key + '-' + str(count), 'w')
def __str__(self):
return str(super(Logging, self).__getattribute__('original'))
def __repr__(self):
return repr(super(Logging, self).__getattribute__('original'))
def __getattribute__(self, name):
original = super(Logging, self).__getattribute__('original')
logFile = super(Logging, self).__getattribute__('_logFile')
logFile.write(name + '\n')
return getattr(original, name)
@implementer(insults.ITerminalTransport)
class TransportSequence(object):
"""
An L{ITerminalTransport} implementation which forwards calls to
one or more other L{ITerminalTransport}s.
This is a cheap way for servers to keep track of the state they
expect the client to see, since all terminal manipulations can be
send to the real client and to a terminal emulator that lives in
the server process.
"""
for keyID in (b'UP_ARROW', b'DOWN_ARROW', b'RIGHT_ARROW', b'LEFT_ARROW',
b'HOME', b'INSERT', b'DELETE', b'END', b'PGUP', b'PGDN',
b'F1', b'F2', b'F3', b'F4', b'F5', b'F6', b'F7', b'F8',
b'F9', b'F10', b'F11', b'F12'):
execBytes = keyID + b" = object()"
execStr = execBytes.decode("ascii")
exec(execStr)
TAB = b'\t'
BACKSPACE = b'\x7f'
def __init__(self, *transports):
assert transports, (
"Cannot construct a TransportSequence with no transports")
self.transports = transports
for method in insults.ITerminalTransport:
exec("""\
def %s(self, *a, **kw):
for tpt in self.transports:
result = tpt.%s(*a, **kw)
return result
""" % (method, method))
class LocalTerminalBufferMixin(object):
"""
A mixin for RecvLine subclasses which records the state of the terminal.
This is accomplished by performing all L{ITerminalTransport} operations on both
the transport passed to makeConnection and an instance of helper.TerminalBuffer.
@ivar terminalCopy: A L{helper.TerminalBuffer} instance which efforts
will be made to keep up to date with the actual terminal
associated with this protocol instance.
"""
def makeConnection(self, transport):
self.terminalCopy = helper.TerminalBuffer()
self.terminalCopy.connectionMade()
return super(LocalTerminalBufferMixin, self).makeConnection(
TransportSequence(transport, self.terminalCopy))
def __str__(self):
return str(self.terminalCopy)
class RecvLine(insults.TerminalProtocol):
"""
L{TerminalProtocol} which adds line editing features.
Clients will be prompted for lines of input with all the usual
features: character echoing, left and right arrow support for
moving the cursor to different areas of the line buffer, backspace
and delete for removing characters, and insert for toggling
between typeover and insert mode. Tabs will be expanded to enough
spaces to move the cursor to the next tabstop (every four
characters by default). Enter causes the line buffer to be
cleared and the line to be passed to the lineReceived() method
which, by default, does nothing. Subclasses are responsible for
redrawing the input prompt (this will probably change).
"""
width = 80
height = 24
TABSTOP = 4
ps = (b'>>> ', b'... ')
pn = 0
_printableChars = string.printable.encode("ascii")
def connectionMade(self):
# A list containing the characters making up the current line
self.lineBuffer = []
# A zero-based (wtf else?) index into self.lineBuffer.
# Indicates the current cursor position.
self.lineBufferIndex = 0
t = self.terminal
# A map of keyIDs to bound instance methods.
self.keyHandlers = {
t.LEFT_ARROW: self.handle_LEFT,
t.RIGHT_ARROW: self.handle_RIGHT,
t.TAB: self.handle_TAB,
# Both of these should not be necessary, but figuring out
# which is necessary is a huge hassle.
b'\r': self.handle_RETURN,
b'\n': self.handle_RETURN,
t.BACKSPACE: self.handle_BACKSPACE,
t.DELETE: self.handle_DELETE,
t.INSERT: self.handle_INSERT,
t.HOME: self.handle_HOME,
t.END: self.handle_END}
self.initializeScreen()
def initializeScreen(self):
# Hmm, state sucks. Oh well.
# For now we will just take over the whole terminal.
self.terminal.reset()
self.terminal.write(self.ps[self.pn])
# XXX Note: I would prefer to default to starting in insert
# mode, however this does not seem to actually work! I do not
# know why. This is probably of interest to implementors
# subclassing RecvLine.
# XXX XXX Note: But the unit tests all expect the initial mode
# to be insert right now. Fuck, there needs to be a way to
# query the current mode or something.
# self.setTypeoverMode()
self.setInsertMode()
def currentLineBuffer(self):
s = b''.join(self.lineBuffer)
return s[:self.lineBufferIndex], s[self.lineBufferIndex:]
def setInsertMode(self):
self.mode = 'insert'
self.terminal.setModes([insults.modes.IRM])
def setTypeoverMode(self):
self.mode = 'typeover'
self.terminal.resetModes([insults.modes.IRM])
def drawInputLine(self):
"""
Write a line containing the current input prompt and the current line
buffer at the current cursor position.
"""
self.terminal.write(self.ps[self.pn] + b''.join(self.lineBuffer))
def terminalSize(self, width, height):
# XXX - Clear the previous input line, redraw it at the new
# cursor position
self.terminal.eraseDisplay()
self.terminal.cursorHome()
self.width = width
self.height = height
self.drawInputLine()
def unhandledControlSequence(self, seq):
pass
def keystrokeReceived(self, keyID, modifier):
m = self.keyHandlers.get(keyID)
if m is not None:
m()
elif keyID in self._printableChars:
self.characterReceived(keyID, False)
else:
log.msg("Received unhandled keyID: %r" % (keyID,))
def characterReceived(self, ch, moreCharactersComing):
if self.mode == 'insert':
self.lineBuffer.insert(self.lineBufferIndex, ch)
else:
self.lineBuffer[self.lineBufferIndex:self.lineBufferIndex+1] = [ch]
self.lineBufferIndex += 1
self.terminal.write(ch)
def handle_TAB(self):
n = self.TABSTOP - (len(self.lineBuffer) % self.TABSTOP)
self.terminal.cursorForward(n)
self.lineBufferIndex += n
self.lineBuffer.extend(iterbytes(b' ' * n))
def handle_LEFT(self):
if self.lineBufferIndex > 0:
self.lineBufferIndex -= 1
self.terminal.cursorBackward()
def handle_RIGHT(self):
if self.lineBufferIndex < len(self.lineBuffer):
self.lineBufferIndex += 1
self.terminal.cursorForward()
def handle_HOME(self):
if self.lineBufferIndex:
self.terminal.cursorBackward(self.lineBufferIndex)
self.lineBufferIndex = 0
def handle_END(self):
offset = len(self.lineBuffer) - self.lineBufferIndex
if offset:
self.terminal.cursorForward(offset)
self.lineBufferIndex = len(self.lineBuffer)
def handle_BACKSPACE(self):
if self.lineBufferIndex > 0:
self.lineBufferIndex -= 1
del self.lineBuffer[self.lineBufferIndex]
self.terminal.cursorBackward()
self.terminal.deleteCharacter()
def handle_DELETE(self):
if self.lineBufferIndex < len(self.lineBuffer):
del self.lineBuffer[self.lineBufferIndex]
self.terminal.deleteCharacter()
def handle_RETURN(self):
line = b''.join(self.lineBuffer)
self.lineBuffer = []
self.lineBufferIndex = 0
self.terminal.nextLine()
self.lineReceived(line)
def handle_INSERT(self):
assert self.mode in ('typeover', 'insert')
if self.mode == 'typeover':
self.setInsertMode()
else:
self.setTypeoverMode()
def lineReceived(self, line):
pass
class HistoricRecvLine(RecvLine):
"""
L{TerminalProtocol} which adds both basic line-editing features and input history.
Everything supported by L{RecvLine} is also supported by this class. In addition, the
up and down arrows traverse the input history. Each received line is automatically
added to the end of the input history.
"""
def connectionMade(self):
RecvLine.connectionMade(self)
self.historyLines = []
self.historyPosition = 0
t = self.terminal
self.keyHandlers.update({t.UP_ARROW: self.handle_UP,
t.DOWN_ARROW: self.handle_DOWN})
def currentHistoryBuffer(self):
b = tuple(self.historyLines)
return b[:self.historyPosition], b[self.historyPosition:]
def _deliverBuffer(self, buf):
if buf:
for ch in iterbytes(buf[:-1]):
self.characterReceived(ch, True)
self.characterReceived(buf[-1:], False)
def handle_UP(self):
if self.lineBuffer and self.historyPosition == len(self.historyLines):
self.historyLines.append(b''.join(self.lineBuffer))
if self.historyPosition > 0:
self.handle_HOME()
self.terminal.eraseToLineEnd()
self.historyPosition -= 1
self.lineBuffer = []
self._deliverBuffer(self.historyLines[self.historyPosition])
def handle_DOWN(self):
if self.historyPosition < len(self.historyLines) - 1:
self.handle_HOME()
self.terminal.eraseToLineEnd()
self.historyPosition += 1
self.lineBuffer = []
self._deliverBuffer(self.historyLines[self.historyPosition])
else:
self.handle_HOME()
self.terminal.eraseToLineEnd()
self.historyPosition = len(self.historyLines)
self.lineBuffer = []
self.lineBufferIndex = 0
def handle_RETURN(self):
if self.lineBuffer:
self.historyLines.append(b''.join(self.lineBuffer))
self.historyPosition = len(self.historyLines)
return RecvLine.handle_RETURN(self)

View file

@ -0,0 +1 @@
'conch scripts'

View file

@ -0,0 +1,949 @@
# -*- test-case-name: twisted.conch.test.test_cftp -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Implementation module for the I{cftp} command.
"""
from __future__ import division, print_function
import os, sys, getpass, struct, tty, fcntl, stat
import fnmatch, pwd, glob
from twisted.conch.client import connect, default, options
from twisted.conch.ssh import connection, common
from twisted.conch.ssh import channel, filetransfer
from twisted.protocols import basic
from twisted.python.compat import _PY3, unicode
from twisted.internet import reactor, stdio, defer, utils
from twisted.python import log, usage, failure
from twisted.python.filepath import FilePath
class ClientOptions(options.ConchOptions):
synopsis = """Usage: cftp [options] [user@]host
cftp [options] [user@]host[:dir[/]]
cftp [options] [user@]host[:file [localfile]]
"""
longdesc = ("cftp is a client for logging into a remote machine and "
"executing commands to send and receive file information")
optParameters = [
['buffersize', 'B', 32768, 'Size of the buffer to use for sending/receiving.'],
['batchfile', 'b', None, 'File to read commands from, or \'-\' for stdin.'],
['requests', 'R', 5, 'Number of requests to make before waiting for a reply.'],
['subsystem', 's', 'sftp', 'Subsystem/server program to connect to.']]
compData = usage.Completions(
descriptions={
"buffersize": "Size of send/receive buffer (default: 32768)"},
extraActions=[usage.CompleteUserAtHost(),
usage.CompleteFiles(descr="local file")])
def parseArgs(self, host, localPath=None):
self['remotePath'] = ''
if ':' in host:
host, self['remotePath'] = host.split(':', 1)
self['remotePath'].rstrip('/')
self['host'] = host
self['localPath'] = localPath
def run():
# import hotshot
# prof = hotshot.Profile('cftp.prof')
# prof.start()
args = sys.argv[1:]
if '-l' in args: # cvs is an idiot
i = args.index('-l')
args = args[i:i+2]+args
del args[i+2:i+4]
options = ClientOptions()
try:
options.parseOptions(args)
except usage.UsageError as u:
print('ERROR: %s' % u)
sys.exit(1)
if options['log']:
realout = sys.stdout
log.startLogging(sys.stderr)
sys.stdout = realout
else:
log.discardLogs()
doConnect(options)
reactor.run()
# prof.stop()
# prof.close()
def handleError():
global exitStatus
exitStatus = 2
try:
reactor.stop()
except: pass
log.err(failure.Failure())
raise
def doConnect(options):
# log.deferr = handleError # HACK
if '@' in options['host']:
options['user'], options['host'] = options['host'].split('@',1)
host = options['host']
if not options['user']:
options['user'] = getpass.getuser()
if not options['port']:
options['port'] = 22
else:
options['port'] = int(options['port'])
host = options['host']
port = options['port']
conn = SSHConnection()
conn.options = options
vhk = default.verifyHostKey
uao = default.SSHUserAuthClient(options['user'], options, conn)
connect.connect(host, port, options, vhk, uao).addErrback(_ebExit)
def _ebExit(f):
#global exitStatus
if hasattr(f.value, 'value'):
s = f.value.value
else:
s = str(f)
print(s)
#exitStatus = "conch: exiting with error %s" % f
try:
reactor.stop()
except: pass
def _ignore(*args): pass
class FileWrapper:
def __init__(self, f):
self.f = f
self.total = 0.0
f.seek(0, 2) # seek to the end
self.size = f.tell()
def __getattr__(self, attr):
return getattr(self.f, attr)
class StdioClient(basic.LineReceiver):
_pwd = pwd
ps = 'cftp> '
delimiter = b'\n'
reactor = reactor
def __init__(self, client, f = None):
self.client = client
self.currentDirectory = ''
self.file = f
self.useProgressBar = (not f and 1) or 0
def connectionMade(self):
self.client.realPath('').addCallback(self._cbSetCurDir)
def _cbSetCurDir(self, path):
self.currentDirectory = path
self._newLine()
def _writeToTransport(self, msg):
if isinstance(msg, unicode):
msg = msg.encode("utf-8")
return self.transport.write(msg)
def lineReceived(self, line):
if self.client.transport.localClosed:
return
if _PY3 and isinstance(line, bytes):
line = line.decode("utf-8")
log.msg('got line %s' % line)
line = line.lstrip()
if not line:
self._newLine()
return
if self.file and line.startswith('-'):
self.ignoreErrors = 1
line = line[1:]
else:
self.ignoreErrors = 0
d = self._dispatchCommand(line)
if d is not None:
d.addCallback(self._cbCommand)
d.addErrback(self._ebCommand)
def _dispatchCommand(self, line):
if ' ' in line:
command, rest = line.split(' ', 1)
rest = rest.lstrip()
else:
command, rest = line, ''
if command.startswith('!'): # command
f = self.cmd_EXEC
rest = (command[1:] + ' ' + rest).strip()
else:
command = command.upper()
log.msg('looking up cmd %s' % command)
f = getattr(self, 'cmd_%s' % command, None)
if f is not None:
return defer.maybeDeferred(f, rest)
else:
errMsg = "No command called `%s'" % (command)
self._ebCommand(failure.Failure(NotImplementedError(errMsg)))
self._newLine()
def _printFailure(self, f):
log.msg(f)
e = f.trap(NotImplementedError, filetransfer.SFTPError, OSError, IOError)
if e == NotImplementedError:
self._writeToTransport(self.cmd_HELP(''))
elif e == filetransfer.SFTPError:
errMsg = "remote error %i: %s\n" % (f.value.code, f.value.message)
self._writeToTransport(errMsg)
elif e in (OSError, IOError):
errMsg = "local error %i: %s\n" % (f.value.errno, f.value.strerror)
self._writeToTransport(errMsg)
def _newLine(self):
if self.client.transport.localClosed:
return
self._writeToTransport(self.ps)
self.ignoreErrors = 0
if self.file:
l = self.file.readline()
if not l:
self.client.transport.loseConnection()
else:
self._writeToTransport(l)
self.lineReceived(l.strip())
def _cbCommand(self, result):
if result is not None:
if isinstance(result, unicode):
result = result.encode("utf-8")
self._writeToTransport(result)
if not result.endswith(b'\n'):
self._writeToTransport(b'\n')
self._newLine()
def _ebCommand(self, f):
self._printFailure(f)
if self.file and not self.ignoreErrors:
self.client.transport.loseConnection()
self._newLine()
def cmd_CD(self, path):
path, rest = self._getFilename(path)
if not path.endswith('/'):
path += '/'
newPath = path and os.path.join(self.currentDirectory, path) or ''
d = self.client.openDirectory(newPath)
d.addCallback(self._cbCd)
d.addErrback(self._ebCommand)
return d
def _cbCd(self, directory):
directory.close()
d = self.client.realPath(directory.name)
d.addCallback(self._cbCurDir)
return d
def _cbCurDir(self, path):
self.currentDirectory = path
def cmd_CHGRP(self, rest):
grp, rest = rest.split(None, 1)
path, rest = self._getFilename(rest)
grp = int(grp)
d = self.client.getAttrs(path)
d.addCallback(self._cbSetUsrGrp, path, grp=grp)
return d
def cmd_CHMOD(self, rest):
mod, rest = rest.split(None, 1)
path, rest = self._getFilename(rest)
mod = int(mod, 8)
d = self.client.setAttrs(path, {'permissions':mod})
d.addCallback(_ignore)
return d
def cmd_CHOWN(self, rest):
usr, rest = rest.split(None, 1)
path, rest = self._getFilename(rest)
usr = int(usr)
d = self.client.getAttrs(path)
d.addCallback(self._cbSetUsrGrp, path, usr=usr)
return d
def _cbSetUsrGrp(self, attrs, path, usr=None, grp=None):
new = {}
new['uid'] = (usr is not None) and usr or attrs['uid']
new['gid'] = (grp is not None) and grp or attrs['gid']
d = self.client.setAttrs(path, new)
d.addCallback(_ignore)
return d
def cmd_GET(self, rest):
remote, rest = self._getFilename(rest)
if '*' in remote or '?' in remote: # wildcard
if rest:
local, rest = self._getFilename(rest)
if not os.path.isdir(local):
return "Wildcard get with non-directory target."
else:
local = b''
d = self._remoteGlob(remote)
d.addCallback(self._cbGetMultiple, local)
return d
if rest:
local, rest = self._getFilename(rest)
else:
local = os.path.split(remote)[1]
log.msg((remote, local))
lf = open(local, 'wb', 0)
path = FilePath(self.currentDirectory).child(remote)
d = self.client.openFile(path.path, filetransfer.FXF_READ, {})
d.addCallback(self._cbGetOpenFile, lf)
d.addErrback(self._ebCloseLf, lf)
return d
def _cbGetMultiple(self, files, local):
#if self._useProgressBar: # one at a time
# XXX this can be optimized for times w/o progress bar
return self._cbGetMultipleNext(None, files, local)
def _cbGetMultipleNext(self, res, files, local):
if isinstance(res, failure.Failure):
self._printFailure(res)
elif res:
self._writeToTransport(res)
if not res.endswith('\n'):
self._writeToTransport('\n')
if not files:
return
f = files.pop(0)[0]
lf = open(os.path.join(local, os.path.split(f)[1]), 'wb', 0)
path = FilePath(self.currentDirectory).child(f)
d = self.client.openFile(path.path, filetransfer.FXF_READ, {})
d.addCallback(self._cbGetOpenFile, lf)
d.addErrback(self._ebCloseLf, lf)
d.addBoth(self._cbGetMultipleNext, files, local)
return d
def _ebCloseLf(self, f, lf):
lf.close()
return f
def _cbGetOpenFile(self, rf, lf):
return rf.getAttrs().addCallback(self._cbGetFileSize, rf, lf)
def _cbGetFileSize(self, attrs, rf, lf):
if not stat.S_ISREG(attrs['permissions']):
rf.close()
lf.close()
return "Can't get non-regular file: %s" % rf.name
rf.size = attrs['size']
bufferSize = self.client.transport.conn.options['buffersize']
numRequests = self.client.transport.conn.options['requests']
rf.total = 0.0
dList = []
chunks = []
startTime = self.reactor.seconds()
for i in range(numRequests):
d = self._cbGetRead('', rf, lf, chunks, 0, bufferSize, startTime)
dList.append(d)
dl = defer.DeferredList(dList, fireOnOneErrback=1)
dl.addCallback(self._cbGetDone, rf, lf)
return dl
def _getNextChunk(self, chunks):
end = 0
for chunk in chunks:
if end == 'eof':
return # nothing more to get
if end != chunk[0]:
i = chunks.index(chunk)
chunks.insert(i, (end, chunk[0]))
return (end, chunk[0] - end)
end = chunk[1]
bufSize = int(self.client.transport.conn.options['buffersize'])
chunks.append((end, end + bufSize))
return (end, bufSize)
def _cbGetRead(self, data, rf, lf, chunks, start, size, startTime):
if data and isinstance(data, failure.Failure):
log.msg('get read err: %s' % data)
reason = data
reason.trap(EOFError)
i = chunks.index((start, start + size))
del chunks[i]
chunks.insert(i, (start, 'eof'))
elif data:
log.msg('get read data: %i' % len(data))
lf.seek(start)
lf.write(data)
if len(data) != size:
log.msg('got less than we asked for: %i < %i' %
(len(data), size))
i = chunks.index((start, start + size))
del chunks[i]
chunks.insert(i, (start, start + len(data)))
rf.total += len(data)
if self.useProgressBar:
self._printProgressBar(rf, startTime)
chunk = self._getNextChunk(chunks)
if not chunk:
return
else:
start, length = chunk
log.msg('asking for %i -> %i' % (start, start+length))
d = rf.readChunk(start, length)
d.addBoth(self._cbGetRead, rf, lf, chunks, start, length, startTime)
return d
def _cbGetDone(self, ignored, rf, lf):
log.msg('get done')
rf.close()
lf.close()
if self.useProgressBar:
self._writeToTransport('\n')
return "Transferred %s to %s" % (rf.name, lf.name)
def cmd_PUT(self, rest):
"""
Do an upload request for a single local file or a globing expression.
@param rest: Requested command line for the PUT command.
@type rest: L{str}
@return: A deferred which fires with L{None} when transfer is done.
@rtype: L{defer.Deferred}
"""
local, rest = self._getFilename(rest)
# FIXME: https://twistedmatrix.com/trac/ticket/7241
# Use a better check for globbing expression.
if '*' in local or '?' in local:
if rest:
remote, rest = self._getFilename(rest)
remote = os.path.join(self.currentDirectory, remote)
else:
remote = ''
files = glob.glob(local)
return self._putMultipleFiles(files, remote)
else:
if rest:
remote, rest = self._getFilename(rest)
else:
remote = os.path.split(local)[1]
return self._putSingleFile(local, remote)
def _putSingleFile(self, local, remote):
"""
Perform an upload for a single file.
@param local: Path to local file.
@type local: L{str}.
@param remote: Remote path for the request relative to current working
directory.
@type remote: L{str}
@return: A deferred which fires when transfer is done.
"""
return self._cbPutMultipleNext(None, [local], remote, single=True)
def _putMultipleFiles(self, files, remote):
"""
Perform an upload for a list of local files.
@param files: List of local files.
@type files: C{list} of L{str}.
@param remote: Remote path for the request relative to current working
directory.
@type remote: L{str}
@return: A deferred which fires when transfer is done.
"""
return self._cbPutMultipleNext(None, files, remote)
def _cbPutMultipleNext(
self, previousResult, files, remotePath, single=False):
"""
Perform an upload for the next file in the list of local files.
@param previousResult: Result form previous file form the list.
@type previousResult: L{str}
@param files: List of local files.
@type files: C{list} of L{str}
@param remotePath: Remote path for the request relative to current
working directory.
@type remotePath: L{str}
@param single: A flag which signals if this is a transfer for a single
file in which case we use the exact remote path
@type single: L{bool}
@return: A deferred which fires when transfer is done.
"""
if isinstance(previousResult, failure.Failure):
self._printFailure(previousResult)
elif previousResult:
if isinstance(previousResult, unicode):
previousResult = previousResult.encode("utf-8")
self._writeToTransport(previousResult)
if not previousResult.endswith(b'\n'):
self._writeToTransport(b'\n')
currentFile = None
while files and not currentFile:
try:
currentFile = files.pop(0)
localStream = open(currentFile, 'rb')
except:
self._printFailure(failure.Failure())
currentFile = None
# No more files to transfer.
if not currentFile:
return None
if single:
remote = remotePath
else:
name = os.path.split(currentFile)[1]
remote = os.path.join(remotePath, name)
log.msg((name, remote, remotePath))
d = self._putRemoteFile(localStream, remote)
d.addBoth(self._cbPutMultipleNext, files, remotePath)
return d
def _putRemoteFile(self, localStream, remotePath):
"""
Do an upload request.
@param localStream: Local stream from where data is read.
@type localStream: File like object.
@param remotePath: Remote path for the request relative to current working directory.
@type remotePath: L{str}
@return: A deferred which fires when transfer is done.
"""
remote = os.path.join(self.currentDirectory, remotePath)
flags = (
filetransfer.FXF_WRITE |
filetransfer.FXF_CREAT |
filetransfer.FXF_TRUNC
)
d = self.client.openFile(remote, flags, {})
d.addCallback(self._cbPutOpenFile, localStream)
d.addErrback(self._ebCloseLf, localStream)
return d
def _cbPutOpenFile(self, rf, lf):
numRequests = self.client.transport.conn.options['requests']
if self.useProgressBar:
lf = FileWrapper(lf)
dList = []
chunks = []
startTime = self.reactor.seconds()
for i in range(numRequests):
d = self._cbPutWrite(None, rf, lf, chunks, startTime)
if d:
dList.append(d)
dl = defer.DeferredList(dList, fireOnOneErrback=1)
dl.addCallback(self._cbPutDone, rf, lf)
return dl
def _cbPutWrite(self, ignored, rf, lf, chunks, startTime):
chunk = self._getNextChunk(chunks)
start, size = chunk
lf.seek(start)
data = lf.read(size)
if self.useProgressBar:
lf.total += len(data)
self._printProgressBar(lf, startTime)
if data:
d = rf.writeChunk(start, data)
d.addCallback(self._cbPutWrite, rf, lf, chunks, startTime)
return d
else:
return
def _cbPutDone(self, ignored, rf, lf):
lf.close()
rf.close()
if self.useProgressBar:
self._writeToTransport('\n')
return 'Transferred %s to %s' % (lf.name, rf.name)
def cmd_LCD(self, path):
os.chdir(path)
def cmd_LN(self, rest):
linkpath, rest = self._getFilename(rest)
targetpath, rest = self._getFilename(rest)
linkpath, targetpath = map(
lambda x: os.path.join(self.currentDirectory, x),
(linkpath, targetpath))
return self.client.makeLink(linkpath, targetpath).addCallback(_ignore)
def cmd_LS(self, rest):
# possible lines:
# ls current directory
# ls name_of_file that file
# ls name_of_directory that directory
# ls some_glob_string current directory, globbed for that string
options = []
rest = rest.split()
while rest and rest[0] and rest[0][0] == '-':
opts = rest.pop(0)[1:]
for o in opts:
if o == 'l':
options.append('verbose')
elif o == 'a':
options.append('all')
rest = ' '.join(rest)
path, rest = self._getFilename(rest)
if not path:
fullPath = self.currentDirectory + '/'
else:
fullPath = os.path.join(self.currentDirectory, path)
d = self._remoteGlob(fullPath)
d.addCallback(self._cbDisplayFiles, options)
return d
def _cbDisplayFiles(self, files, options):
files.sort()
if 'all' not in options:
files = [f for f in files if not f[0].startswith(b'.')]
if 'verbose' in options:
lines = [f[1] for f in files]
else:
lines = [f[0] for f in files]
if not lines:
return None
else:
return b'\n'.join(lines)
def cmd_MKDIR(self, path):
path, rest = self._getFilename(path)
path = os.path.join(self.currentDirectory, path)
return self.client.makeDirectory(path, {}).addCallback(_ignore)
def cmd_RMDIR(self, path):
path, rest = self._getFilename(path)
path = os.path.join(self.currentDirectory, path)
return self.client.removeDirectory(path).addCallback(_ignore)
def cmd_LMKDIR(self, path):
os.system("mkdir %s" % path)
def cmd_RM(self, path):
path, rest = self._getFilename(path)
path = os.path.join(self.currentDirectory, path)
return self.client.removeFile(path).addCallback(_ignore)
def cmd_LLS(self, rest):
os.system("ls %s" % rest)
def cmd_RENAME(self, rest):
oldpath, rest = self._getFilename(rest)
newpath, rest = self._getFilename(rest)
oldpath, newpath = map (
lambda x: os.path.join(self.currentDirectory, x),
(oldpath, newpath))
return self.client.renameFile(oldpath, newpath).addCallback(_ignore)
def cmd_EXIT(self, ignored):
self.client.transport.loseConnection()
cmd_QUIT = cmd_EXIT
def cmd_VERSION(self, ignored):
version = "SFTP version %i" % self.client.version
if isinstance(version, unicode):
version = version.encode("utf-8")
return version
def cmd_HELP(self, ignored):
return """Available commands:
cd path Change remote directory to 'path'.
chgrp gid path Change gid of 'path' to 'gid'.
chmod mode path Change mode of 'path' to 'mode'.
chown uid path Change uid of 'path' to 'uid'.
exit Disconnect from the server.
get remote-path [local-path] Get remote file.
help Get a list of available commands.
lcd path Change local directory to 'path'.
lls [ls-options] [path] Display local directory listing.
lmkdir path Create local directory.
ln linkpath targetpath Symlink remote file.
lpwd Print the local working directory.
ls [-l] [path] Display remote directory listing.
mkdir path Create remote directory.
progress Toggle progress bar.
put local-path [remote-path] Put local file.
pwd Print the remote working directory.
quit Disconnect from the server.
rename oldpath newpath Rename remote file.
rmdir path Remove remote directory.
rm path Remove remote file.
version Print the SFTP version.
? Synonym for 'help'.
"""
def cmd_PWD(self, ignored):
return self.currentDirectory
def cmd_LPWD(self, ignored):
return os.getcwd()
def cmd_PROGRESS(self, ignored):
self.useProgressBar = not self.useProgressBar
return "%ssing progess bar." % (self.useProgressBar and "U" or "Not u")
def cmd_EXEC(self, rest):
"""
Run C{rest} using the user's shell (or /bin/sh if they do not have
one).
"""
shell = self._pwd.getpwnam(getpass.getuser())[6]
if not shell:
shell = '/bin/sh'
if rest:
cmds = ['-c', rest]
return utils.getProcessOutput(shell, cmds, errortoo=1)
else:
os.system(shell)
# accessory functions
def _remoteGlob(self, fullPath):
log.msg('looking up %s' % fullPath)
head, tail = os.path.split(fullPath)
if '*' in tail or '?' in tail:
glob = 1
else:
glob = 0
if tail and not glob: # could be file or directory
# try directory first
d = self.client.openDirectory(fullPath)
d.addCallback(self._cbOpenList, '')
d.addErrback(self._ebNotADirectory, head, tail)
else:
d = self.client.openDirectory(head)
d.addCallback(self._cbOpenList, tail)
return d
def _cbOpenList(self, directory, glob):
files = []
d = directory.read()
d.addBoth(self._cbReadFile, files, directory, glob)
return d
def _ebNotADirectory(self, reason, path, glob):
d = self.client.openDirectory(path)
d.addCallback(self._cbOpenList, glob)
return d
def _cbReadFile(self, files, l, directory, glob):
if not isinstance(files, failure.Failure):
if glob:
if _PY3:
glob = glob.encode("utf-8")
l.extend([f for f in files if fnmatch.fnmatch(f[0], glob)])
else:
l.extend(files)
d = directory.read()
d.addBoth(self._cbReadFile, l, directory, glob)
return d
else:
reason = files
reason.trap(EOFError)
directory.close()
return l
def _abbrevSize(self, size):
# from http://mail.python.org/pipermail/python-list/1999-December/018395.html
_abbrevs = [
(1<<50, 'PB'),
(1<<40, 'TB'),
(1<<30, 'GB'),
(1<<20, 'MB'),
(1<<10, 'kB'),
(1, 'B')
]
for factor, suffix in _abbrevs:
if size > factor:
break
return '%.1f' % (size/factor) + suffix
def _abbrevTime(self, t):
if t > 3600: # 1 hour
hours = int(t / 3600)
t -= (3600 * hours)
mins = int(t / 60)
t -= (60 * mins)
return "%i:%02i:%02i" % (hours, mins, t)
else:
mins = int(t/60)
t -= (60 * mins)
return "%02i:%02i" % (mins, t)
def _printProgressBar(self, f, startTime):
"""
Update a console progress bar on this L{StdioClient}'s transport, based
on the difference between the start time of the operation and the
current time according to the reactor, and appropriate to the size of
the console window.
@param f: a wrapper around the file which is being written or read
@type f: L{FileWrapper}
@param startTime: The time at which the operation being tracked began.
@type startTime: L{float}
"""
diff = self.reactor.seconds() - startTime
total = f.total
try:
winSize = struct.unpack('4H',
fcntl.ioctl(0, tty.TIOCGWINSZ, '12345679'))
except IOError:
winSize = [None, 80]
if diff == 0.0:
speed = 0.0
else:
speed = total / diff
if speed:
timeLeft = (f.size - total) / speed
else:
timeLeft = 0
front = f.name
if f.size:
percentage = (total / f.size) * 100
else:
percentage = 100
back = '%3i%% %s %sps %s ' % (percentage,
self._abbrevSize(total),
self._abbrevSize(speed),
self._abbrevTime(timeLeft))
spaces = (winSize[1] - (len(front) + len(back) + 1)) * ' '
command = '\r%s%s%s' % (front, spaces, back)
self._writeToTransport(command)
def _getFilename(self, line):
"""
Parse line received as command line input and return first filename
together with the remaining line.
@param line: Arguments received from command line input.
@type line: L{str}
@return: Tupple with filename and rest. Return empty values when no path was not found.
@rtype: C{tupple}
"""
line = line.strip()
if not line:
return '', ''
if line[0] in '\'"':
ret = []
line = list(line)
try:
for i in range(1,len(line)):
c = line[i]
if c == line[0]:
return ''.join(ret), ''.join(line[i+1:]).lstrip()
elif c == '\\': # quoted character
del line[i]
if line[i] not in '\'"\\':
raise IndexError("bad quote: \\%s" % (line[i],))
ret.append(line[i])
else:
ret.append(line[i])
except IndexError:
raise IndexError("unterminated quote")
ret = line.split(None, 1)
if len(ret) == 1:
return ret[0], ''
else:
return ret[0], ret[1]
setattr(StdioClient, 'cmd_?', StdioClient.cmd_HELP)
class SSHConnection(connection.SSHConnection):
def serviceStarted(self):
self.openChannel(SSHSession())
class SSHSession(channel.SSHChannel):
name = b'session'
def channelOpen(self, foo):
log.msg('session %s open' % self.id)
if self.conn.options['subsystem'].startswith('/'):
request = 'exec'
else:
request = 'subsystem'
d = self.conn.sendRequest(self, request, \
common.NS(self.conn.options['subsystem']), wantReply=1)
d.addCallback(self._cbSubsystem)
d.addErrback(_ebExit)
def _cbSubsystem(self, result):
self.client = filetransfer.FileTransferClient()
self.client.makeConnection(self)
self.dataReceived = self.client.dataReceived
f = None
if self.conn.options['batchfile']:
fn = self.conn.options['batchfile']
if fn != '-':
f = open(fn)
self.stdio = stdio.StandardIO(StdioClient(self.client, f))
def extReceived(self, t, data):
if t==connection.EXTENDED_DATA_STDERR:
log.msg('got %s stderr data' % len(data))
sys.stderr.write(data)
sys.stderr.flush()
def eofReceived(self):
log.msg('got eof')
self.stdio.loseWriteConnection()
def closeReceived(self):
log.msg('remote side closed %s' % self)
self.conn.sendClose(self)
def closed(self):
try:
reactor.stop()
except:
pass
def stopWriting(self):
self.stdio.pauseProducing()
def startWriting(self):
self.stdio.resumeProducing()
if __name__ == '__main__':
run()

View file

@ -0,0 +1,317 @@
# -*- test-case-name: twisted.conch.test.test_ckeygen -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Implementation module for the `ckeygen` command.
"""
from __future__ import print_function
import sys, os, getpass, socket
from functools import wraps
from imp import reload
if getpass.getpass == getpass.unix_getpass:
try:
import termios # hack around broken termios
termios.tcgetattr, termios.tcsetattr
except (ImportError, AttributeError):
sys.modules['termios'] = None
reload(getpass)
from twisted.conch.ssh import keys
from twisted.python import failure, filepath, log, usage
from twisted.python.compat import raw_input, _PY3
supportedKeyTypes = dict()
def _keyGenerator(keyType):
def assignkeygenerator(keygenerator):
@wraps(keygenerator)
def wrapper(*args, **kwargs):
return keygenerator(*args, **kwargs)
supportedKeyTypes[keyType] = wrapper
return wrapper
return assignkeygenerator
class GeneralOptions(usage.Options):
synopsis = """Usage: ckeygen [options]
"""
longdesc = "ckeygen manipulates public/private keys in various ways."
optParameters = [['bits', 'b', None, 'Number of bits in the key to create.'],
['filename', 'f', None, 'Filename of the key file.'],
['type', 't', None, 'Specify type of key to create.'],
['comment', 'C', None, 'Provide new comment.'],
['newpass', 'N', None, 'Provide new passphrase.'],
['pass', 'P', None, 'Provide old passphrase.'],
['format', 'o', 'sha256-base64',
'Fingerprint format of key file.'],
['private-key-subtype', None, 'PEM',
'OpenSSH private key subtype to write ("PEM" or "v1").']]
optFlags = [['fingerprint', 'l', 'Show fingerprint of key file.'],
['changepass', 'p', 'Change passphrase of private key file.'],
['quiet', 'q', 'Quiet.'],
['no-passphrase', None, "Create the key with no passphrase."],
['showpub', 'y',
'Read private key file and print public key.']]
compData = usage.Completions(
optActions={
"type": usage.CompleteList(list(supportedKeyTypes.keys())),
"private-key-subtype": usage.CompleteList(["PEM", "v1"]),
})
def run():
options = GeneralOptions()
try:
options.parseOptions(sys.argv[1:])
except usage.UsageError as u:
print('ERROR: %s' % u)
options.opt_help()
sys.exit(1)
log.discardLogs()
log.deferr = handleError # HACK
if options['type']:
if options['type'].lower() in supportedKeyTypes:
print('Generating public/private %s key pair.' % (options['type']))
supportedKeyTypes[options['type'].lower()](options)
else:
sys.exit(
'Key type was %s, must be one of %s'
% (options['type'], ', '.join(supportedKeyTypes.keys())))
elif options['fingerprint']:
printFingerprint(options)
elif options['changepass']:
changePassPhrase(options)
elif options['showpub']:
displayPublicKey(options)
else:
options.opt_help()
sys.exit(1)
def enumrepresentation(options):
if options['format'] == 'md5-hex':
options['format'] = keys.FingerprintFormats.MD5_HEX
return options
elif options['format'] == 'sha256-base64':
options['format'] = keys.FingerprintFormats.SHA256_BASE64
return options
else:
raise keys.BadFingerPrintFormat(
'Unsupported fingerprint format: %s' % (options['format'],))
def handleError():
global exitStatus
exitStatus = 2
log.err(failure.Failure())
raise
@_keyGenerator('rsa')
def generateRSAkey(options):
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import rsa
if not options['bits']:
options['bits'] = 1024
keyPrimitive = rsa.generate_private_key(
key_size=int(options['bits']),
public_exponent=65537,
backend=default_backend(),
)
key = keys.Key(keyPrimitive)
_saveKey(key, options)
@_keyGenerator('dsa')
def generateDSAkey(options):
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import dsa
if not options['bits']:
options['bits'] = 1024
keyPrimitive = dsa.generate_private_key(
key_size=int(options['bits']),
backend=default_backend(),
)
key = keys.Key(keyPrimitive)
_saveKey(key, options)
@_keyGenerator('ecdsa')
def generateECDSAkey(options):
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import ec
if not options['bits']:
options['bits'] = 256
# OpenSSH supports only mandatory sections of RFC5656.
# See https://www.openssh.com/txt/release-5.7
curve = b'ecdsa-sha2-nistp' + str(options['bits']).encode('ascii')
keyPrimitive = ec.generate_private_key(
curve=keys._curveTable[curve],
backend=default_backend()
)
key = keys.Key(keyPrimitive)
_saveKey(key, options)
def printFingerprint(options):
if not options['filename']:
filename = os.path.expanduser('~/.ssh/id_rsa')
options['filename'] = raw_input('Enter file in which the key is (%s): ' % filename)
if os.path.exists(options['filename']+'.pub'):
options['filename'] += '.pub'
options = enumrepresentation(options)
try:
key = keys.Key.fromFile(options['filename'])
print('%s %s %s' % (
key.size(),
key.fingerprint(options['format']),
os.path.basename(options['filename'])))
except keys.BadKeyError:
sys.exit('bad key')
def changePassPhrase(options):
if not options['filename']:
filename = os.path.expanduser('~/.ssh/id_rsa')
options['filename'] = raw_input(
'Enter file in which the key is (%s): ' % filename)
try:
key = keys.Key.fromFile(options['filename'])
except keys.EncryptedKeyError:
# Raised if password not supplied for an encrypted key
if not options.get('pass'):
options['pass'] = getpass.getpass('Enter old passphrase: ')
try:
key = keys.Key.fromFile(
options['filename'], passphrase=options['pass'])
except keys.BadKeyError:
sys.exit('Could not change passphrase: old passphrase error')
except keys.EncryptedKeyError as e:
sys.exit('Could not change passphrase: %s' % (e,))
except keys.BadKeyError as e:
sys.exit('Could not change passphrase: %s' % (e,))
if not options.get('newpass'):
while 1:
p1 = getpass.getpass(
'Enter new passphrase (empty for no passphrase): ')
p2 = getpass.getpass('Enter same passphrase again: ')
if p1 == p2:
break
print('Passphrases do not match. Try again.')
options['newpass'] = p1
try:
newkeydata = key.toString(
'openssh', subtype=options.get('private-key-subtype'),
passphrase=options['newpass'])
except Exception as e:
sys.exit('Could not change passphrase: %s' % (e,))
try:
keys.Key.fromString(newkeydata, passphrase=options['newpass'])
except (keys.EncryptedKeyError, keys.BadKeyError) as e:
sys.exit('Could not change passphrase: %s' % (e,))
with open(options['filename'], 'wb') as fd:
fd.write(newkeydata)
print('Your identification has been saved with the new passphrase.')
def displayPublicKey(options):
if not options['filename']:
filename = os.path.expanduser('~/.ssh/id_rsa')
options['filename'] = raw_input('Enter file in which the key is (%s): ' % filename)
try:
key = keys.Key.fromFile(options['filename'])
except keys.EncryptedKeyError:
if not options.get('pass'):
options['pass'] = getpass.getpass('Enter passphrase: ')
key = keys.Key.fromFile(
options['filename'], passphrase = options['pass'])
displayKey = key.public().toString('openssh')
if _PY3:
displayKey = displayKey.decode("ascii")
print(displayKey)
def _saveKey(key, options):
"""
Persist a SSH key on local filesystem.
@param key: Key which is persisted on local filesystem.
@type key: C{keys.Key} implementation.
@param options:
@type options: L{dict}
"""
KeyTypeMapping = {'EC': 'ecdsa', 'RSA': 'rsa', 'DSA': 'dsa'}
keyTypeName = KeyTypeMapping[key.type()]
if not options['filename']:
defaultPath = os.path.expanduser(u'~/.ssh/id_%s' % (keyTypeName,))
newPath = raw_input(
'Enter file in which to save the key (%s): ' % (defaultPath,))
options['filename'] = newPath.strip() or defaultPath
if os.path.exists(options['filename']):
print('%s already exists.' % (options['filename'],))
yn = raw_input('Overwrite (y/n)? ')
if yn[0].lower() != 'y':
sys.exit()
if options.get('no-passphrase'):
options['pass'] = b''
elif not options['pass']:
while 1:
p1 = getpass.getpass(
'Enter passphrase (empty for no passphrase): ')
p2 = getpass.getpass('Enter same passphrase again: ')
if p1 == p2:
break
print('Passphrases do not match. Try again.')
options['pass'] = p1
comment = '%s@%s' % (getpass.getuser(), socket.gethostname())
filepath.FilePath(options['filename']).setContent(
key.toString(
'openssh', subtype=options.get('private-key-subtype'),
passphrase=options['pass']))
os.chmod(options['filename'], 33152)
filepath.FilePath(options['filename'] + '.pub').setContent(
key.public().toString('openssh', comment=comment))
options = enumrepresentation(options)
print('Your identification has been saved in %s' % (options['filename'],))
print('Your public key has been saved in %s.pub' % (options['filename'],))
print('The key fingerprint in %s is:' % (options['format'],))
print(key.fingerprint(options['format']))
if __name__ == '__main__':
run()

View file

@ -0,0 +1,585 @@
# -*- test-case-name: twisted.conch.test.test_conch -*-
#
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
#
# $Id: conch.py,v 1.65 2004/03/11 00:29:14 z3p Exp $
#""" Implementation module for the `conch` command.
#"""
from __future__ import print_function
from twisted.conch.client import connect, default, options
from twisted.conch.error import ConchError
from twisted.conch.ssh import connection, common
from twisted.conch.ssh import session, forwarding, channel
from twisted.internet import reactor, stdio, task
from twisted.python import log, usage
from twisted.python.compat import ioType, networkString, unicode
import os
import sys
import getpass
import struct
import tty
import fcntl
import signal
class ClientOptions(options.ConchOptions):
synopsis = """Usage: conch [options] host [command]
"""
longdesc = ("conch is a SSHv2 client that allows logging into a remote "
"machine and executing commands.")
optParameters = [['escape', 'e', '~'],
['localforward', 'L', None, 'listen-port:host:port Forward local port to remote address'],
['remoteforward', 'R', None, 'listen-port:host:port Forward remote port to local address'],
]
optFlags = [['null', 'n', 'Redirect input from /dev/null.'],
['fork', 'f', 'Fork to background after authentication.'],
['tty', 't', 'Tty; allocate a tty even if command is given.'],
['notty', 'T', 'Do not allocate a tty.'],
['noshell', 'N', 'Do not execute a shell or command.'],
['subsystem', 's', 'Invoke command (mandatory) as SSH2 subsystem.'],
]
compData = usage.Completions(
mutuallyExclusive=[("tty", "notty")],
optActions={
"localforward": usage.Completer(descr="listen-port:host:port"),
"remoteforward": usage.Completer(descr="listen-port:host:port")},
extraActions=[usage.CompleteUserAtHost(),
usage.Completer(descr="command"),
usage.Completer(descr="argument", repeat=True)]
)
localForwards = []
remoteForwards = []
def opt_escape(self, esc):
"""
Set escape character; ``none'' = disable
"""
if esc == 'none':
self['escape'] = None
elif esc[0] == '^' and len(esc) == 2:
self['escape'] = chr(ord(esc[1])-64)
elif len(esc) == 1:
self['escape'] = esc
else:
sys.exit("Bad escape character '{}'.".format(esc))
def opt_localforward(self, f):
"""
Forward local port to remote address (lport:host:port)
"""
localPort, remoteHost, remotePort = f.split(':') # Doesn't do v6 yet
localPort = int(localPort)
remotePort = int(remotePort)
self.localForwards.append((localPort, (remoteHost, remotePort)))
def opt_remoteforward(self, f):
"""
Forward remote port to local address (rport:host:port)
"""
remotePort, connHost, connPort = f.split(':') # Doesn't do v6 yet
remotePort = int(remotePort)
connPort = int(connPort)
self.remoteForwards.append((remotePort, (connHost, connPort)))
def parseArgs(self, host, *command):
self['host'] = host
self['command'] = ' '.join(command)
# Rest of code in "run"
options = None
conn = None
exitStatus = 0
old = None
_inRawMode = 0
_savedRawMode = None
def run():
global options, old
args = sys.argv[1:]
if '-l' in args: # CVS is an idiot
i = args.index('-l')
args = args[i:i+2]+args
del args[i+2:i+4]
for arg in args[:]:
try:
i = args.index(arg)
if arg[:2] == '-o' and args[i+1][0] != '-':
args[i:i+2] = [] # Suck on it scp
except ValueError:
pass
options = ClientOptions()
try:
options.parseOptions(args)
except usage.UsageError as u:
print('ERROR: {}'.format(u))
options.opt_help()
sys.exit(1)
if options['log']:
if options['logfile']:
if options['logfile'] == '-':
f = sys.stdout
else:
f = open(options['logfile'], 'a+')
else:
f = sys.stderr
realout = sys.stdout
log.startLogging(f)
sys.stdout = realout
else:
log.discardLogs()
doConnect()
fd = sys.stdin.fileno()
try:
old = tty.tcgetattr(fd)
except:
old = None
try:
oldUSR1 = signal.signal(signal.SIGUSR1, lambda *a: reactor.callLater(0, reConnect))
except:
oldUSR1 = None
try:
reactor.run()
finally:
if old:
tty.tcsetattr(fd, tty.TCSANOW, old)
if oldUSR1:
signal.signal(signal.SIGUSR1, oldUSR1)
if (options['command'] and options['tty']) or not options['notty']:
signal.signal(signal.SIGWINCH, signal.SIG_DFL)
if sys.stdout.isatty() and not options['command']:
print('Connection to {} closed.'.format(options['host']))
sys.exit(exitStatus)
def handleError():
from twisted.python import failure
global exitStatus
exitStatus = 2
reactor.callLater(0.01, _stopReactor)
log.err(failure.Failure())
raise
def _stopReactor():
try:
reactor.stop()
except: pass
def doConnect():
if '@' in options['host']:
options['user'], options['host'] = options['host'].split('@', 1)
if not options.identitys:
options.identitys = ['~/.ssh/id_rsa', '~/.ssh/id_dsa']
host = options['host']
if not options['user']:
options['user'] = getpass.getuser()
if not options['port']:
options['port'] = 22
else:
options['port'] = int(options['port'])
host = options['host']
port = options['port']
vhk = default.verifyHostKey
if not options['host-key-algorithms']:
options['host-key-algorithms'] = default.getHostKeyAlgorithms(
host, options)
uao = default.SSHUserAuthClient(options['user'], options, SSHConnection())
connect.connect(host, port, options, vhk, uao).addErrback(_ebExit)
def _ebExit(f):
global exitStatus
exitStatus = "conch: exiting with error {}".format(f)
reactor.callLater(0.1, _stopReactor)
def onConnect():
# if keyAgent and options['agent']:
# cc = protocol.ClientCreator(reactor, SSHAgentForwardingLocal, conn)
# cc.connectUNIX(os.environ['SSH_AUTH_SOCK'])
if hasattr(conn.transport, 'sendIgnore'):
_KeepAlive(conn)
if options.localForwards:
for localPort, hostport in options.localForwards:
s = reactor.listenTCP(localPort,
forwarding.SSHListenForwardingFactory(conn,
hostport,
SSHListenClientForwardingChannel))
conn.localForwards.append(s)
if options.remoteForwards:
for remotePort, hostport in options.remoteForwards:
log.msg('asking for remote forwarding for {}:{}'.format(
remotePort, hostport))
conn.requestRemoteForwarding(remotePort, hostport)
reactor.addSystemEventTrigger('before', 'shutdown', beforeShutdown)
if not options['noshell'] or options['agent']:
conn.openChannel(SSHSession())
if options['fork']:
if os.fork():
os._exit(0)
os.setsid()
for i in range(3):
try:
os.close(i)
except OSError as e:
import errno
if e.errno != errno.EBADF:
raise
def reConnect():
beforeShutdown()
conn.transport.transport.loseConnection()
def beforeShutdown():
remoteForwards = options.remoteForwards
for remotePort, hostport in remoteForwards:
log.msg('cancelling {}:{}'.format(remotePort, hostport))
conn.cancelRemoteForwarding(remotePort)
def stopConnection():
if not options['reconnect']:
reactor.callLater(0.1, _stopReactor)
class _KeepAlive:
def __init__(self, conn):
self.conn = conn
self.globalTimeout = None
self.lc = task.LoopingCall(self.sendGlobal)
self.lc.start(300)
def sendGlobal(self):
d = self.conn.sendGlobalRequest(b"conch-keep-alive@twistedmatrix.com",
b"", wantReply=1)
d.addBoth(self._cbGlobal)
self.globalTimeout = reactor.callLater(30, self._ebGlobal)
def _cbGlobal(self, res):
if self.globalTimeout:
self.globalTimeout.cancel()
self.globalTimeout = None
def _ebGlobal(self):
if self.globalTimeout:
self.globalTimeout = None
self.conn.transport.loseConnection()
class SSHConnection(connection.SSHConnection):
def serviceStarted(self):
global conn
conn = self
self.localForwards = []
self.remoteForwards = {}
if not isinstance(self, connection.SSHConnection):
# make these fall through
del self.__class__.requestRemoteForwarding
del self.__class__.cancelRemoteForwarding
onConnect()
def serviceStopped(self):
lf = self.localForwards
self.localForwards = []
for s in lf:
s.loseConnection()
stopConnection()
def requestRemoteForwarding(self, remotePort, hostport):
data = forwarding.packGlobal_tcpip_forward(('0.0.0.0', remotePort))
d = self.sendGlobalRequest(b'tcpip-forward', data,
wantReply=1)
log.msg('requesting remote forwarding {}:{}'.format(
remotePort, hostport))
d.addCallback(self._cbRemoteForwarding, remotePort, hostport)
d.addErrback(self._ebRemoteForwarding, remotePort, hostport)
def _cbRemoteForwarding(self, result, remotePort, hostport):
log.msg('accepted remote forwarding {}:{}'.format(
remotePort, hostport))
self.remoteForwards[remotePort] = hostport
log.msg(repr(self.remoteForwards))
def _ebRemoteForwarding(self, f, remotePort, hostport):
log.msg('remote forwarding {}:{} failed'.format(
remotePort, hostport))
log.msg(f)
def cancelRemoteForwarding(self, remotePort):
data = forwarding.packGlobal_tcpip_forward(('0.0.0.0', remotePort))
self.sendGlobalRequest(b'cancel-tcpip-forward', data)
log.msg('cancelling remote forwarding {}'.format(remotePort))
try:
del self.remoteForwards[remotePort]
except Exception:
pass
log.msg(repr(self.remoteForwards))
def channel_forwarded_tcpip(self, windowSize, maxPacket, data):
log.msg('FTCP {!r}'.format(data))
remoteHP, origHP = forwarding.unpackOpen_forwarded_tcpip(data)
log.msg(self.remoteForwards)
log.msg(remoteHP)
if remoteHP[1] in self.remoteForwards:
connectHP = self.remoteForwards[remoteHP[1]]
log.msg('connect forwarding {}'.format(connectHP))
return SSHConnectForwardingChannel(connectHP,
remoteWindow=windowSize,
remoteMaxPacket=maxPacket,
conn=self)
else:
raise ConchError(connection.OPEN_CONNECT_FAILED,
"don't know about that port")
def channelClosed(self, channel):
log.msg('connection closing {}'.format(channel))
log.msg(self.channels)
if len(self.channels) == 1: # Just us left
log.msg('stopping connection')
stopConnection()
else:
# Because of the unix thing
self.__class__.__bases__[0].channelClosed(self, channel)
class SSHSession(channel.SSHChannel):
name = b'session'
def channelOpen(self, foo):
log.msg('session {} open'.format(self.id))
if options['agent']:
d = self.conn.sendRequest(self, b'auth-agent-req@openssh.com',
b'', wantReply=1)
d.addBoth(lambda x: log.msg(x))
if options['noshell']:
return
if (options['command'] and options['tty']) or not options['notty']:
_enterRawMode()
c = session.SSHSessionClient()
if options['escape'] and not options['notty']:
self.escapeMode = 1
c.dataReceived = self.handleInput
else:
c.dataReceived = self.write
c.connectionLost = lambda x: self.sendEOF()
self.stdio = stdio.StandardIO(c)
fd = 0
if options['subsystem']:
self.conn.sendRequest(self, b'subsystem',
common.NS(options['command']))
elif options['command']:
if options['tty']:
term = os.environ['TERM']
winsz = fcntl.ioctl(fd, tty.TIOCGWINSZ, '12345678')
winSize = struct.unpack('4H', winsz)
ptyReqData = session.packRequest_pty_req(term, winSize, '')
self.conn.sendRequest(self, b'pty-req', ptyReqData)
signal.signal(signal.SIGWINCH, self._windowResized)
self.conn.sendRequest(self, b'exec', common.NS(options['command']))
else:
if not options['notty']:
term = os.environ['TERM']
winsz = fcntl.ioctl(fd, tty.TIOCGWINSZ, '12345678')
winSize = struct.unpack('4H', winsz)
ptyReqData = session.packRequest_pty_req(term, winSize, '')
self.conn.sendRequest(self, b'pty-req', ptyReqData)
signal.signal(signal.SIGWINCH, self._windowResized)
self.conn.sendRequest(self, b'shell', b'')
#if hasattr(conn.transport, 'transport'):
# conn.transport.transport.setTcpNoDelay(1)
def handleInput(self, char):
if char in (b'\n', b'\r'):
self.escapeMode = 1
self.write(char)
elif self.escapeMode == 1 and char == options['escape']:
self.escapeMode = 2
elif self.escapeMode == 2:
self.escapeMode = 1 # So we can chain escapes together
if char == b'.': # Disconnect
log.msg('disconnecting from escape')
stopConnection()
return
elif char == b'\x1a': # ^Z, suspend
def _():
_leaveRawMode()
sys.stdout.flush()
sys.stdin.flush()
os.kill(os.getpid(), signal.SIGTSTP)
_enterRawMode()
reactor.callLater(0, _)
return
elif char == b'R': # Rekey connection
log.msg('rekeying connection')
self.conn.transport.sendKexInit()
return
elif char == b'#': # Display connections
self.stdio.write(
b'\r\nThe following connections are open:\r\n')
channels = self.conn.channels.keys()
channels.sort()
for channelId in channels:
self.stdio.write(networkString(' #{} {}\r\n'.format(
channelId,
self.conn.channels[channelId])))
return
self.write(b'~' + char)
else:
self.escapeMode = 0
self.write(char)
def dataReceived(self, data):
self.stdio.write(data)
def extReceived(self, t, data):
if t == connection.EXTENDED_DATA_STDERR:
log.msg('got {} stderr data'.format(len(data)))
if ioType(sys.stderr) == unicode:
sys.stderr.buffer.write(data)
else:
sys.stderr.write(data)
def eofReceived(self):
log.msg('got eof')
self.stdio.loseWriteConnection()
def closeReceived(self):
log.msg('remote side closed {}'.format(self))
self.conn.sendClose(self)
def closed(self):
global old
log.msg('closed {}'.format(self))
log.msg(repr(self.conn.channels))
def request_exit_status(self, data):
global exitStatus
exitStatus = int(struct.unpack('>L', data)[0])
log.msg('exit status: {}'.format(exitStatus))
def sendEOF(self):
self.conn.sendEOF(self)
def stopWriting(self):
self.stdio.pauseProducing()
def startWriting(self):
self.stdio.resumeProducing()
def _windowResized(self, *args):
winsz = fcntl.ioctl(0, tty.TIOCGWINSZ, '12345678')
winSize = struct.unpack('4H', winsz)
newSize = winSize[1], winSize[0], winSize[2], winSize[3]
self.conn.sendRequest(self, b'window-change', struct.pack('!4L', *newSize))
class SSHListenClientForwardingChannel(forwarding.SSHListenClientForwardingChannel): pass
class SSHConnectForwardingChannel(forwarding.SSHConnectForwardingChannel): pass
def _leaveRawMode():
global _inRawMode
if not _inRawMode:
return
fd = sys.stdin.fileno()
tty.tcsetattr(fd, tty.TCSANOW, _savedRawMode)
_inRawMode = 0
def _enterRawMode():
global _inRawMode, _savedRawMode
if _inRawMode:
return
fd = sys.stdin.fileno()
try:
old = tty.tcgetattr(fd)
new = old[:]
except:
log.msg('not a typewriter!')
else:
# iflage
new[0] = new[0] | tty.IGNPAR
new[0] = new[0] & ~(tty.ISTRIP | tty.INLCR | tty.IGNCR | tty.ICRNL |
tty.IXON | tty.IXANY | tty.IXOFF)
if hasattr(tty, 'IUCLC'):
new[0] = new[0] & ~tty.IUCLC
# lflag
new[3] = new[3] & ~(tty.ISIG | tty.ICANON | tty.ECHO | tty.ECHO |
tty.ECHOE | tty.ECHOK | tty.ECHONL)
if hasattr(tty, 'IEXTEN'):
new[3] = new[3] & ~tty.IEXTEN
#oflag
new[1] = new[1] & ~tty.OPOST
new[6][tty.VMIN] = 1
new[6][tty.VTIME] = 0
_savedRawMode = old
tty.tcsetattr(fd, tty.TCSANOW, new)
#tty.setraw(fd)
_inRawMode = 1
if __name__ == '__main__':
run()

View file

@ -0,0 +1,586 @@
# -*- test-case-name: twisted.conch.test.test_scripts -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Implementation module for the `tkconch` command.
"""
from __future__ import print_function
from twisted.conch import error
from twisted.conch.ui import tkvt100
from twisted.conch.ssh import transport, userauth, connection, common, keys
from twisted.conch.ssh import session, forwarding, channel
from twisted.conch.client.default import isInKnownHosts
from twisted.internet import reactor, defer, protocol, tksupport
from twisted.python import usage, log
from twisted.python.compat import _PY3
import os, sys, getpass, struct, base64, signal
if _PY3:
import tkinter as Tkinter
import tkinter.filedialog as tkFileDialog
import tkinter.messagebox as tkMessageBox
else:
import Tkinter, tkFileDialog, tkMessageBox
class TkConchMenu(Tkinter.Frame):
def __init__(self, *args, **params):
## Standard heading: initialization
Tkinter.Frame.__init__(self, *args, **params)
self.master.title('TkConch')
self.localRemoteVar = Tkinter.StringVar()
self.localRemoteVar.set('local')
Tkinter.Label(self, anchor='w', justify='left', text='Hostname').grid(column=1, row=1, sticky='w')
self.host = Tkinter.Entry(self)
self.host.grid(column=2, columnspan=2, row=1, sticky='nesw')
Tkinter.Label(self, anchor='w', justify='left', text='Port').grid(column=1, row=2, sticky='w')
self.port = Tkinter.Entry(self)
self.port.grid(column=2, columnspan=2, row=2, sticky='nesw')
Tkinter.Label(self, anchor='w', justify='left', text='Username').grid(column=1, row=3, sticky='w')
self.user = Tkinter.Entry(self)
self.user.grid(column=2, columnspan=2, row=3, sticky='nesw')
Tkinter.Label(self, anchor='w', justify='left', text='Command').grid(column=1, row=4, sticky='w')
self.command = Tkinter.Entry(self)
self.command.grid(column=2, columnspan=2, row=4, sticky='nesw')
Tkinter.Label(self, anchor='w', justify='left', text='Identity').grid(column=1, row=5, sticky='w')
self.identity = Tkinter.Entry(self)
self.identity.grid(column=2, row=5, sticky='nesw')
Tkinter.Button(self, command=self.getIdentityFile, text='Browse').grid(column=3, row=5, sticky='nesw')
Tkinter.Label(self, text='Port Forwarding').grid(column=1, row=6, sticky='w')
self.forwards = Tkinter.Listbox(self, height=0, width=0)
self.forwards.grid(column=2, columnspan=2, row=6, sticky='nesw')
Tkinter.Button(self, text='Add', command=self.addForward).grid(column=1, row=7)
Tkinter.Button(self, text='Remove', command=self.removeForward).grid(column=1, row=8)
self.forwardPort = Tkinter.Entry(self)
self.forwardPort.grid(column=2, row=7, sticky='nesw')
Tkinter.Label(self, text='Port').grid(column=3, row=7, sticky='nesw')
self.forwardHost = Tkinter.Entry(self)
self.forwardHost.grid(column=2, row=8, sticky='nesw')
Tkinter.Label(self, text='Host').grid(column=3, row=8, sticky='nesw')
self.localForward = Tkinter.Radiobutton(self, text='Local', variable=self.localRemoteVar, value='local')
self.localForward.grid(column=2, row=9)
self.remoteForward = Tkinter.Radiobutton(self, text='Remote', variable=self.localRemoteVar, value='remote')
self.remoteForward.grid(column=3, row=9)
Tkinter.Label(self, text='Advanced Options').grid(column=1, columnspan=3, row=10, sticky='nesw')
Tkinter.Label(self, anchor='w', justify='left', text='Cipher').grid(column=1, row=11, sticky='w')
self.cipher = Tkinter.Entry(self, name='cipher')
self.cipher.grid(column=2, columnspan=2, row=11, sticky='nesw')
Tkinter.Label(self, anchor='w', justify='left', text='MAC').grid(column=1, row=12, sticky='w')
self.mac = Tkinter.Entry(self, name='mac')
self.mac.grid(column=2, columnspan=2, row=12, sticky='nesw')
Tkinter.Label(self, anchor='w', justify='left', text='Escape Char').grid(column=1, row=13, sticky='w')
self.escape = Tkinter.Entry(self, name='escape')
self.escape.grid(column=2, columnspan=2, row=13, sticky='nesw')
Tkinter.Button(self, text='Connect!', command=self.doConnect).grid(column=1, columnspan=3, row=14, sticky='nesw')
# Resize behavior(s)
self.grid_rowconfigure(6, weight=1, minsize=64)
self.grid_columnconfigure(2, weight=1, minsize=2)
self.master.protocol("WM_DELETE_WINDOW", sys.exit)
def getIdentityFile(self):
r = tkFileDialog.askopenfilename()
if r:
self.identity.delete(0, Tkinter.END)
self.identity.insert(Tkinter.END, r)
def addForward(self):
port = self.forwardPort.get()
self.forwardPort.delete(0, Tkinter.END)
host = self.forwardHost.get()
self.forwardHost.delete(0, Tkinter.END)
if self.localRemoteVar.get() == 'local':
self.forwards.insert(Tkinter.END, 'L:%s:%s' % (port, host))
else:
self.forwards.insert(Tkinter.END, 'R:%s:%s' % (port, host))
def removeForward(self):
cur = self.forwards.curselection()
if cur:
self.forwards.remove(cur[0])
def doConnect(self):
finished = 1
options['host'] = self.host.get()
options['port'] = self.port.get()
options['user'] = self.user.get()
options['command'] = self.command.get()
cipher = self.cipher.get()
mac = self.mac.get()
escape = self.escape.get()
if cipher:
if cipher in SSHClientTransport.supportedCiphers:
SSHClientTransport.supportedCiphers = [cipher]
else:
tkMessageBox.showerror('TkConch', 'Bad cipher.')
finished = 0
if mac:
if mac in SSHClientTransport.supportedMACs:
SSHClientTransport.supportedMACs = [mac]
elif finished:
tkMessageBox.showerror('TkConch', 'Bad MAC.')
finished = 0
if escape:
if escape == 'none':
options['escape'] = None
elif escape[0] == '^' and len(escape) == 2:
options['escape'] = chr(ord(escape[1])-64)
elif len(escape) == 1:
options['escape'] = escape
elif finished:
tkMessageBox.showerror('TkConch', "Bad escape character '%s'." % escape)
finished = 0
if self.identity.get():
options.identitys.append(self.identity.get())
for line in self.forwards.get(0,Tkinter.END):
if line[0]=='L':
options.opt_localforward(line[2:])
else:
options.opt_remoteforward(line[2:])
if '@' in options['host']:
options['user'], options['host'] = options['host'].split('@',1)
if (not options['host'] or not options['user']) and finished:
tkMessageBox.showerror('TkConch', 'Missing host or username.')
finished = 0
if finished:
self.master.quit()
self.master.destroy()
if options['log']:
realout = sys.stdout
log.startLogging(sys.stderr)
sys.stdout = realout
else:
log.discardLogs()
log.deferr = handleError # HACK
if not options.identitys:
options.identitys = ['~/.ssh/id_rsa', '~/.ssh/id_dsa']
host = options['host']
port = int(options['port'] or 22)
log.msg((host,port))
reactor.connectTCP(host, port, SSHClientFactory())
frame.master.deiconify()
frame.master.title('%s@%s - TkConch' % (options['user'], options['host']))
else:
self.focus()
class GeneralOptions(usage.Options):
synopsis = """Usage: tkconch [options] host [command]
"""
optParameters = [['user', 'l', None, 'Log in using this user name.'],
['identity', 'i', '~/.ssh/identity', 'Identity for public key authentication'],
['escape', 'e', '~', "Set escape character; ``none'' = disable"],
['cipher', 'c', None, 'Select encryption algorithm.'],
['macs', 'm', None, 'Specify MAC algorithms for protocol version 2.'],
['port', 'p', None, 'Connect to this port. Server must be on the same port.'],
['localforward', 'L', None, 'listen-port:host:port Forward local port to remote address'],
['remoteforward', 'R', None, 'listen-port:host:port Forward remote port to local address'],
]
optFlags = [['tty', 't', 'Tty; allocate a tty even if command is given.'],
['notty', 'T', 'Do not allocate a tty.'],
['version', 'V', 'Display version number only.'],
['compress', 'C', 'Enable compression.'],
['noshell', 'N', 'Do not execute a shell or command.'],
['subsystem', 's', 'Invoke command (mandatory) as SSH2 subsystem.'],
['log', 'v', 'Log to stderr'],
['ansilog', 'a', 'Print the received data to stdout']]
_ciphers = transport.SSHClientTransport.supportedCiphers
_macs = transport.SSHClientTransport.supportedMACs
compData = usage.Completions(
mutuallyExclusive=[("tty", "notty")],
optActions={
"cipher": usage.CompleteList(_ciphers),
"macs": usage.CompleteList(_macs),
"localforward": usage.Completer(descr="listen-port:host:port"),
"remoteforward": usage.Completer(descr="listen-port:host:port")},
extraActions=[usage.CompleteUserAtHost(),
usage.Completer(descr="command"),
usage.Completer(descr="argument", repeat=True)]
)
identitys = []
localForwards = []
remoteForwards = []
def opt_identity(self, i):
self.identitys.append(i)
def opt_localforward(self, f):
localPort, remoteHost, remotePort = f.split(':') # doesn't do v6 yet
localPort = int(localPort)
remotePort = int(remotePort)
self.localForwards.append((localPort, (remoteHost, remotePort)))
def opt_remoteforward(self, f):
remotePort, connHost, connPort = f.split(':') # doesn't do v6 yet
remotePort = int(remotePort)
connPort = int(connPort)
self.remoteForwards.append((remotePort, (connHost, connPort)))
def opt_compress(self):
SSHClientTransport.supportedCompressions[0:1] = ['zlib']
def parseArgs(self, *args):
if args:
self['host'] = args[0]
self['command'] = ' '.join(args[1:])
else:
self['host'] = ''
self['command'] = ''
# Rest of code in "run"
options = None
menu = None
exitStatus = 0
frame = None
def deferredAskFrame(question, echo):
if frame.callback:
raise ValueError("can't ask 2 questions at once!")
d = defer.Deferred()
resp = []
def gotChar(ch, resp=resp):
if not ch: return
if ch=='\x03': # C-c
reactor.stop()
if ch=='\r':
frame.write('\r\n')
stresp = ''.join(resp)
del resp
frame.callback = None
d.callback(stresp)
return
elif 32 <= ord(ch) < 127:
resp.append(ch)
if echo:
frame.write(ch)
elif ord(ch) == 8 and resp: # BS
if echo: frame.write('\x08 \x08')
resp.pop()
frame.callback = gotChar
frame.write(question)
frame.canvas.focus_force()
return d
def run():
global menu, options, frame
args = sys.argv[1:]
if '-l' in args: # cvs is an idiot
i = args.index('-l')
args = args[i:i+2]+args
del args[i+2:i+4]
for arg in args[:]:
try:
i = args.index(arg)
if arg[:2] == '-o' and args[i+1][0]!='-':
args[i:i+2] = [] # suck on it scp
except ValueError:
pass
root = Tkinter.Tk()
root.withdraw()
top = Tkinter.Toplevel()
menu = TkConchMenu(top)
menu.pack(side=Tkinter.TOP, fill=Tkinter.BOTH, expand=1)
options = GeneralOptions()
try:
options.parseOptions(args)
except usage.UsageError as u:
print('ERROR: %s' % u)
options.opt_help()
sys.exit(1)
for k,v in options.items():
if v and hasattr(menu, k):
getattr(menu,k).insert(Tkinter.END, v)
for (p, (rh, rp)) in options.localForwards:
menu.forwards.insert(Tkinter.END, 'L:%s:%s:%s' % (p, rh, rp))
options.localForwards = []
for (p, (rh, rp)) in options.remoteForwards:
menu.forwards.insert(Tkinter.END, 'R:%s:%s:%s' % (p, rh, rp))
options.remoteForwards = []
frame = tkvt100.VT100Frame(root, callback=None)
root.geometry('%dx%d'%(tkvt100.fontWidth*frame.width+3, tkvt100.fontHeight*frame.height+3))
frame.pack(side = Tkinter.TOP)
tksupport.install(root)
root.withdraw()
if (options['host'] and options['user']) or '@' in options['host']:
menu.doConnect()
else:
top.mainloop()
reactor.run()
sys.exit(exitStatus)
def handleError():
from twisted.python import failure
global exitStatus
exitStatus = 2
log.err(failure.Failure())
reactor.stop()
raise
class SSHClientFactory(protocol.ClientFactory):
noisy = 1
def stopFactory(self):
reactor.stop()
def buildProtocol(self, addr):
return SSHClientTransport()
def clientConnectionFailed(self, connector, reason):
tkMessageBox.showwarning('TkConch','Connection Failed, Reason:\n %s: %s' % (reason.type, reason.value))
class SSHClientTransport(transport.SSHClientTransport):
def receiveError(self, code, desc):
global exitStatus
exitStatus = 'conch:\tRemote side disconnected with error code %i\nconch:\treason: %s' % (code, desc)
def sendDisconnect(self, code, reason):
global exitStatus
exitStatus = 'conch:\tSending disconnect with error code %i\nconch:\treason: %s' % (code, reason)
transport.SSHClientTransport.sendDisconnect(self, code, reason)
def receiveDebug(self, alwaysDisplay, message, lang):
global options
if alwaysDisplay or options['log']:
log.msg('Received Debug Message: %s' % message)
def verifyHostKey(self, pubKey, fingerprint):
#d = defer.Deferred()
#d.addCallback(lambda x:defer.succeed(1))
#d.callback(2)
#return d
goodKey = isInKnownHosts(options['host'], pubKey, {'known-hosts': None})
if goodKey == 1: # good key
return defer.succeed(1)
elif goodKey == 2: # AAHHHHH changed
return defer.fail(error.ConchError('bad host key'))
else:
if options['host'] == self.transport.getPeer().host:
host = options['host']
khHost = options['host']
else:
host = '%s (%s)' % (options['host'],
self.transport.getPeer().host)
khHost = '%s,%s' % (options['host'],
self.transport.getPeer().host)
keyType = common.getNS(pubKey)[0]
ques = """The authenticity of host '%s' can't be established.\r
%s key fingerprint is %s.""" % (host,
{b'ssh-dss':'DSA', b'ssh-rsa':'RSA'}[keyType],
fingerprint)
ques+='\r\nAre you sure you want to continue connecting (yes/no)? '
return deferredAskFrame(ques, 1).addCallback(self._cbVerifyHostKey, pubKey, khHost, keyType)
def _cbVerifyHostKey(self, ans, pubKey, khHost, keyType):
if ans.lower() not in ('yes', 'no'):
return deferredAskFrame("Please type 'yes' or 'no': ",1).addCallback(self._cbVerifyHostKey, pubKey, khHost, keyType)
if ans.lower() == 'no':
frame.write('Host key verification failed.\r\n')
raise error.ConchError('bad host key')
try:
frame.write(
"Warning: Permanently added '%s' (%s) to the list of "
"known hosts.\r\n" %
(khHost, {b'ssh-dss':'DSA', b'ssh-rsa':'RSA'}[keyType]))
with open(os.path.expanduser('~/.ssh/known_hosts'), 'a') as known_hosts:
encodedKey = base64.encodestring(pubKey).replace(b'\n', b'')
known_hosts.write('\n%s %s %s' % (khHost, keyType, encodedKey))
except:
log.deferr()
raise error.ConchError
def connectionSecure(self):
if options['user']:
user = options['user']
else:
user = getpass.getuser()
self.requestService(SSHUserAuthClient(user, SSHConnection()))
class SSHUserAuthClient(userauth.SSHUserAuthClient):
usedFiles = []
def getPassword(self, prompt = None):
if not prompt:
prompt = "%s@%s's password: " % (self.user, options['host'])
return deferredAskFrame(prompt,0)
def getPublicKey(self):
files = [x for x in options.identitys if x not in self.usedFiles]
if not files:
return None
file = files[0]
log.msg(file)
self.usedFiles.append(file)
file = os.path.expanduser(file)
file += '.pub'
if not os.path.exists(file):
return
try:
return keys.Key.fromFile(file).blob()
except:
return self.getPublicKey() # try again
def getPrivateKey(self):
file = os.path.expanduser(self.usedFiles[-1])
if not os.path.exists(file):
return None
try:
return defer.succeed(keys.Key.fromFile(file).keyObject)
except keys.BadKeyError as e:
if e.args[0] == 'encrypted key with no password':
prompt = "Enter passphrase for key '%s': " % \
self.usedFiles[-1]
return deferredAskFrame(prompt, 0).addCallback(self._cbGetPrivateKey, 0)
def _cbGetPrivateKey(self, ans, count):
file = os.path.expanduser(self.usedFiles[-1])
try:
return keys.Key.fromFile(file, password = ans).keyObject
except keys.BadKeyError:
if count == 2:
raise
prompt = "Enter passphrase for key '%s': " % \
self.usedFiles[-1]
return deferredAskFrame(prompt, 0).addCallback(self._cbGetPrivateKey, count+1)
class SSHConnection(connection.SSHConnection):
def serviceStarted(self):
if not options['noshell']:
self.openChannel(SSHSession())
if options.localForwards:
for localPort, hostport in options.localForwards:
reactor.listenTCP(localPort,
forwarding.SSHListenForwardingFactory(self,
hostport,
forwarding.SSHListenClientForwardingChannel))
if options.remoteForwards:
for remotePort, hostport in options.remoteForwards:
log.msg('asking for remote forwarding for %s:%s' %
(remotePort, hostport))
data = forwarding.packGlobal_tcpip_forward(
('0.0.0.0', remotePort))
self.sendGlobalRequest('tcpip-forward', data)
self.remoteForwards[remotePort] = hostport
class SSHSession(channel.SSHChannel):
name = b'session'
def channelOpen(self, foo):
#global globalSession
#globalSession = self
# turn off local echo
self.escapeMode = 1
c = session.SSHSessionClient()
if options['escape']:
c.dataReceived = self.handleInput
else:
c.dataReceived = self.write
c.connectionLost = self.sendEOF
frame.callback = c.dataReceived
frame.canvas.focus_force()
if options['subsystem']:
self.conn.sendRequest(self, b'subsystem', \
common.NS(options['command']))
elif options['command']:
if options['tty']:
term = os.environ.get('TERM', 'xterm')
#winsz = fcntl.ioctl(fd, tty.TIOCGWINSZ, '12345678')
winSize = (25,80,0,0) #struct.unpack('4H', winsz)
ptyReqData = session.packRequest_pty_req(term, winSize, '')
self.conn.sendRequest(self, b'pty-req', ptyReqData)
self.conn.sendRequest(self, 'exec', \
common.NS(options['command']))
else:
if not options['notty']:
term = os.environ.get('TERM', 'xterm')
#winsz = fcntl.ioctl(fd, tty.TIOCGWINSZ, '12345678')
winSize = (25,80,0,0) #struct.unpack('4H', winsz)
ptyReqData = session.packRequest_pty_req(term, winSize, '')
self.conn.sendRequest(self, b'pty-req', ptyReqData)
self.conn.sendRequest(self, b'shell', b'')
self.conn.transport.transport.setTcpNoDelay(1)
def handleInput(self, char):
#log.msg('handling %s' % repr(char))
if char in ('\n', '\r'):
self.escapeMode = 1
self.write(char)
elif self.escapeMode == 1 and char == options['escape']:
self.escapeMode = 2
elif self.escapeMode == 2:
self.escapeMode = 1 # so we can chain escapes together
if char == '.': # disconnect
log.msg('disconnecting from escape')
reactor.stop()
return
elif char == '\x1a': # ^Z, suspend
# following line courtesy of Erwin@freenode
os.kill(os.getpid(), signal.SIGSTOP)
return
elif char == 'R': # rekey connection
log.msg('rekeying connection')
self.conn.transport.sendKexInit()
return
self.write('~' + char)
else:
self.escapeMode = 0
self.write(char)
def dataReceived(self, data):
if _PY3 and isinstance(data, bytes):
data = data.decode("utf-8")
if options['ansilog']:
print(repr(data))
frame.write(data)
def extReceived(self, t, data):
if t==connection.EXTENDED_DATA_STDERR:
log.msg('got %s stderr data' % len(data))
sys.stderr.write(data)
sys.stderr.flush()
def eofReceived(self):
log.msg('got eof')
sys.stdin.close()
def closed(self):
log.msg('closed %s' % self)
if len(self.conn.channels) == 1: # just us left
reactor.stop()
def request_exit_status(self, data):
global exitStatus
exitStatus = int(struct.unpack('>L', data)[0])
log.msg('exit status: %s' % exitStatus)
def sendEOF(self):
self.conn.sendEOF(self)
if __name__=="__main__":
run()

View file

@ -0,0 +1,10 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
#
"""
An SSHv2 implementation for Twisted. Part of the Twisted.Conch package.
Maintainer: Paul Swartz
"""

View file

@ -0,0 +1,294 @@
# -*- test-case-name: twisted.conch.test.test_transport -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
SSH key exchange handling.
"""
from __future__ import absolute_import, division
from hashlib import sha1, sha256, sha384, sha512
from zope.interface import Attribute, implementer, Interface
from twisted.conch import error
from twisted.python.compat import long
class _IKexAlgorithm(Interface):
"""
An L{_IKexAlgorithm} describes a key exchange algorithm.
"""
preference = Attribute(
"An L{int} giving the preference of the algorithm when negotiating "
"key exchange. Algorithms with lower precedence values are more "
"preferred.")
hashProcessor = Attribute(
"A callable hash algorithm constructor (e.g. C{hashlib.sha256}) "
"suitable for use with this key exchange algorithm.")
class _IFixedGroupKexAlgorithm(_IKexAlgorithm):
"""
An L{_IFixedGroupKexAlgorithm} describes a key exchange algorithm with a
fixed prime / generator group.
"""
prime = Attribute(
"A L{long} giving the prime number used in Diffie-Hellman key "
"exchange, or L{None} if not applicable.")
generator = Attribute(
"A L{long} giving the generator number used in Diffie-Hellman key "
"exchange, or L{None} if not applicable. (This is not related to "
"Python generator functions.)")
class _IEllipticCurveExchangeKexAlgorithm(_IKexAlgorithm):
"""
An L{_IEllipticCurveExchangeKexAlgorithm} describes a key exchange algorithm
that uses an elliptic curve exchange between the client and server.
"""
class _IGroupExchangeKexAlgorithm(_IKexAlgorithm):
"""
An L{_IGroupExchangeKexAlgorithm} describes a key exchange algorithm
that uses group exchange between the client and server.
A prime / generator group should be chosen at run time based on the
requested size. See RFC 4419.
"""
@implementer(_IEllipticCurveExchangeKexAlgorithm)
class _Curve25519SHA256(object):
"""
Elliptic Curve Key Exchange using Curve25519 and SHA256. Defined in
U{https://datatracker.ietf.org/doc/draft-ietf-curdle-ssh-curves/}.
"""
preference = 1
hashProcessor = sha256
@implementer(_IEllipticCurveExchangeKexAlgorithm)
class _Curve25519SHA256LibSSH(object):
"""
As L{_Curve25519SHA256}, but with a pre-standardized algorithm name.
"""
preference = 2
hashProcessor = sha256
@implementer(_IEllipticCurveExchangeKexAlgorithm)
class _ECDH256(object):
"""
Elliptic Curve Key Exchange with SHA-256 as HASH. Defined in
RFC 5656.
"""
preference = 3
hashProcessor = sha256
@implementer(_IEllipticCurveExchangeKexAlgorithm)
class _ECDH384(object):
"""
Elliptic Curve Key Exchange with SHA-384 as HASH. Defined in
RFC 5656.
"""
preference = 4
hashProcessor = sha384
@implementer(_IEllipticCurveExchangeKexAlgorithm)
class _ECDH512(object):
"""
Elliptic Curve Key Exchange with SHA-512 as HASH. Defined in
RFC 5656.
"""
preference = 5
hashProcessor = sha512
@implementer(_IGroupExchangeKexAlgorithm)
class _DHGroupExchangeSHA256(object):
"""
Diffie-Hellman Group and Key Exchange with SHA-256 as HASH. Defined in
RFC 4419, 4.2.
"""
preference = 6
hashProcessor = sha256
@implementer(_IGroupExchangeKexAlgorithm)
class _DHGroupExchangeSHA1(object):
"""
Diffie-Hellman Group and Key Exchange with SHA-1 as HASH. Defined in
RFC 4419, 4.1.
"""
preference = 7
hashProcessor = sha1
@implementer(_IFixedGroupKexAlgorithm)
class _DHGroup14SHA1(object):
"""
Diffie-Hellman key exchange with SHA-1 as HASH and Oakley Group 14
(2048-bit MODP Group). Defined in RFC 4253, 8.2.
"""
preference = 8
hashProcessor = sha1
# Diffie-Hellman primes from Oakley Group 14 (RFC 3526, 3).
prime = long('32317006071311007300338913926423828248817941241140239112842'
'00975140074170663435422261968941736356934711790173790970419175460587'
'32091950288537589861856221532121754125149017745202702357960782362488'
'84246189477587641105928646099411723245426622522193230540919037680524'
'23551912567971587011700105805587765103886184728025797605490356973256'
'15261670813393617995413364765591603683178967290731783845896806396719'
'00977202194168647225871031411336429319536193471636533209717077448227'
'98858856536920864529663607725026895550592836275112117409697299806841'
'05543595848665832916421362182310789909994486524682624169720359118525'
'07045361090559')
generator = 2
# Which ECDH hash function to use is dependent on the size.
_kexAlgorithms = {
b"curve25519-sha256": _Curve25519SHA256(),
b"curve25519-sha256@libssh.org": _Curve25519SHA256LibSSH(),
b"diffie-hellman-group-exchange-sha256": _DHGroupExchangeSHA256(),
b"diffie-hellman-group-exchange-sha1": _DHGroupExchangeSHA1(),
b"diffie-hellman-group14-sha1": _DHGroup14SHA1(),
b"ecdh-sha2-nistp256": _ECDH256(),
b"ecdh-sha2-nistp384": _ECDH384(),
b"ecdh-sha2-nistp521": _ECDH512(),
}
def getKex(kexAlgorithm):
"""
Get a description of a named key exchange algorithm.
@param kexAlgorithm: The key exchange algorithm name.
@type kexAlgorithm: L{bytes}
@return: A description of the key exchange algorithm named by
C{kexAlgorithm}.
@rtype: L{_IKexAlgorithm}
@raises ConchError: if the key exchange algorithm is not found.
"""
if kexAlgorithm not in _kexAlgorithms:
raise error.ConchError(
"Unsupported key exchange algorithm: %s" % (kexAlgorithm,))
return _kexAlgorithms[kexAlgorithm]
def isEllipticCurve(kexAlgorithm):
"""
Returns C{True} if C{kexAlgorithm} is an elliptic curve.
@param kexAlgorithm: The key exchange algorithm name.
@type kexAlgorithm: C{str}
@return: C{True} if C{kexAlgorithm} is an elliptic curve,
otherwise C{False}.
@rtype: C{bool}
"""
return _IEllipticCurveExchangeKexAlgorithm.providedBy(getKex(kexAlgorithm))
def isFixedGroup(kexAlgorithm):
"""
Returns C{True} if C{kexAlgorithm} has a fixed prime / generator group.
@param kexAlgorithm: The key exchange algorithm name.
@type kexAlgorithm: L{bytes}
@return: C{True} if C{kexAlgorithm} has a fixed prime / generator group,
otherwise C{False}.
@rtype: L{bool}
"""
return _IFixedGroupKexAlgorithm.providedBy(getKex(kexAlgorithm))
def getHashProcessor(kexAlgorithm):
"""
Get the hash algorithm callable to use in key exchange.
@param kexAlgorithm: The key exchange algorithm name.
@type kexAlgorithm: L{bytes}
@return: A callable hash algorithm constructor (e.g. C{hashlib.sha256}).
@rtype: C{callable}
"""
kex = getKex(kexAlgorithm)
return kex.hashProcessor
def getDHGeneratorAndPrime(kexAlgorithm):
"""
Get the generator and the prime to use in key exchange.
@param kexAlgorithm: The key exchange algorithm name.
@type kexAlgorithm: L{bytes}
@return: A L{tuple} containing L{long} generator and L{long} prime.
@rtype: L{tuple}
"""
kex = getKex(kexAlgorithm)
return kex.generator, kex.prime
def getSupportedKeyExchanges():
"""
Get a list of supported key exchange algorithm names in order of
preference.
@return: A C{list} of supported key exchange algorithm names.
@rtype: C{list} of L{bytes}
"""
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import ec
from twisted.conch.ssh.keys import _curveTable
backend = default_backend()
kexAlgorithms = _kexAlgorithms.copy()
for keyAlgorithm in list(kexAlgorithms):
if keyAlgorithm.startswith(b"ecdh"):
keyAlgorithmDsa = keyAlgorithm.replace(b"ecdh", b"ecdsa")
supported = backend.elliptic_curve_exchange_algorithm_supported(
ec.ECDH(), _curveTable[keyAlgorithmDsa])
elif keyAlgorithm.startswith(b"curve25519-sha256"):
supported = backend.x25519_supported()
else:
supported = True
if not supported:
kexAlgorithms.pop(keyAlgorithm)
return sorted(
kexAlgorithms,
key=lambda kexAlgorithm: kexAlgorithms[kexAlgorithm].preference)

View file

@ -0,0 +1,47 @@
# -*- test-case-name: twisted.conch.test.test_address -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Address object for SSH network connections.
Maintainer: Paul Swartz
@since: 12.1
"""
from __future__ import division, absolute_import
from zope.interface import implementer
from twisted.internet.interfaces import IAddress
from twisted.python import util
@implementer(IAddress)
class SSHTransportAddress(util.FancyEqMixin, object):
"""
Object representing an SSH Transport endpoint.
This is used to ensure that any code inspecting this address and
attempting to construct a similar connection based upon it is not
mislead into creating a transport which is not similar to the one it is
indicating.
@ivar address: An instance of an object which implements I{IAddress} to
which this transport address is connected.
"""
compareAttributes = ('address',)
def __init__(self, address):
self.address = address
def __repr__(self):
return 'SSHTransportAddress(%r)' % (self.address,)
def __hash__(self):
return hash(('SSH', self.address))

View file

@ -0,0 +1,296 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Implements the SSH v2 key agent protocol. This protocol is documented in the
SSH source code, in the file
U{PROTOCOL.agent<http://www.openbsd.org/cgi-bin/cvsweb/src/usr.bin/ssh/PROTOCOL.agent>}.
Maintainer: Paul Swartz
"""
from __future__ import absolute_import, division
import struct
from twisted.conch.ssh.common import NS, getNS, getMP
from twisted.conch.error import ConchError, MissingKeyStoreError
from twisted.conch.ssh import keys
from twisted.internet import defer, protocol
from twisted.python.compat import itervalues
class SSHAgentClient(protocol.Protocol):
"""
The client side of the SSH agent protocol. This is equivalent to
ssh-add(1) and can be used with either ssh-agent(1) or the SSHAgentServer
protocol, also in this package.
"""
def __init__(self):
self.buf = b''
self.deferreds = []
def dataReceived(self, data):
self.buf += data
while 1:
if len(self.buf) <= 4:
return
packLen = struct.unpack('!L', self.buf[:4])[0]
if len(self.buf) < 4 + packLen:
return
packet, self.buf = self.buf[4:4 + packLen], self.buf[4 + packLen:]
reqType = ord(packet[0:1])
d = self.deferreds.pop(0)
if reqType == AGENT_FAILURE:
d.errback(ConchError('agent failure'))
elif reqType == AGENT_SUCCESS:
d.callback(b'')
else:
d.callback(packet)
def sendRequest(self, reqType, data):
pack = struct.pack('!LB',len(data) + 1, reqType) + data
self.transport.write(pack)
d = defer.Deferred()
self.deferreds.append(d)
return d
def requestIdentities(self):
"""
@return: A L{Deferred} which will fire with a list of all keys found in
the SSH agent. The list of keys is comprised of (public key blob,
comment) tuples.
"""
d = self.sendRequest(AGENTC_REQUEST_IDENTITIES, b'')
d.addCallback(self._cbRequestIdentities)
return d
def _cbRequestIdentities(self, data):
"""
Unpack a collection of identities into a list of tuples comprised of
public key blobs and comments.
"""
if ord(data[0:1]) != AGENT_IDENTITIES_ANSWER:
raise ConchError('unexpected response: %i' % ord(data[0:1]))
numKeys = struct.unpack('!L', data[1:5])[0]
result = []
data = data[5:]
for i in range(numKeys):
blob, data = getNS(data)
comment, data = getNS(data)
result.append((blob, comment))
return result
def addIdentity(self, blob, comment = b''):
"""
Add a private key blob to the agent's collection of keys.
"""
req = blob
req += NS(comment)
return self.sendRequest(AGENTC_ADD_IDENTITY, req)
def signData(self, blob, data):
"""
Request that the agent sign the given C{data} with the private key
which corresponds to the public key given by C{blob}. The private
key should have been added to the agent already.
@type blob: L{bytes}
@type data: L{bytes}
@return: A L{Deferred} which fires with a signature for given data
created with the given key.
"""
req = NS(blob)
req += NS(data)
req += b'\000\000\000\000' # flags
return self.sendRequest(AGENTC_SIGN_REQUEST, req).addCallback(self._cbSignData)
def _cbSignData(self, data):
if ord(data[0:1]) != AGENT_SIGN_RESPONSE:
raise ConchError('unexpected data: %i' % ord(data[0:1]))
signature = getNS(data[1:])[0]
return signature
def removeIdentity(self, blob):
"""
Remove the private key corresponding to the public key in blob from the
running agent.
"""
req = NS(blob)
return self.sendRequest(AGENTC_REMOVE_IDENTITY, req)
def removeAllIdentities(self):
"""
Remove all keys from the running agent.
"""
return self.sendRequest(AGENTC_REMOVE_ALL_IDENTITIES, b'')
class SSHAgentServer(protocol.Protocol):
"""
The server side of the SSH agent protocol. This is equivalent to
ssh-agent(1) and can be used with either ssh-add(1) or the SSHAgentClient
protocol, also in this package.
"""
def __init__(self):
self.buf = b''
def dataReceived(self, data):
self.buf += data
while 1:
if len(self.buf) <= 4:
return
packLen = struct.unpack('!L', self.buf[:4])[0]
if len(self.buf) < 4 + packLen:
return
packet, self.buf = self.buf[4:4 + packLen], self.buf[4 + packLen:]
reqType = ord(packet[0:1])
reqName = messages.get(reqType, None)
if not reqName:
self.sendResponse(AGENT_FAILURE, b'')
else:
f = getattr(self, 'agentc_%s' % reqName)
if getattr(self.factory, 'keys', None) is None:
self.sendResponse(AGENT_FAILURE, b'')
raise MissingKeyStoreError()
f(packet[1:])
def sendResponse(self, reqType, data):
pack = struct.pack('!LB', len(data) + 1, reqType) + data
self.transport.write(pack)
def agentc_REQUEST_IDENTITIES(self, data):
"""
Return all of the identities that have been added to the server
"""
assert data == b''
numKeys = len(self.factory.keys)
resp = []
resp.append(struct.pack('!L', numKeys))
for key, comment in itervalues(self.factory.keys):
resp.append(NS(key.blob())) # yes, wrapped in an NS
resp.append(NS(comment))
self.sendResponse(AGENT_IDENTITIES_ANSWER, b''.join(resp))
def agentc_SIGN_REQUEST(self, data):
"""
Data is a structure with a reference to an already added key object and
some data that the clients wants signed with that key. If the key
object wasn't loaded, return AGENT_FAILURE, else return the signature.
"""
blob, data = getNS(data)
if blob not in self.factory.keys:
return self.sendResponse(AGENT_FAILURE, b'')
signData, data = getNS(data)
assert data == b'\000\000\000\000'
self.sendResponse(AGENT_SIGN_RESPONSE, NS(self.factory.keys[blob][0].sign(signData)))
def agentc_ADD_IDENTITY(self, data):
"""
Adds a private key to the agent's collection of identities. On
subsequent interactions, the private key can be accessed using only the
corresponding public key.
"""
# need to pre-read the key data so we can get past it to the comment string
keyType, rest = getNS(data)
if keyType == b'ssh-rsa':
nmp = 6
elif keyType == b'ssh-dss':
nmp = 5
else:
raise keys.BadKeyError('unknown blob type: %s' % keyType)
rest = getMP(rest, nmp)[-1] # ignore the key data for now, we just want the comment
comment, rest = getNS(rest) # the comment, tacked onto the end of the key blob
k = keys.Key.fromString(data, type='private_blob') # not wrapped in NS here
self.factory.keys[k.blob()] = (k, comment)
self.sendResponse(AGENT_SUCCESS, b'')
def agentc_REMOVE_IDENTITY(self, data):
"""
Remove a specific key from the agent's collection of identities.
"""
blob, _ = getNS(data)
k = keys.Key.fromString(blob, type='blob')
del self.factory.keys[k.blob()]
self.sendResponse(AGENT_SUCCESS, b'')
def agentc_REMOVE_ALL_IDENTITIES(self, data):
"""
Remove all keys from the agent's collection of identities.
"""
assert data == b''
self.factory.keys = {}
self.sendResponse(AGENT_SUCCESS, b'')
# v1 messages that we ignore because we don't keep v1 keys
# open-ssh sends both v1 and v2 commands, so we have to
# do no-ops for v1 commands or we'll get "bad request" errors
def agentc_REQUEST_RSA_IDENTITIES(self, data):
"""
v1 message for listing RSA1 keys; superseded by
agentc_REQUEST_IDENTITIES, which handles different key types.
"""
self.sendResponse(AGENT_RSA_IDENTITIES_ANSWER, struct.pack('!L', 0))
def agentc_REMOVE_RSA_IDENTITY(self, data):
"""
v1 message for removing RSA1 keys; superseded by
agentc_REMOVE_IDENTITY, which handles different key types.
"""
self.sendResponse(AGENT_SUCCESS, b'')
def agentc_REMOVE_ALL_RSA_IDENTITIES(self, data):
"""
v1 message for removing all RSA1 keys; superseded by
agentc_REMOVE_ALL_IDENTITIES, which handles different key types.
"""
self.sendResponse(AGENT_SUCCESS, b'')
AGENTC_REQUEST_RSA_IDENTITIES = 1
AGENT_RSA_IDENTITIES_ANSWER = 2
AGENT_FAILURE = 5
AGENT_SUCCESS = 6
AGENTC_REMOVE_RSA_IDENTITY = 8
AGENTC_REMOVE_ALL_RSA_IDENTITIES = 9
AGENTC_REQUEST_IDENTITIES = 11
AGENT_IDENTITIES_ANSWER = 12
AGENTC_SIGN_REQUEST = 13
AGENT_SIGN_RESPONSE = 14
AGENTC_ADD_IDENTITY = 17
AGENTC_REMOVE_IDENTITY = 18
AGENTC_REMOVE_ALL_IDENTITIES = 19
messages = {}
for name, value in locals().copy().items():
if name[:7] == 'AGENTC_':
messages[value] = name[7:] # doesn't handle doubles

View file

@ -0,0 +1,320 @@
# -*- test-case-name: twisted.conch.test.test_channel -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
The parent class for all the SSH Channels. Currently implemented channels
are session, direct-tcp, and forwarded-tcp.
Maintainer: Paul Swartz
"""
from __future__ import division, absolute_import
from zope.interface import implementer
from twisted.python import log
from twisted.python.compat import nativeString, intToBytes
from twisted.internet import interfaces
@implementer(interfaces.ITransport)
class SSHChannel(log.Logger):
"""
A class that represents a multiplexed channel over an SSH connection.
The channel has a local window which is the maximum amount of data it will
receive, and a remote which is the maximum amount of data the remote side
will accept. There is also a maximum packet size for any individual data
packet going each way.
@ivar name: the name of the channel.
@type name: L{bytes}
@ivar localWindowSize: the maximum size of the local window in bytes.
@type localWindowSize: L{int}
@ivar localWindowLeft: how many bytes are left in the local window.
@type localWindowLeft: L{int}
@ivar localMaxPacket: the maximum size of packet we will accept in bytes.
@type localMaxPacket: L{int}
@ivar remoteWindowLeft: how many bytes are left in the remote window.
@type remoteWindowLeft: L{int}
@ivar remoteMaxPacket: the maximum size of a packet the remote side will
accept in bytes.
@type remoteMaxPacket: L{int}
@ivar conn: the connection this channel is multiplexed through.
@type conn: L{SSHConnection}
@ivar data: any data to send to the other side when the channel is
requested.
@type data: L{bytes}
@ivar avatar: an avatar for the logged-in user (if a server channel)
@ivar localClosed: True if we aren't accepting more data.
@type localClosed: L{bool}
@ivar remoteClosed: True if the other side isn't accepting more data.
@type remoteClosed: L{bool}
"""
name = None # only needed for client channels
def __init__(self, localWindow = 0, localMaxPacket = 0,
remoteWindow = 0, remoteMaxPacket = 0,
conn = None, data=None, avatar = None):
self.localWindowSize = localWindow or 131072
self.localWindowLeft = self.localWindowSize
self.localMaxPacket = localMaxPacket or 32768
self.remoteWindowLeft = remoteWindow
self.remoteMaxPacket = remoteMaxPacket
self.areWriting = 1
self.conn = conn
self.data = data
self.avatar = avatar
self.specificData = b''
self.buf = b''
self.extBuf = []
self.closing = 0
self.localClosed = 0
self.remoteClosed = 0
self.id = None # gets set later by SSHConnection
def __str__(self):
return nativeString(self.__bytes__())
def __bytes__(self):
"""
Return a byte string representation of the channel
"""
name = self.name
if not name:
name = b'None'
return (b'<SSHChannel ' + name +
b' (lw ' + intToBytes(self.localWindowLeft) +
b' rw ' + intToBytes(self.remoteWindowLeft) +
b')>')
def logPrefix(self):
id = (self.id is not None and str(self.id)) or "unknown"
name = self.name
if name:
name = nativeString(name)
return "SSHChannel %s (%s) on %s" % (name, id,
self.conn.logPrefix())
def channelOpen(self, specificData):
"""
Called when the channel is opened. specificData is any data that the
other side sent us when opening the channel.
@type specificData: L{bytes}
"""
log.msg('channel open')
def openFailed(self, reason):
"""
Called when the open failed for some reason.
reason.desc is a string descrption, reason.code the SSH error code.
@type reason: L{error.ConchError}
"""
log.msg('other side refused open\nreason: %s'% reason)
def addWindowBytes(self, data):
"""
Called when bytes are added to the remote window. By default it clears
the data buffers.
@type data: L{bytes}
"""
self.remoteWindowLeft = self.remoteWindowLeft+data
if not self.areWriting and not self.closing:
self.areWriting = True
self.startWriting()
if self.buf:
b = self.buf
self.buf = b''
self.write(b)
if self.extBuf:
b = self.extBuf
self.extBuf = []
for (type, data) in b:
self.writeExtended(type, data)
def requestReceived(self, requestType, data):
"""
Called when a request is sent to this channel. By default it delegates
to self.request_<requestType>.
If this function returns true, the request succeeded, otherwise it
failed.
@type requestType: L{bytes}
@type data: L{bytes}
@rtype: L{bool}
"""
foo = nativeString(requestType.replace(b'-', b'_'))
f = getattr(self, 'request_%s'%foo, None)
if f:
return f(data)
log.msg('unhandled request for %s'%requestType)
return 0
def dataReceived(self, data):
"""
Called when we receive data.
@type data: L{bytes}
"""
log.msg('got data %s'%repr(data))
def extReceived(self, dataType, data):
"""
Called when we receive extended data (usually standard error).
@type dataType: L{int}
@type data: L{str}
"""
log.msg('got extended data %s %s'%(dataType, repr(data)))
def eofReceived(self):
"""
Called when the other side will send no more data.
"""
log.msg('remote eof')
def closeReceived(self):
"""
Called when the other side has closed the channel.
"""
log.msg('remote close')
self.loseConnection()
def closed(self):
"""
Called when the channel is closed. This means that both our side and
the remote side have closed the channel.
"""
log.msg('closed')
def write(self, data):
"""
Write some data to the channel. If there is not enough remote window
available, buffer until it is. Otherwise, split the data into
packets of length remoteMaxPacket and send them.
@type data: L{bytes}
"""
if self.buf:
self.buf += data
return
top = len(data)
if top > self.remoteWindowLeft:
data, self.buf = (data[:self.remoteWindowLeft],
data[self.remoteWindowLeft:])
self.areWriting = 0
self.stopWriting()
top = self.remoteWindowLeft
rmp = self.remoteMaxPacket
write = self.conn.sendData
r = range(0, top, rmp)
for offset in r:
write(self, data[offset: offset+rmp])
self.remoteWindowLeft -= top
if self.closing and not self.buf:
self.loseConnection() # try again
def writeExtended(self, dataType, data):
"""
Send extended data to this channel. If there is not enough remote
window available, buffer until there is. Otherwise, split the data
into packets of length remoteMaxPacket and send them.
@type dataType: L{int}
@type data: L{bytes}
"""
if self.extBuf:
if self.extBuf[-1][0] == dataType:
self.extBuf[-1][1] += data
else:
self.extBuf.append([dataType, data])
return
if len(data) > self.remoteWindowLeft:
data, self.extBuf = (data[:self.remoteWindowLeft],
[[dataType, data[self.remoteWindowLeft:]]])
self.areWriting = 0
self.stopWriting()
while len(data) > self.remoteMaxPacket:
self.conn.sendExtendedData(self, dataType,
data[:self.remoteMaxPacket])
data = data[self.remoteMaxPacket:]
self.remoteWindowLeft -= self.remoteMaxPacket
if data:
self.conn.sendExtendedData(self, dataType, data)
self.remoteWindowLeft -= len(data)
if self.closing:
self.loseConnection() # try again
def writeSequence(self, data):
"""
Part of the Transport interface. Write a list of strings to the
channel.
@type data: C{list} of L{str}
"""
self.write(b''.join(data))
def loseConnection(self):
"""
Close the channel if there is no buferred data. Otherwise, note the
request and return.
"""
self.closing = 1
if not self.buf and not self.extBuf:
self.conn.sendClose(self)
def getPeer(self):
"""
See: L{ITransport.getPeer}
@return: The remote address of this connection.
@rtype: L{SSHTransportAddress}.
"""
return self.conn.transport.getPeer()
def getHost(self):
"""
See: L{ITransport.getHost}
@return: An address describing this side of the connection.
@rtype: L{SSHTransportAddress}.
"""
return self.conn.transport.getHost()
def stopWriting(self):
"""
Called when the remote buffer is full, as a hint to stop writing.
This can be ignored, but it can be helpful.
"""
def startWriting(self):
"""
Called when the remote buffer has more room, as a hint to continue
writing.
"""

View file

@ -0,0 +1,93 @@
# -*- test-case-name: twisted.conch.test.test_ssh -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Common functions for the SSH classes.
Maintainer: Paul Swartz
"""
from __future__ import absolute_import, division
import struct
from cryptography.utils import int_from_bytes, int_to_bytes
from twisted.python.compat import unicode
from twisted.python.deprecate import deprecated
from twisted.python.versions import Version
__all__ = ["NS", "getNS", "MP", "getMP", "ffs"]
def NS(t):
"""
net string
"""
if isinstance(t, unicode):
t = t.encode("utf-8")
return struct.pack('!L', len(t)) + t
def getNS(s, count=1):
"""
get net string
"""
ns = []
c = 0
for i in range(count):
l, = struct.unpack('!L', s[c:c + 4])
ns.append(s[c + 4:4 + l + c])
c += 4 + l
return tuple(ns) + (s[c:],)
def MP(number):
if number == 0:
return b'\000' * 4
assert number > 0
bn = int_to_bytes(number)
if ord(bn[0:1]) & 128:
bn = b'\000' + bn
return struct.pack('>L', len(bn)) + bn
def getMP(data, count=1):
"""
Get multiple precision integer out of the string. A multiple precision
integer is stored as a 4-byte length followed by length bytes of the
integer. If count is specified, get count integers out of the string.
The return value is a tuple of count integers followed by the rest of
the data.
"""
mp = []
c = 0
for i in range(count):
length, = struct.unpack('>L', data[c:c + 4])
mp.append(int_from_bytes(data[c + 4:c + 4 + length], 'big'))
c += 4 + length
return tuple(mp) + (data[c:],)
def ffs(c, s):
"""
first from second
goes through the first list, looking for items in the second, returns the first one
"""
for i in c:
if i in s:
return i
@deprecated(Version("Twisted", 16, 5, 0))
def install():
# This used to install gmpy, but is technically public API, so just do
# nothing.
pass

View file

@ -0,0 +1,653 @@
# -*- test-case-name: twisted.conch.test.test_connection -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
This module contains the implementation of the ssh-connection service, which
allows access to the shell and port-forwarding.
Maintainer: Paul Swartz
"""
from __future__ import division, absolute_import
import string
import struct
import twisted.internet.error
from twisted.conch.ssh import service, common
from twisted.conch import error
from twisted.internet import defer
from twisted.python import log
from twisted.python.compat import (
nativeString, networkString, long, _bytesChr as chr)
class SSHConnection(service.SSHService):
"""
An implementation of the 'ssh-connection' service. It is used to
multiplex multiple channels over the single SSH connection.
@ivar localChannelID: the next number to use as a local channel ID.
@type localChannelID: L{int}
@ivar channels: a L{dict} mapping a local channel ID to C{SSHChannel}
subclasses.
@type channels: L{dict}
@ivar localToRemoteChannel: a L{dict} mapping a local channel ID to a
remote channel ID.
@type localToRemoteChannel: L{dict}
@ivar channelsToRemoteChannel: a L{dict} mapping a C{SSHChannel} subclass
to remote channel ID.
@type channelsToRemoteChannel: L{dict}
@ivar deferreds: a L{dict} mapping a local channel ID to a C{list} of
C{Deferreds} for outstanding channel requests. Also, the 'global'
key stores the C{list} of pending global request C{Deferred}s.
"""
name = b'ssh-connection'
def __init__(self):
self.localChannelID = 0 # this is the current # to use for channel ID
self.localToRemoteChannel = {} # local channel ID -> remote channel ID
self.channels = {} # local channel ID -> subclass of SSHChannel
self.channelsToRemoteChannel = {} # subclass of SSHChannel ->
# remote channel ID
self.deferreds = {"global": []} # local channel -> list of deferreds
# for pending requests or 'global' -> list of
# deferreds for global requests
self.transport = None # gets set later
def serviceStarted(self):
if hasattr(self.transport, 'avatar'):
self.transport.avatar.conn = self
def serviceStopped(self):
"""
Called when the connection is stopped.
"""
# Close any fully open channels
for channel in list(self.channelsToRemoteChannel.keys()):
self.channelClosed(channel)
# Indicate failure to any channels that were in the process of
# opening but not yet open.
while self.channels:
(_, channel) = self.channels.popitem()
log.callWithLogger(channel, channel.openFailed,
twisted.internet.error.ConnectionLost())
# Errback any unfinished global requests.
self._cleanupGlobalDeferreds()
def _cleanupGlobalDeferreds(self):
"""
All pending requests that have returned a deferred must be errbacked
when this service is stopped, otherwise they might be left uncalled and
uncallable.
"""
for d in self.deferreds["global"]:
d.errback(error.ConchError("Connection stopped."))
del self.deferreds["global"][:]
# packet methods
def ssh_GLOBAL_REQUEST(self, packet):
"""
The other side has made a global request. Payload::
string request type
bool want reply
<request specific data>
This dispatches to self.gotGlobalRequest.
"""
requestType, rest = common.getNS(packet)
wantReply, rest = ord(rest[0:1]), rest[1:]
ret = self.gotGlobalRequest(requestType, rest)
if wantReply:
reply = MSG_REQUEST_FAILURE
data = b''
if ret:
reply = MSG_REQUEST_SUCCESS
if isinstance(ret, (tuple, list)):
data = ret[1]
self.transport.sendPacket(reply, data)
def ssh_REQUEST_SUCCESS(self, packet):
"""
Our global request succeeded. Get the appropriate Deferred and call
it back with the packet we received.
"""
log.msg('RS')
self.deferreds['global'].pop(0).callback(packet)
def ssh_REQUEST_FAILURE(self, packet):
"""
Our global request failed. Get the appropriate Deferred and errback
it with the packet we received.
"""
log.msg('RF')
self.deferreds['global'].pop(0).errback(
error.ConchError('global request failed', packet))
def ssh_CHANNEL_OPEN(self, packet):
"""
The other side wants to get a channel. Payload::
string channel name
uint32 remote channel number
uint32 remote window size
uint32 remote maximum packet size
<channel specific data>
We get a channel from self.getChannel(), give it a local channel number
and notify the other side. Then notify the channel by calling its
channelOpen method.
"""
channelType, rest = common.getNS(packet)
senderChannel, windowSize, maxPacket = struct.unpack('>3L', rest[:12])
packet = rest[12:]
try:
channel = self.getChannel(channelType, windowSize, maxPacket,
packet)
localChannel = self.localChannelID
self.localChannelID += 1
channel.id = localChannel
self.channels[localChannel] = channel
self.channelsToRemoteChannel[channel] = senderChannel
self.localToRemoteChannel[localChannel] = senderChannel
self.transport.sendPacket(MSG_CHANNEL_OPEN_CONFIRMATION,
struct.pack('>4L', senderChannel, localChannel,
channel.localWindowSize,
channel.localMaxPacket)+channel.specificData)
log.callWithLogger(channel, channel.channelOpen, packet)
except Exception as e:
log.err(e, 'channel open failed')
if isinstance(e, error.ConchError):
textualInfo, reason = e.args
if isinstance(textualInfo, (int, long)):
# See #3657 and #3071
textualInfo, reason = reason, textualInfo
else:
reason = OPEN_CONNECT_FAILED
textualInfo = "unknown failure"
self.transport.sendPacket(
MSG_CHANNEL_OPEN_FAILURE,
struct.pack('>2L', senderChannel, reason) +
common.NS(networkString(textualInfo)) + common.NS(b''))
def ssh_CHANNEL_OPEN_CONFIRMATION(self, packet):
"""
The other side accepted our MSG_CHANNEL_OPEN request. Payload::
uint32 local channel number
uint32 remote channel number
uint32 remote window size
uint32 remote maximum packet size
<channel specific data>
Find the channel using the local channel number and notify its
channelOpen method.
"""
(localChannel, remoteChannel, windowSize,
maxPacket) = struct.unpack('>4L', packet[: 16])
specificData = packet[16:]
channel = self.channels[localChannel]
channel.conn = self
self.localToRemoteChannel[localChannel] = remoteChannel
self.channelsToRemoteChannel[channel] = remoteChannel
channel.remoteWindowLeft = windowSize
channel.remoteMaxPacket = maxPacket
log.callWithLogger(channel, channel.channelOpen, specificData)
def ssh_CHANNEL_OPEN_FAILURE(self, packet):
"""
The other side did not accept our MSG_CHANNEL_OPEN request. Payload::
uint32 local channel number
uint32 reason code
string reason description
Find the channel using the local channel number and notify it by
calling its openFailed() method.
"""
localChannel, reasonCode = struct.unpack('>2L', packet[:8])
reasonDesc = common.getNS(packet[8:])[0]
channel = self.channels[localChannel]
del self.channels[localChannel]
channel.conn = self
reason = error.ConchError(reasonDesc, reasonCode)
log.callWithLogger(channel, channel.openFailed, reason)
def ssh_CHANNEL_WINDOW_ADJUST(self, packet):
"""
The other side is adding bytes to its window. Payload::
uint32 local channel number
uint32 bytes to add
Call the channel's addWindowBytes() method to add new bytes to the
remote window.
"""
localChannel, bytesToAdd = struct.unpack('>2L', packet[:8])
channel = self.channels[localChannel]
log.callWithLogger(channel, channel.addWindowBytes, bytesToAdd)
def ssh_CHANNEL_DATA(self, packet):
"""
The other side is sending us data. Payload::
uint32 local channel number
string data
Check to make sure the other side hasn't sent too much data (more
than what's in the window, or more than the maximum packet size). If
they have, close the channel. Otherwise, decrease the available
window and pass the data to the channel's dataReceived().
"""
localChannel, dataLength = struct.unpack('>2L', packet[:8])
channel = self.channels[localChannel]
# XXX should this move to dataReceived to put client in charge?
if (dataLength > channel.localWindowLeft or
dataLength > channel.localMaxPacket): # more data than we want
log.callWithLogger(channel, log.msg, 'too much data')
self.sendClose(channel)
return
#packet = packet[:channel.localWindowLeft+4]
data = common.getNS(packet[4:])[0]
channel.localWindowLeft -= dataLength
if channel.localWindowLeft < channel.localWindowSize // 2:
self.adjustWindow(channel, channel.localWindowSize - \
channel.localWindowLeft)
#log.msg('local window left: %s/%s' % (channel.localWindowLeft,
# channel.localWindowSize))
log.callWithLogger(channel, channel.dataReceived, data)
def ssh_CHANNEL_EXTENDED_DATA(self, packet):
"""
The other side is sending us exteneded data. Payload::
uint32 local channel number
uint32 type code
string data
Check to make sure the other side hasn't sent too much data (more
than what's in the window, or than the maximum packet size). If
they have, close the channel. Otherwise, decrease the available
window and pass the data and type code to the channel's
extReceived().
"""
localChannel, typeCode, dataLength = struct.unpack('>3L', packet[:12])
channel = self.channels[localChannel]
if (dataLength > channel.localWindowLeft or
dataLength > channel.localMaxPacket):
log.callWithLogger(channel, log.msg, 'too much extdata')
self.sendClose(channel)
return
data = common.getNS(packet[8:])[0]
channel.localWindowLeft -= dataLength
if channel.localWindowLeft < channel.localWindowSize // 2:
self.adjustWindow(channel, channel.localWindowSize -
channel.localWindowLeft)
log.callWithLogger(channel, channel.extReceived, typeCode, data)
def ssh_CHANNEL_EOF(self, packet):
"""
The other side is not sending any more data. Payload::
uint32 local channel number
Notify the channel by calling its eofReceived() method.
"""
localChannel = struct.unpack('>L', packet[:4])[0]
channel = self.channels[localChannel]
log.callWithLogger(channel, channel.eofReceived)
def ssh_CHANNEL_CLOSE(self, packet):
"""
The other side is closing its end; it does not want to receive any
more data. Payload::
uint32 local channel number
Notify the channnel by calling its closeReceived() method. If
the channel has also sent a close message, call self.channelClosed().
"""
localChannel = struct.unpack('>L', packet[:4])[0]
channel = self.channels[localChannel]
log.callWithLogger(channel, channel.closeReceived)
channel.remoteClosed = True
if channel.localClosed and channel.remoteClosed:
self.channelClosed(channel)
def ssh_CHANNEL_REQUEST(self, packet):
"""
The other side is sending a request to a channel. Payload::
uint32 local channel number
string request name
bool want reply
<request specific data>
Pass the message to the channel's requestReceived method. If the
other side wants a reply, add callbacks which will send the
reply.
"""
localChannel = struct.unpack('>L', packet[:4])[0]
requestType, rest = common.getNS(packet[4:])
wantReply = ord(rest[0:1])
channel = self.channels[localChannel]
d = defer.maybeDeferred(log.callWithLogger, channel,
channel.requestReceived, requestType, rest[1:])
if wantReply:
d.addCallback(self._cbChannelRequest, localChannel)
d.addErrback(self._ebChannelRequest, localChannel)
return d
def _cbChannelRequest(self, result, localChannel):
"""
Called back if the other side wanted a reply to a channel request. If
the result is true, send a MSG_CHANNEL_SUCCESS. Otherwise, raise
a C{error.ConchError}
@param result: the value returned from the channel's requestReceived()
method. If it's False, the request failed.
@type result: L{bool}
@param localChannel: the local channel ID of the channel to which the
request was made.
@type localChannel: L{int}
@raises ConchError: if the result is False.
"""
if not result:
raise error.ConchError('failed request')
self.transport.sendPacket(MSG_CHANNEL_SUCCESS, struct.pack('>L',
self.localToRemoteChannel[localChannel]))
def _ebChannelRequest(self, result, localChannel):
"""
Called if the other wisde wanted a reply to the channel requeset and
the channel request failed.
@param result: a Failure, but it's not used.
@param localChannel: the local channel ID of the channel to which the
request was made.
@type localChannel: L{int}
"""
self.transport.sendPacket(MSG_CHANNEL_FAILURE, struct.pack('>L',
self.localToRemoteChannel[localChannel]))
def ssh_CHANNEL_SUCCESS(self, packet):
"""
Our channel request to the other side succeeded. Payload::
uint32 local channel number
Get the C{Deferred} out of self.deferreds and call it back.
"""
localChannel = struct.unpack('>L', packet[:4])[0]
if self.deferreds.get(localChannel):
d = self.deferreds[localChannel].pop(0)
log.callWithLogger(self.channels[localChannel],
d.callback, '')
def ssh_CHANNEL_FAILURE(self, packet):
"""
Our channel request to the other side failed. Payload::
uint32 local channel number
Get the C{Deferred} out of self.deferreds and errback it with a
C{error.ConchError}.
"""
localChannel = struct.unpack('>L', packet[:4])[0]
if self.deferreds.get(localChannel):
d = self.deferreds[localChannel].pop(0)
log.callWithLogger(self.channels[localChannel],
d.errback,
error.ConchError('channel request failed'))
# methods for users of the connection to call
def sendGlobalRequest(self, request, data, wantReply=0):
"""
Send a global request for this connection. Current this is only used
for remote->local TCP forwarding.
@type request: L{bytes}
@type data: L{bytes}
@type wantReply: L{bool}
@rtype C{Deferred}/L{None}
"""
self.transport.sendPacket(MSG_GLOBAL_REQUEST,
common.NS(request)
+ (wantReply and b'\xff' or b'\x00')
+ data)
if wantReply:
d = defer.Deferred()
self.deferreds['global'].append(d)
return d
def openChannel(self, channel, extra=b''):
"""
Open a new channel on this connection.
@type channel: subclass of C{SSHChannel}
@type extra: L{bytes}
"""
log.msg('opening channel %s with %s %s'%(self.localChannelID,
channel.localWindowSize, channel.localMaxPacket))
self.transport.sendPacket(MSG_CHANNEL_OPEN, common.NS(channel.name)
+ struct.pack('>3L', self.localChannelID,
channel.localWindowSize, channel.localMaxPacket)
+ extra)
channel.id = self.localChannelID
self.channels[self.localChannelID] = channel
self.localChannelID += 1
def sendRequest(self, channel, requestType, data, wantReply=0):
"""
Send a request to a channel.
@type channel: subclass of C{SSHChannel}
@type requestType: L{bytes}
@type data: L{bytes}
@type wantReply: L{bool}
@rtype C{Deferred}/L{None}
"""
if channel.localClosed:
return
log.msg('sending request %r' % (requestType))
self.transport.sendPacket(MSG_CHANNEL_REQUEST, struct.pack('>L',
self.channelsToRemoteChannel[channel])
+ common.NS(requestType)+chr(wantReply)
+ data)
if wantReply:
d = defer.Deferred()
self.deferreds.setdefault(channel.id, []).append(d)
return d
def adjustWindow(self, channel, bytesToAdd):
"""
Tell the other side that we will receive more data. This should not
normally need to be called as it is managed automatically.
@type channel: subclass of L{SSHChannel}
@type bytesToAdd: L{int}
"""
if channel.localClosed:
return # we're already closed
self.transport.sendPacket(MSG_CHANNEL_WINDOW_ADJUST, struct.pack('>2L',
self.channelsToRemoteChannel[channel],
bytesToAdd))
log.msg('adding %i to %i in channel %i' % (bytesToAdd,
channel.localWindowLeft, channel.id))
channel.localWindowLeft += bytesToAdd
def sendData(self, channel, data):
"""
Send data to a channel. This should not normally be used: instead use
channel.write(data) as it manages the window automatically.
@type channel: subclass of L{SSHChannel}
@type data: L{bytes}
"""
if channel.localClosed:
return # we're already closed
self.transport.sendPacket(MSG_CHANNEL_DATA, struct.pack('>L',
self.channelsToRemoteChannel[channel]) +
common.NS(data))
def sendExtendedData(self, channel, dataType, data):
"""
Send extended data to a channel. This should not normally be used:
instead use channel.writeExtendedData(data, dataType) as it manages
the window automatically.
@type channel: subclass of L{SSHChannel}
@type dataType: L{int}
@type data: L{bytes}
"""
if channel.localClosed:
return # we're already closed
self.transport.sendPacket(MSG_CHANNEL_EXTENDED_DATA, struct.pack('>2L',
self.channelsToRemoteChannel[channel],dataType) \
+ common.NS(data))
def sendEOF(self, channel):
"""
Send an EOF (End of File) for a channel.
@type channel: subclass of L{SSHChannel}
"""
if channel.localClosed:
return # we're already closed
log.msg('sending eof')
self.transport.sendPacket(MSG_CHANNEL_EOF, struct.pack('>L',
self.channelsToRemoteChannel[channel]))
def sendClose(self, channel):
"""
Close a channel.
@type channel: subclass of L{SSHChannel}
"""
if channel.localClosed:
return # we're already closed
log.msg('sending close %i' % channel.id)
self.transport.sendPacket(MSG_CHANNEL_CLOSE, struct.pack('>L',
self.channelsToRemoteChannel[channel]))
channel.localClosed = True
if channel.localClosed and channel.remoteClosed:
self.channelClosed(channel)
# methods to override
def getChannel(self, channelType, windowSize, maxPacket, data):
"""
The other side requested a channel of some sort.
channelType is the type of channel being requested,
windowSize is the initial size of the remote window,
maxPacket is the largest packet we should send,
data is any other packet data (often nothing).
We return a subclass of L{SSHChannel}.
By default, this dispatches to a method 'channel_channelType' with any
non-alphanumerics in the channelType replace with _'s. If it cannot
find a suitable method, it returns an OPEN_UNKNOWN_CHANNEL_TYPE error.
The method is called with arguments of windowSize, maxPacket, data.
@type channelType: L{bytes}
@type windowSize: L{int}
@type maxPacket: L{int}
@type data: L{bytes}
@rtype: subclass of L{SSHChannel}/L{tuple}
"""
log.msg('got channel %r request' % (channelType))
if hasattr(self.transport, "avatar"): # this is a server!
chan = self.transport.avatar.lookupChannel(channelType,
windowSize,
maxPacket,
data)
else:
channelType = channelType.translate(TRANSLATE_TABLE)
attr = 'channel_%s' % nativeString(channelType)
f = getattr(self, attr, None)
if f is not None:
chan = f(windowSize, maxPacket, data)
else:
chan = None
if chan is None:
raise error.ConchError('unknown channel',
OPEN_UNKNOWN_CHANNEL_TYPE)
else:
chan.conn = self
return chan
def gotGlobalRequest(self, requestType, data):
"""
We got a global request. pretty much, this is just used by the client
to request that we forward a port from the server to the client.
Returns either:
- 1: request accepted
- 1, <data>: request accepted with request specific data
- 0: request denied
By default, this dispatches to a method 'global_requestType' with
-'s in requestType replaced with _'s. The found method is passed data.
If this method cannot be found, this method returns 0. Otherwise, it
returns the return value of that method.
@type requestType: L{bytes}
@type data: L{bytes}
@rtype: L{int}/L{tuple}
"""
log.msg('got global %s request' % requestType)
if hasattr(self.transport, 'avatar'): # this is a server!
return self.transport.avatar.gotGlobalRequest(requestType, data)
requestType = nativeString(requestType.replace(b'-',b'_'))
f = getattr(self, 'global_%s' % requestType, None)
if not f:
return 0
return f(data)
def channelClosed(self, channel):
"""
Called when a channel is closed.
It clears the local state related to the channel, and calls
channel.closed().
MAKE SURE YOU CALL THIS METHOD, even if you subclass L{SSHConnection}.
If you don't, things will break mysteriously.
@type channel: L{SSHChannel}
"""
if channel in self.channelsToRemoteChannel: # actually open
channel.localClosed = channel.remoteClosed = True
del self.localToRemoteChannel[channel.id]
del self.channels[channel.id]
del self.channelsToRemoteChannel[channel]
for d in self.deferreds.pop(channel.id, []):
d.errback(error.ConchError("Channel closed."))
log.callWithLogger(channel, channel.closed)
MSG_GLOBAL_REQUEST = 80
MSG_REQUEST_SUCCESS = 81
MSG_REQUEST_FAILURE = 82
MSG_CHANNEL_OPEN = 90
MSG_CHANNEL_OPEN_CONFIRMATION = 91
MSG_CHANNEL_OPEN_FAILURE = 92
MSG_CHANNEL_WINDOW_ADJUST = 93
MSG_CHANNEL_DATA = 94
MSG_CHANNEL_EXTENDED_DATA = 95
MSG_CHANNEL_EOF = 96
MSG_CHANNEL_CLOSE = 97
MSG_CHANNEL_REQUEST = 98
MSG_CHANNEL_SUCCESS = 99
MSG_CHANNEL_FAILURE = 100
OPEN_ADMINISTRATIVELY_PROHIBITED = 1
OPEN_CONNECT_FAILED = 2
OPEN_UNKNOWN_CHANNEL_TYPE = 3
OPEN_RESOURCE_SHORTAGE = 4
EXTENDED_DATA_STDERR = 1
messages = {}
for name, value in locals().copy().items():
if name[:4] == 'MSG_':
messages[value] = name # Doesn't handle doubles
alphanums = networkString(string.ascii_letters + string.digits)
TRANSLATE_TABLE = b''.join([chr(i) in alphanums and chr(i) or b'_'
for i in range(256)])
SSHConnection.protocolMessages = messages

View file

@ -0,0 +1,123 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
A Factory for SSH servers.
See also L{twisted.conch.openssh_compat.factory} for OpenSSH compatibility.
Maintainer: Paul Swartz
"""
from __future__ import division, absolute_import
from twisted.internet import protocol
from twisted.python import log
from twisted.conch import error
from twisted.conch.ssh import (_kex, transport, userauth, connection)
import random
class SSHFactory(protocol.Factory):
"""
A Factory for SSH servers.
"""
protocol = transport.SSHServerTransport
services = {
b'ssh-userauth':userauth.SSHUserAuthServer,
b'ssh-connection':connection.SSHConnection
}
def startFactory(self):
"""
Check for public and private keys.
"""
if not hasattr(self,'publicKeys'):
self.publicKeys = self.getPublicKeys()
if not hasattr(self,'privateKeys'):
self.privateKeys = self.getPrivateKeys()
if not self.publicKeys or not self.privateKeys:
raise error.ConchError('no host keys, failing')
if not hasattr(self,'primes'):
self.primes = self.getPrimes()
def buildProtocol(self, addr):
"""
Create an instance of the server side of the SSH protocol.
@type addr: L{twisted.internet.interfaces.IAddress} provider
@param addr: The address at which the server will listen.
@rtype: L{twisted.conch.ssh.transport.SSHServerTransport}
@return: The built transport.
"""
t = protocol.Factory.buildProtocol(self, addr)
t.supportedPublicKeys = self.privateKeys.keys()
if not self.primes:
log.msg('disabling non-fixed-group key exchange algorithms '
'because we cannot find moduli file')
t.supportedKeyExchanges = [
kexAlgorithm for kexAlgorithm in t.supportedKeyExchanges
if _kex.isFixedGroup(kexAlgorithm) or
_kex.isEllipticCurve(kexAlgorithm)]
return t
def getPublicKeys(self):
"""
Called when the factory is started to get the public portions of the
servers host keys. Returns a dictionary mapping SSH key types to
public key strings.
@rtype: L{dict}
"""
raise NotImplementedError('getPublicKeys unimplemented')
def getPrivateKeys(self):
"""
Called when the factory is started to get the private portions of the
servers host keys. Returns a dictionary mapping SSH key types to
L{twisted.conch.ssh.keys.Key} objects.
@rtype: L{dict}
"""
raise NotImplementedError('getPrivateKeys unimplemented')
def getPrimes(self):
"""
Called when the factory is started to get Diffie-Hellman generators and
primes to use. Returns a dictionary mapping number of bits to lists
of tuple of (generator, prime).
@rtype: L{dict}
"""
def getDHPrime(self, bits):
"""
Return a tuple of (g, p) for a Diffe-Hellman process, with p being as
close to bits bits as possible.
@type bits: L{int}
@rtype: L{tuple}
"""
primesKeys = sorted(self.primes.keys(), key=lambda i: abs(i - bits))
realBits = primesKeys[0]
return random.choice(self.primes[realBits])
def getService(self, transport, service):
"""
Return a class to use as a service for the given transport.
@type transport: L{transport.SSHServerTransport}
@type service: L{bytes}
@rtype: subclass of L{service.SSHService}
"""
if service == b'ssh-userauth' or hasattr(transport, 'avatar'):
return self.services[service]

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,269 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
This module contains the implementation of the TCP forwarding, which allows
clients and servers to forward arbitrary TCP data across the connection.
Maintainer: Paul Swartz
"""
from __future__ import division, absolute_import
import struct
from twisted.internet import protocol, reactor
from twisted.internet.endpoints import HostnameEndpoint, connectProtocol
from twisted.python import log
from twisted.python.compat import _PY3, unicode
from twisted.conch.ssh import common, channel
class SSHListenForwardingFactory(protocol.Factory):
def __init__(self, connection, hostport, klass):
self.conn = connection
self.hostport = hostport # tuple
self.klass = klass
def buildProtocol(self, addr):
channel = self.klass(conn = self.conn)
client = SSHForwardingClient(channel)
channel.client = client
addrTuple = (addr.host, addr.port)
channelOpenData = packOpen_direct_tcpip(self.hostport, addrTuple)
self.conn.openChannel(channel, channelOpenData)
return client
class SSHListenForwardingChannel(channel.SSHChannel):
def channelOpen(self, specificData):
log.msg('opened forwarding channel %s' % self.id)
if len(self.client.buf)>1:
b = self.client.buf[1:]
self.write(b)
self.client.buf = b''
def openFailed(self, reason):
self.closed()
def dataReceived(self, data):
self.client.transport.write(data)
def eofReceived(self):
self.client.transport.loseConnection()
def closed(self):
if hasattr(self, 'client'):
log.msg('closing local forwarding channel %s' % self.id)
self.client.transport.loseConnection()
del self.client
class SSHListenClientForwardingChannel(SSHListenForwardingChannel):
name = b'direct-tcpip'
class SSHListenServerForwardingChannel(SSHListenForwardingChannel):
name = b'forwarded-tcpip'
class SSHConnectForwardingChannel(channel.SSHChannel):
"""
Channel used for handling server side forwarding request.
It acts as a client for the remote forwarding destination.
@ivar hostport: C{(host, port)} requested by client as forwarding
destination.
@type hostport: L{tuple} or a C{sequence}
@ivar client: Protocol connected to the forwarding destination.
@type client: L{protocol.Protocol}
@ivar clientBuf: Data received while forwarding channel is not yet
connected.
@type clientBuf: L{bytes}
@var _reactor: Reactor used for TCP connections.
@type _reactor: A reactor.
@ivar _channelOpenDeferred: Deferred used in testing to check the
result of C{channelOpen}.
@type _channelOpenDeferred: L{twisted.internet.defer.Deferred}
"""
_reactor = reactor
def __init__(self, hostport, *args, **kw):
channel.SSHChannel.__init__(self, *args, **kw)
self.hostport = hostport
self.client = None
self.clientBuf = b''
def channelOpen(self, specificData):
"""
See: L{channel.SSHChannel}
"""
log.msg("connecting to %s:%i" % self.hostport)
ep = HostnameEndpoint(
self._reactor, self.hostport[0], self.hostport[1])
d = connectProtocol(ep, SSHForwardingClient(self))
d.addCallbacks(self._setClient, self._close)
self._channelOpenDeferred = d
def _setClient(self, client):
"""
Called when the connection was established to the forwarding
destination.
@param client: Client protocol connected to the forwarding destination.
@type client: L{protocol.Protocol}
"""
self.client = client
log.msg("connected to %s:%i" % self.hostport)
if self.clientBuf:
self.client.transport.write(self.clientBuf)
self.clientBuf = None
if self.client.buf[1:]:
self.write(self.client.buf[1:])
self.client.buf = b''
def _close(self, reason):
"""
Called when failed to connect to the forwarding destination.
@param reason: Reason why connection failed.
@type reason: L{twisted.python.failure.Failure}
"""
log.msg("failed to connect: %s" % reason)
self.loseConnection()
def dataReceived(self, data):
"""
See: L{channel.SSHChannel}
"""
if self.client:
self.client.transport.write(data)
else:
self.clientBuf += data
def closed(self):
"""
See: L{channel.SSHChannel}
"""
if self.client:
log.msg('closed remote forwarding channel %s' % self.id)
if self.client.channel:
self.loseConnection()
self.client.transport.loseConnection()
del self.client
def openConnectForwardingClient(remoteWindow, remoteMaxPacket, data, avatar):
remoteHP, origHP = unpackOpen_direct_tcpip(data)
return SSHConnectForwardingChannel(remoteHP,
remoteWindow=remoteWindow,
remoteMaxPacket=remoteMaxPacket,
avatar=avatar)
class SSHForwardingClient(protocol.Protocol):
def __init__(self, channel):
self.channel = channel
self.buf = b'\000'
def dataReceived(self, data):
if self.buf:
self.buf += data
else:
self.channel.write(data)
def connectionLost(self, reason):
if self.channel:
self.channel.loseConnection()
self.channel = None
def packOpen_direct_tcpip(destination, source):
"""
Pack the data suitable for sending in a CHANNEL_OPEN packet.
@type destination: L{tuple}
@param destination: A tuple of the (host, port) of the destination host.
@type source: L{tuple}
@param source: A tuple of the (host, port) of the source host.
"""
(connHost, connPort) = destination
(origHost, origPort) = source
if isinstance(connHost, unicode):
connHost = connHost.encode("utf-8")
if isinstance(origHost, unicode):
origHost = origHost.encode("utf-8")
conn = common.NS(connHost) + struct.pack('>L', connPort)
orig = common.NS(origHost) + struct.pack('>L', origPort)
return conn + orig
packOpen_forwarded_tcpip = packOpen_direct_tcpip
def unpackOpen_direct_tcpip(data):
"""Unpack the data to a usable format.
"""
connHost, rest = common.getNS(data)
if _PY3 and isinstance(connHost, bytes):
connHost = connHost.decode("utf-8")
connPort = int(struct.unpack('>L', rest[:4])[0])
origHost, rest = common.getNS(rest[4:])
if _PY3 and isinstance(origHost, bytes):
origHost = origHost.decode("utf-8")
origPort = int(struct.unpack('>L', rest[:4])[0])
return (connHost, connPort), (origHost, origPort)
unpackOpen_forwarded_tcpip = unpackOpen_direct_tcpip
def packGlobal_tcpip_forward(peer):
"""
Pack the data for tcpip forwarding.
@param peer: A tuple of the (host, port) .
@type peer: L{tuple}
"""
(host, port) = peer
return common.NS(host) + struct.pack('>L', port)
def unpackGlobal_tcpip_forward(data):
host, rest = common.getNS(data)
if _PY3 and isinstance(host, bytes):
host = host.decode("utf-8")
port = int(struct.unpack('>L', rest[:4])[0])
return host, port
"""This is how the data -> eof -> close stuff /should/ work.
debug3: channel 1: waiting for connection
debug1: channel 1: connected
debug1: channel 1: read<=0 rfd 7 len 0
debug1: channel 1: read failed
debug1: channel 1: close_read
debug1: channel 1: input open -> drain
debug1: channel 1: ibuf empty
debug1: channel 1: send eof
debug1: channel 1: input drain -> closed
debug1: channel 1: rcvd eof
debug1: channel 1: output open -> drain
debug1: channel 1: obuf empty
debug1: channel 1: close_write
debug1: channel 1: output drain -> closed
debug1: channel 1: rcvd close
debug3: channel 1: will not send data after close
debug1: channel 1: send close
debug1: channel 1: is dead
"""

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,48 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
The parent class for all the SSH services. Currently implemented services
are ssh-userauth and ssh-connection.
Maintainer: Paul Swartz
"""
from __future__ import division, absolute_import
from twisted.python import log
class SSHService(log.Logger):
name = None # this is the ssh name for the service
protocolMessages = {} # these map #'s -> protocol names
transport = None # gets set later
def serviceStarted(self):
"""
called when the service is active on the transport.
"""
def serviceStopped(self):
"""
called when the service is stopped, either by the connection ending
or by another service being started
"""
def logPrefix(self):
return "SSHService %r on %s" % (self.name,
self.transport.transport.logPrefix())
def packetReceived(self, messageNum, packet):
"""
called when we receive a packet on the transport
"""
#print self.protocolMessages
if messageNum in self.protocolMessages:
messageType = self.protocolMessages[messageNum]
f = getattr(self,'ssh_%s' % messageType[4:],
None)
if f is not None:
return f(packet)
log.msg("couldn't handle %r" % messageNum)
log.msg(repr(packet))
self.transport.sendUnimplemented()

View file

@ -0,0 +1,362 @@
# -*- test-case-name: twisted.conch.test.test_session -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
This module contains the implementation of SSHSession, which (by default)
allows access to a shell and a python interpreter over SSH.
Maintainer: Paul Swartz
"""
from __future__ import division, absolute_import
import struct
import signal
import sys
import os
from zope.interface import implementer
from twisted.internet import interfaces, protocol
from twisted.python import log
from twisted.python.compat import _bytesChr as chr, networkString
from twisted.conch.interfaces import ISession
from twisted.conch.ssh import common, channel, connection
class SSHSession(channel.SSHChannel):
name = b'session'
def __init__(self, *args, **kw):
channel.SSHChannel.__init__(self, *args, **kw)
self.buf = b''
self.client = None
self.session = None
def request_subsystem(self, data):
subsystem, ignored= common.getNS(data)
log.msg('asking for subsystem "%s"' % subsystem)
client = self.avatar.lookupSubsystem(subsystem, data)
if client:
pp = SSHSessionProcessProtocol(self)
proto = wrapProcessProtocol(pp)
client.makeConnection(proto)
pp.makeConnection(wrapProtocol(client))
self.client = pp
return 1
else:
log.msg('failed to get subsystem')
return 0
def request_shell(self, data):
log.msg('getting shell')
if not self.session:
self.session = ISession(self.avatar)
try:
pp = SSHSessionProcessProtocol(self)
self.session.openShell(pp)
except:
log.deferr()
return 0
else:
self.client = pp
return 1
def request_exec(self, data):
if not self.session:
self.session = ISession(self.avatar)
f,data = common.getNS(data)
log.msg('executing command "%s"' % f)
try:
pp = SSHSessionProcessProtocol(self)
self.session.execCommand(pp, f)
except:
log.deferr()
return 0
else:
self.client = pp
return 1
def request_pty_req(self, data):
if not self.session:
self.session = ISession(self.avatar)
term, windowSize, modes = parseRequest_pty_req(data)
log.msg('pty request: %r %r' % (term, windowSize))
try:
self.session.getPty(term, windowSize, modes)
except:
log.err()
return 0
else:
return 1
def request_window_change(self, data):
if not self.session:
self.session = ISession(self.avatar)
winSize = parseRequest_window_change(data)
try:
self.session.windowChanged(winSize)
except:
log.msg('error changing window size')
log.err()
return 0
else:
return 1
def dataReceived(self, data):
if not self.client:
#self.conn.sendClose(self)
self.buf += data
return
self.client.transport.write(data)
def extReceived(self, dataType, data):
if dataType == connection.EXTENDED_DATA_STDERR:
if self.client and hasattr(self.client.transport, 'writeErr'):
self.client.transport.writeErr(data)
else:
log.msg('weird extended data: %s'%dataType)
def eofReceived(self):
if self.session:
self.session.eofReceived()
elif self.client:
self.conn.sendClose(self)
def closed(self):
if self.session:
self.session.closed()
elif self.client:
self.client.transport.loseConnection()
#def closeReceived(self):
# self.loseConnection() # don't know what to do with this
def loseConnection(self):
if self.client:
self.client.transport.loseConnection()
channel.SSHChannel.loseConnection(self)
class _ProtocolWrapper(protocol.ProcessProtocol):
"""
This class wraps a L{Protocol} instance in a L{ProcessProtocol} instance.
"""
def __init__(self, proto):
self.proto = proto
def connectionMade(self): self.proto.connectionMade()
def outReceived(self, data): self.proto.dataReceived(data)
def processEnded(self, reason): self.proto.connectionLost(reason)
class _DummyTransport:
def __init__(self, proto):
self.proto = proto
def dataReceived(self, data):
self.proto.transport.write(data)
def write(self, data):
self.proto.dataReceived(data)
def writeSequence(self, seq):
self.write(b''.join(seq))
def loseConnection(self):
self.proto.connectionLost(protocol.connectionDone)
def wrapProcessProtocol(inst):
if isinstance(inst, protocol.Protocol):
return _ProtocolWrapper(inst)
else:
return inst
def wrapProtocol(proto):
return _DummyTransport(proto)
# SUPPORTED_SIGNALS is a list of signals that every session channel is supposed
# to accept. See RFC 4254
SUPPORTED_SIGNALS = ["ABRT", "ALRM", "FPE", "HUP", "ILL", "INT", "KILL",
"PIPE", "QUIT", "SEGV", "TERM", "USR1", "USR2"]
@implementer(interfaces.ITransport)
class SSHSessionProcessProtocol(protocol.ProcessProtocol):
"""I am both an L{IProcessProtocol} and an L{ITransport}.
I am a transport to the remote endpoint and a process protocol to the
local subsystem.
"""
# once initialized, a dictionary mapping signal values to strings
# that follow RFC 4254.
_signalValuesToNames = None
def __init__(self, session):
self.session = session
self.lostOutOrErrFlag = False
def connectionMade(self):
if self.session.buf:
self.transport.write(self.session.buf)
self.session.buf = None
def outReceived(self, data):
self.session.write(data)
def errReceived(self, err):
self.session.writeExtended(connection.EXTENDED_DATA_STDERR, err)
def outConnectionLost(self):
"""
EOF should only be sent when both STDOUT and STDERR have been closed.
"""
if self.lostOutOrErrFlag:
self.session.conn.sendEOF(self.session)
else:
self.lostOutOrErrFlag = True
def errConnectionLost(self):
"""
See outConnectionLost().
"""
self.outConnectionLost()
def connectionLost(self, reason = None):
self.session.loseConnection()
def _getSignalName(self, signum):
"""
Get a signal name given a signal number.
"""
if self._signalValuesToNames is None:
self._signalValuesToNames = {}
# make sure that the POSIX ones are the defaults
for signame in SUPPORTED_SIGNALS:
signame = 'SIG' + signame
sigvalue = getattr(signal, signame, None)
if sigvalue is not None:
self._signalValuesToNames[sigvalue] = signame
for k, v in signal.__dict__.items():
# Check for platform specific signals, ignoring Python specific
# SIG_DFL and SIG_IGN
if k.startswith('SIG') and not k.startswith('SIG_'):
if v not in self._signalValuesToNames:
self._signalValuesToNames[v] = k + '@' + sys.platform
return self._signalValuesToNames[signum]
def processEnded(self, reason=None):
"""
When we are told the process ended, try to notify the other side about
how the process ended using the exit-signal or exit-status requests.
Also, close the channel.
"""
if reason is not None:
err = reason.value
if err.signal is not None:
signame = self._getSignalName(err.signal)
if (getattr(os, 'WCOREDUMP', None) is not None and
os.WCOREDUMP(err.status)):
log.msg('exitSignal: %s (core dumped)' % (signame,))
coreDumped = 1
else:
log.msg('exitSignal: %s' % (signame,))
coreDumped = 0
self.session.conn.sendRequest(
self.session, b'exit-signal',
common.NS(networkString(signame[3:])) + chr(coreDumped) +
common.NS(b'') + common.NS(b''))
elif err.exitCode is not None:
log.msg('exitCode: %r' % (err.exitCode,))
self.session.conn.sendRequest(self.session, b'exit-status',
struct.pack('>L', err.exitCode))
self.session.loseConnection()
def getHost(self):
"""
Return the host from my session's transport.
"""
return self.session.conn.transport.getHost()
def getPeer(self):
"""
Return the peer from my session's transport.
"""
return self.session.conn.transport.getPeer()
def write(self, data):
self.session.write(data)
def writeSequence(self, seq):
self.session.write(b''.join(seq))
def loseConnection(self):
self.session.loseConnection()
class SSHSessionClient(protocol.Protocol):
def dataReceived(self, data):
if self.transport:
self.transport.write(data)
# methods factored out to make live easier on server writers
def parseRequest_pty_req(data):
"""Parse the data from a pty-req request into usable data.
@returns: a tuple of (terminal type, (rows, cols, xpixel, ypixel), modes)
"""
term, rest = common.getNS(data)
cols, rows, xpixel, ypixel = struct.unpack('>4L', rest[: 16])
modes, ignored= common.getNS(rest[16:])
winSize = (rows, cols, xpixel, ypixel)
modes = [(ord(modes[i:i+1]), struct.unpack('>L', modes[i+1: i+5])[0])
for i in range(0, len(modes)-1, 5)]
return term, winSize, modes
def packRequest_pty_req(term, geometry, modes):
"""
Pack a pty-req request so that it is suitable for sending.
NOTE: modes must be packed before being sent here.
@type geometry: L{tuple}
@param geometry: A tuple of (rows, columns, xpixel, ypixel)
"""
(rows, cols, xpixel, ypixel) = geometry
termPacked = common.NS(term)
winSizePacked = struct.pack('>4L', cols, rows, xpixel, ypixel)
modesPacked = common.NS(modes) # depend on the client packing modes
return termPacked + winSizePacked + modesPacked
def parseRequest_window_change(data):
"""Parse the data from a window-change request into usuable data.
@returns: a tuple of (rows, cols, xpixel, ypixel)
"""
cols, rows, xpixel, ypixel = struct.unpack('>4L', data)
return rows, cols, xpixel, ypixel
def packRequest_window_change(geometry):
"""
Pack a window-change request so that it is suitable for sending.
@type geometry: L{tuple}
@param geometry: A tuple of (rows, columns, xpixel, ypixel)
"""
(rows, cols, xpixel, ypixel) = geometry
return struct.pack('>4L', cols, rows, xpixel, ypixel)

View file

@ -0,0 +1,45 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
from __future__ import absolute_import, division
from twisted.python.compat import intToBytes
def parse(s):
s = s.strip()
expr = []
while s:
if s[0:1] == b'(':
newSexp = []
if expr:
expr[-1].append(newSexp)
expr.append(newSexp)
s = s[1:]
continue
if s[0:1] == b')':
aList = expr.pop()
s=s[1:]
if not expr:
assert not s
return aList
continue
i = 0
while s[i:i+1].isdigit(): i+=1
assert i
length = int(s[:i])
data = s[i+1:i+1+length]
expr[-1].append(data)
s=s[i+1+length:]
assert 0, "this should not happen"
def pack(sexp):
s = b""
for o in sexp:
if type(o) in (type(()), type([])):
s+=b'('
s+=pack(o)
s+=b')'
else:
s+=intToBytes(len(o)) + b":" + o
return s

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,770 @@
# -*- test-case-name: twisted.conch.test.test_userauth -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Implementation of the ssh-userauth service.
Currently implemented authentication types are public-key and password.
Maintainer: Paul Swartz
"""
from __future__ import absolute_import, division
import struct
from twisted.conch import error, interfaces
from twisted.conch.ssh import keys, transport, service
from twisted.conch.ssh.common import NS, getNS
from twisted.cred import credentials
from twisted.cred.error import UnauthorizedLogin
from twisted.internet import defer, reactor
from twisted.python import failure, log
from twisted.python.compat import nativeString, _bytesChr as chr
class SSHUserAuthServer(service.SSHService):
"""
A service implementing the server side of the 'ssh-userauth' service. It
is used to authenticate the user on the other side as being able to access
this server.
@ivar name: the name of this service: 'ssh-userauth'
@type name: L{bytes}
@ivar authenticatedWith: a list of authentication methods that have
already been used.
@type authenticatedWith: L{list}
@ivar loginTimeout: the number of seconds we wait before disconnecting
the user for taking too long to authenticate
@type loginTimeout: L{int}
@ivar attemptsBeforeDisconnect: the number of failed login attempts we
allow before disconnecting.
@type attemptsBeforeDisconnect: L{int}
@ivar loginAttempts: the number of login attempts that have been made
@type loginAttempts: L{int}
@ivar passwordDelay: the number of seconds to delay when the user gives
an incorrect password
@type passwordDelay: L{int}
@ivar interfaceToMethod: a L{dict} mapping credential interfaces to
authentication methods. The server checks to see which of the
cred interfaces have checkers and tells the client that those methods
are valid for authentication.
@type interfaceToMethod: L{dict}
@ivar supportedAuthentications: A list of the supported authentication
methods.
@type supportedAuthentications: L{list} of L{bytes}
@ivar user: the last username the client tried to authenticate with
@type user: L{bytes}
@ivar method: the current authentication method
@type method: L{bytes}
@ivar nextService: the service the user wants started after authentication
has been completed.
@type nextService: L{bytes}
@ivar portal: the L{twisted.cred.portal.Portal} we are using for
authentication
@type portal: L{twisted.cred.portal.Portal}
@ivar clock: an object with a callLater method. Stubbed out for testing.
"""
name = b'ssh-userauth'
loginTimeout = 10 * 60 * 60
# 10 minutes before we disconnect them
attemptsBeforeDisconnect = 20
# 20 login attempts before a disconnect
passwordDelay = 1 # number of seconds to delay on a failed password
clock = reactor
interfaceToMethod = {
credentials.ISSHPrivateKey : b'publickey',
credentials.IUsernamePassword : b'password',
}
def serviceStarted(self):
"""
Called when the userauth service is started. Set up instance
variables, check if we should allow password authentication (only
allow if the outgoing connection is encrypted) and set up a login
timeout.
"""
self.authenticatedWith = []
self.loginAttempts = 0
self.user = None
self.nextService = None
self.portal = self.transport.factory.portal
self.supportedAuthentications = []
for i in self.portal.listCredentialsInterfaces():
if i in self.interfaceToMethod:
self.supportedAuthentications.append(self.interfaceToMethod[i])
if not self.transport.isEncrypted('in'):
# don't let us transport password in plaintext
if b'password' in self.supportedAuthentications:
self.supportedAuthentications.remove(b'password')
self._cancelLoginTimeout = self.clock.callLater(
self.loginTimeout,
self.timeoutAuthentication)
def serviceStopped(self):
"""
Called when the userauth service is stopped. Cancel the login timeout
if it's still going.
"""
if self._cancelLoginTimeout:
self._cancelLoginTimeout.cancel()
self._cancelLoginTimeout = None
def timeoutAuthentication(self):
"""
Called when the user has timed out on authentication. Disconnect
with a DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE message.
"""
self._cancelLoginTimeout = None
self.transport.sendDisconnect(
transport.DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE,
b'you took too long')
def tryAuth(self, kind, user, data):
"""
Try to authenticate the user with the given method. Dispatches to a
auth_* method.
@param kind: the authentication method to try.
@type kind: L{bytes}
@param user: the username the client is authenticating with.
@type user: L{bytes}
@param data: authentication specific data sent by the client.
@type data: L{bytes}
@return: A Deferred called back if the method succeeded, or erred back
if it failed.
@rtype: C{defer.Deferred}
"""
log.msg('%r trying auth %r' % (user, kind))
if kind not in self.supportedAuthentications:
return defer.fail(
error.ConchError('unsupported authentication, failing'))
kind = nativeString(kind.replace(b'-', b'_'))
f = getattr(self, 'auth_%s' % (kind,), None)
if f:
ret = f(data)
if not ret:
return defer.fail(
error.ConchError(
'%s return None instead of a Deferred'
% (kind, )))
else:
return ret
return defer.fail(error.ConchError('bad auth type: %s' % (kind,)))
def ssh_USERAUTH_REQUEST(self, packet):
"""
The client has requested authentication. Payload::
string user
string next service
string method
<authentication specific data>
@type packet: L{bytes}
"""
user, nextService, method, rest = getNS(packet, 3)
if user != self.user or nextService != self.nextService:
self.authenticatedWith = [] # clear auth state
self.user = user
self.nextService = nextService
self.method = method
d = self.tryAuth(method, user, rest)
if not d:
self._ebBadAuth(
failure.Failure(error.ConchError('auth returned none')))
return
d.addCallback(self._cbFinishedAuth)
d.addErrback(self._ebMaybeBadAuth)
d.addErrback(self._ebBadAuth)
return d
def _cbFinishedAuth(self, result):
"""
The callback when user has successfully been authenticated. For a
description of the arguments, see L{twisted.cred.portal.Portal.login}.
We start the service requested by the user.
"""
(interface, avatar, logout) = result
self.transport.avatar = avatar
self.transport.logoutFunction = logout
service = self.transport.factory.getService(self.transport,
self.nextService)
if not service:
raise error.ConchError('could not get next service: %s'
% self.nextService)
log.msg('%r authenticated with %r' % (self.user, self.method))
self.transport.sendPacket(MSG_USERAUTH_SUCCESS, b'')
self.transport.setService(service())
def _ebMaybeBadAuth(self, reason):
"""
An intermediate errback. If the reason is
error.NotEnoughAuthentication, we send a MSG_USERAUTH_FAILURE, but
with the partial success indicator set.
@type reason: L{twisted.python.failure.Failure}
"""
reason.trap(error.NotEnoughAuthentication)
self.transport.sendPacket(MSG_USERAUTH_FAILURE,
NS(b','.join(self.supportedAuthentications)) + b'\xff')
def _ebBadAuth(self, reason):
"""
The final errback in the authentication chain. If the reason is
error.IgnoreAuthentication, we simply return; the authentication
method has sent its own response. Otherwise, send a failure message
and (if the method is not 'none') increment the number of login
attempts.
@type reason: L{twisted.python.failure.Failure}
"""
if reason.check(error.IgnoreAuthentication):
return
if self.method != b'none':
log.msg('%r failed auth %r' % (self.user, self.method))
if reason.check(UnauthorizedLogin):
log.msg('unauthorized login: %s' % reason.getErrorMessage())
elif reason.check(error.ConchError):
log.msg('reason: %s' % reason.getErrorMessage())
else:
log.msg(reason.getTraceback())
self.loginAttempts += 1
if self.loginAttempts > self.attemptsBeforeDisconnect:
self.transport.sendDisconnect(
transport.DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE,
b'too many bad auths')
return
self.transport.sendPacket(
MSG_USERAUTH_FAILURE,
NS(b','.join(self.supportedAuthentications)) + b'\x00')
def auth_publickey(self, packet):
"""
Public key authentication. Payload::
byte has signature
string algorithm name
string key blob
[string signature] (if has signature is True)
Create a SSHPublicKey credential and verify it using our portal.
"""
hasSig = ord(packet[0:1])
algName, blob, rest = getNS(packet[1:], 2)
try:
pubKey = keys.Key.fromString(blob)
except keys.BadKeyError:
error = "Unsupported key type %s or bad key" % (
algName.decode('ascii'),)
log.msg(error)
return defer.fail(UnauthorizedLogin(error))
signature = hasSig and getNS(rest)[0] or None
if hasSig:
b = (NS(self.transport.sessionID) + chr(MSG_USERAUTH_REQUEST) +
NS(self.user) + NS(self.nextService) + NS(b'publickey') +
chr(hasSig) + NS(pubKey.sshType()) + NS(blob))
c = credentials.SSHPrivateKey(self.user, algName, blob, b,
signature)
return self.portal.login(c, None, interfaces.IConchUser)
else:
c = credentials.SSHPrivateKey(self.user, algName, blob, None, None)
return self.portal.login(c, None,
interfaces.IConchUser).addErrback(self._ebCheckKey,
packet[1:])
def _ebCheckKey(self, reason, packet):
"""
Called back if the user did not sent a signature. If reason is
error.ValidPublicKey then this key is valid for the user to
authenticate with. Send MSG_USERAUTH_PK_OK.
"""
reason.trap(error.ValidPublicKey)
# if we make it here, it means that the publickey is valid
self.transport.sendPacket(MSG_USERAUTH_PK_OK, packet)
return failure.Failure(error.IgnoreAuthentication())
def auth_password(self, packet):
"""
Password authentication. Payload::
string password
Make a UsernamePassword credential and verify it with our portal.
"""
password = getNS(packet[1:])[0]
c = credentials.UsernamePassword(self.user, password)
return self.portal.login(c, None, interfaces.IConchUser).addErrback(
self._ebPassword)
def _ebPassword(self, f):
"""
If the password is invalid, wait before sending the failure in order
to delay brute-force password guessing.
"""
d = defer.Deferred()
self.clock.callLater(self.passwordDelay, d.callback, f)
return d
class SSHUserAuthClient(service.SSHService):
"""
A service implementing the client side of 'ssh-userauth'.
This service will try all authentication methods provided by the server,
making callbacks for more information when necessary.
@ivar name: the name of this service: 'ssh-userauth'
@type name: L{str}
@ivar preferredOrder: a list of authentication methods that should be used
first, in order of preference, if supported by the server
@type preferredOrder: L{list}
@ivar user: the name of the user to authenticate as
@type user: L{bytes}
@ivar instance: the service to start after authentication has finished
@type instance: L{service.SSHService}
@ivar authenticatedWith: a list of strings of authentication methods we've tried
@type authenticatedWith: L{list} of L{bytes}
@ivar triedPublicKeys: a list of public key objects that we've tried to
authenticate with
@type triedPublicKeys: L{list} of L{Key}
@ivar lastPublicKey: the last public key object we've tried to authenticate
with
@type lastPublicKey: L{Key}
"""
name = b'ssh-userauth'
preferredOrder = [b'publickey', b'password', b'keyboard-interactive']
def __init__(self, user, instance):
self.user = user
self.instance = instance
def serviceStarted(self):
self.authenticatedWith = []
self.triedPublicKeys = []
self.lastPublicKey = None
self.askForAuth(b'none', b'')
def askForAuth(self, kind, extraData):
"""
Send a MSG_USERAUTH_REQUEST.
@param kind: the authentication method to try.
@type kind: L{bytes}
@param extraData: method-specific data to go in the packet
@type extraData: L{bytes}
"""
self.lastAuth = kind
self.transport.sendPacket(MSG_USERAUTH_REQUEST, NS(self.user) +
NS(self.instance.name) + NS(kind) + extraData)
def tryAuth(self, kind):
"""
Dispatch to an authentication method.
@param kind: the authentication method
@type kind: L{bytes}
"""
kind = nativeString(kind.replace(b'-', b'_'))
log.msg('trying to auth with %s' % (kind,))
f = getattr(self,'auth_%s' % (kind,), None)
if f:
return f()
def _ebAuth(self, ignored, *args):
"""
Generic callback for a failed authentication attempt. Respond by
asking for the list of accepted methods (the 'none' method)
"""
self.askForAuth(b'none', b'')
def ssh_USERAUTH_SUCCESS(self, packet):
"""
We received a MSG_USERAUTH_SUCCESS. The server has accepted our
authentication, so start the next service.
"""
self.transport.setService(self.instance)
def ssh_USERAUTH_FAILURE(self, packet):
"""
We received a MSG_USERAUTH_FAILURE. Payload::
string methods
byte partial success
If partial success is C{True}, then the previous method succeeded but is
not sufficient for authentication. C{methods} is a comma-separated list
of accepted authentication methods.
We sort the list of methods by their position in C{self.preferredOrder},
removing methods that have already succeeded. We then call
C{self.tryAuth} with the most preferred method.
@param packet: the C{MSG_USERAUTH_FAILURE} payload.
@type packet: L{bytes}
@return: a L{defer.Deferred} that will be callbacked with L{None} as
soon as all authentication methods have been tried, or L{None} if no
more authentication methods are available.
@rtype: C{defer.Deferred} or L{None}
"""
canContinue, partial = getNS(packet)
partial = ord(partial)
if partial:
self.authenticatedWith.append(self.lastAuth)
def orderByPreference(meth):
"""
Invoked once per authentication method in order to extract a
comparison key which is then used for sorting.
@param meth: the authentication method.
@type meth: L{bytes}
@return: the comparison key for C{meth}.
@rtype: L{int}
"""
if meth in self.preferredOrder:
return self.preferredOrder.index(meth)
else:
# put the element at the end of the list.
return len(self.preferredOrder)
canContinue = sorted([meth for meth in canContinue.split(b',')
if meth not in self.authenticatedWith],
key=orderByPreference)
log.msg('can continue with: %s' % canContinue)
return self._cbUserauthFailure(None, iter(canContinue))
def _cbUserauthFailure(self, result, iterator):
if result:
return
try:
method = next(iterator)
except StopIteration:
self.transport.sendDisconnect(
transport.DISCONNECT_NO_MORE_AUTH_METHODS_AVAILABLE,
b'no more authentication methods available')
else:
d = defer.maybeDeferred(self.tryAuth, method)
d.addCallback(self._cbUserauthFailure, iterator)
return d
def ssh_USERAUTH_PK_OK(self, packet):
"""
This message (number 60) can mean several different messages depending
on the current authentication type. We dispatch to individual methods
in order to handle this request.
"""
func = getattr(self, 'ssh_USERAUTH_PK_OK_%s' %
nativeString(self.lastAuth.replace(b'-', b'_')), None)
if func is not None:
return func(packet)
else:
self.askForAuth(b'none', b'')
def ssh_USERAUTH_PK_OK_publickey(self, packet):
"""
This is MSG_USERAUTH_PK. Our public key is valid, so we create a
signature and try to authenticate with it.
"""
publicKey = self.lastPublicKey
b = (NS(self.transport.sessionID) + chr(MSG_USERAUTH_REQUEST) +
NS(self.user) + NS(self.instance.name) + NS(b'publickey') +
b'\x01' + NS(publicKey.sshType()) + NS(publicKey.blob()))
d = self.signData(publicKey, b)
if not d:
self.askForAuth(b'none', b'')
# this will fail, we'll move on
return
d.addCallback(self._cbSignedData)
d.addErrback(self._ebAuth)
def ssh_USERAUTH_PK_OK_password(self, packet):
"""
This is MSG_USERAUTH_PASSWD_CHANGEREQ. The password given has expired.
We ask for an old password and a new password, then send both back to
the server.
"""
prompt, language, rest = getNS(packet, 2)
self._oldPass = self._newPass = None
d = self.getPassword(b'Old Password: ')
d = d.addCallbacks(self._setOldPass, self._ebAuth)
d.addCallback(lambda ignored: self.getPassword(prompt))
d.addCallbacks(self._setNewPass, self._ebAuth)
def ssh_USERAUTH_PK_OK_keyboard_interactive(self, packet):
"""
This is MSG_USERAUTH_INFO_RESPONSE. The server has sent us the
questions it wants us to answer, so we ask the user and sent the
responses.
"""
name, instruction, lang, data = getNS(packet, 3)
numPrompts = struct.unpack('!L', data[:4])[0]
data = data[4:]
prompts = []
for i in range(numPrompts):
prompt, data = getNS(data)
echo = bool(ord(data[0:1]))
data = data[1:]
prompts.append((prompt, echo))
d = self.getGenericAnswers(name, instruction, prompts)
d.addCallback(self._cbGenericAnswers)
d.addErrback(self._ebAuth)
def _cbSignedData(self, signedData):
"""
Called back out of self.signData with the signed data. Send the
authentication request with the signature.
@param signedData: the data signed by the user's private key.
@type signedData: L{bytes}
"""
publicKey = self.lastPublicKey
self.askForAuth(b'publickey', b'\x01' + NS(publicKey.sshType()) +
NS(publicKey.blob()) + NS(signedData))
def _setOldPass(self, op):
"""
Called back when we are choosing a new password. Simply store the old
password for now.
@param op: the old password as entered by the user
@type op: L{bytes}
"""
self._oldPass = op
def _setNewPass(self, np):
"""
Called back when we are choosing a new password. Get the old password
and send the authentication message with both.
@param np: the new password as entered by the user
@type np: L{bytes}
"""
op = self._oldPass
self._oldPass = None
self.askForAuth(b'password', b'\xff' + NS(op) + NS(np))
def _cbGenericAnswers(self, responses):
"""
Called back when we are finished answering keyboard-interactive
questions. Send the info back to the server in a
MSG_USERAUTH_INFO_RESPONSE.
@param responses: a list of L{bytes} responses
@type responses: L{list}
"""
data = struct.pack('!L', len(responses))
for r in responses:
data += NS(r.encode('UTF8'))
self.transport.sendPacket(MSG_USERAUTH_INFO_RESPONSE, data)
def auth_publickey(self):
"""
Try to authenticate with a public key. Ask the user for a public key;
if the user has one, send the request to the server and return True.
Otherwise, return False.
@rtype: L{bool}
"""
d = defer.maybeDeferred(self.getPublicKey)
d.addBoth(self._cbGetPublicKey)
return d
def _cbGetPublicKey(self, publicKey):
if not isinstance(publicKey, keys.Key): # failure or None
publicKey = None
if publicKey is not None:
self.lastPublicKey = publicKey
self.triedPublicKeys.append(publicKey)
log.msg('using key of type %s' % publicKey.type())
self.askForAuth(b'publickey', b'\x00' + NS(publicKey.sshType()) +
NS(publicKey.blob()))
return True
else:
return False
def auth_password(self):
"""
Try to authenticate with a password. Ask the user for a password.
If the user will return a password, return True. Otherwise, return
False.
@rtype: L{bool}
"""
d = self.getPassword()
if d:
d.addCallbacks(self._cbPassword, self._ebAuth)
return True
else: # returned None, don't do password auth
return False
def auth_keyboard_interactive(self):
"""
Try to authenticate with keyboard-interactive authentication. Send
the request to the server and return True.
@rtype: L{bool}
"""
log.msg('authing with keyboard-interactive')
self.askForAuth(b'keyboard-interactive', NS(b'') + NS(b''))
return True
def _cbPassword(self, password):
"""
Called back when the user gives a password. Send the request to the
server.
@param password: the password the user entered
@type password: L{bytes}
"""
self.askForAuth(b'password', b'\x00' + NS(password))
def signData(self, publicKey, signData):
"""
Sign the given data with the given public key.
By default, this will call getPrivateKey to get the private key,
then sign the data using Key.sign().
This method is factored out so that it can be overridden to use
alternate methods, such as a key agent.
@param publicKey: The public key object returned from L{getPublicKey}
@type publicKey: L{keys.Key}
@param signData: the data to be signed by the private key.
@type signData: L{bytes}
@return: a Deferred that's called back with the signature
@rtype: L{defer.Deferred}
"""
key = self.getPrivateKey()
if not key:
return
return key.addCallback(self._cbSignData, signData)
def _cbSignData(self, privateKey, signData):
"""
Called back when the private key is returned. Sign the data and
return the signature.
@param privateKey: the private key object
@type publicKey: L{keys.Key}
@param signData: the data to be signed by the private key.
@type signData: L{bytes}
@return: the signature
@rtype: L{bytes}
"""
return privateKey.sign(signData)
def getPublicKey(self):
"""
Return a public key for the user. If no more public keys are
available, return L{None}.
This implementation always returns L{None}. Override it in a
subclass to actually find and return a public key object.
@rtype: L{Key} or L{None}
"""
return None
def getPrivateKey(self):
"""
Return a L{Deferred} that will be called back with the private key
object corresponding to the last public key from getPublicKey().
If the private key is not available, errback on the Deferred.
@rtype: L{Deferred} called back with L{Key}
"""
return defer.fail(NotImplementedError())
def getPassword(self, prompt = None):
"""
Return a L{Deferred} that will be called back with a password.
prompt is a string to display for the password, or None for a generic
'user@hostname's password: '.
@type prompt: L{bytes}/L{None}
@rtype: L{defer.Deferred}
"""
return defer.fail(NotImplementedError())
def getGenericAnswers(self, name, instruction, prompts):
"""
Returns a L{Deferred} with the responses to the promopts.
@param name: The name of the authentication currently in progress.
@param instruction: Describes what the authentication wants.
@param prompts: A list of (prompt, echo) pairs, where prompt is a
string to display and echo is a boolean indicating whether the
user's response should be echoed as they type it.
"""
return defer.fail(NotImplementedError())
MSG_USERAUTH_REQUEST = 50
MSG_USERAUTH_FAILURE = 51
MSG_USERAUTH_SUCCESS = 52
MSG_USERAUTH_BANNER = 53
MSG_USERAUTH_INFO_RESPONSE = 61
MSG_USERAUTH_PK_OK = 60
messages = {}
for k, v in list(locals().items()):
if k[:4] == 'MSG_':
messages[v] = k
SSHUserAuthServer.protocolMessages = messages
SSHUserAuthClient.protocolMessages = messages
del messages
del v
# Doubles, not included in the protocols' mappings
MSG_USERAUTH_PASSWD_CHANGEREQ = 60
MSG_USERAUTH_INFO_REQUEST = 60

View file

@ -0,0 +1,120 @@
# -*- test-case-name: twisted.conch.test.test_manhole -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Asynchronous local terminal input handling
@author: Jp Calderone
"""
import os, tty, sys, termios
from twisted.internet import reactor, stdio, protocol, defer
from twisted.python import failure, reflect, log
from twisted.conch.insults.insults import ServerProtocol
from twisted.conch.manhole import ColoredManhole
class UnexpectedOutputError(Exception):
pass
class TerminalProcessProtocol(protocol.ProcessProtocol):
def __init__(self, proto):
self.proto = proto
self.onConnection = defer.Deferred()
def connectionMade(self):
self.proto.makeConnection(self)
self.onConnection.callback(None)
self.onConnection = None
def write(self, data):
"""
Write to the terminal.
@param data: Data to write.
@type data: L{bytes}
"""
self.transport.write(data)
def outReceived(self, data):
"""
Receive data from the terminal.
@param data: Data received.
@type data: L{bytes}
"""
self.proto.dataReceived(data)
def errReceived(self, data):
"""
Report an error.
@param data: Data to include in L{Failure}.
@type data: L{bytes}
"""
self.transport.loseConnection()
if self.proto is not None:
self.proto.connectionLost(failure.Failure(UnexpectedOutputError(data)))
self.proto = None
def childConnectionLost(self, childFD):
if self.proto is not None:
self.proto.childConnectionLost(childFD)
def processEnded(self, reason):
if self.proto is not None:
self.proto.connectionLost(reason)
self.proto = None
class ConsoleManhole(ColoredManhole):
"""
A manhole protocol specifically for use with L{stdio.StandardIO}.
"""
def connectionLost(self, reason):
"""
When the connection is lost, there is nothing more to do. Stop the
reactor so that the process can exit.
"""
reactor.stop()
def runWithProtocol(klass):
fd = sys.__stdin__.fileno()
oldSettings = termios.tcgetattr(fd)
tty.setraw(fd)
try:
stdio.StandardIO(ServerProtocol(klass))
reactor.run()
finally:
termios.tcsetattr(fd, termios.TCSANOW, oldSettings)
os.write(fd, b"\r\x1bc\r")
def main(argv=None):
log.startLogging(open('child.log', 'w'))
if argv is None:
argv = sys.argv[1:]
if argv:
klass = reflect.namedClass(argv[0])
else:
klass = ConsoleManhole
runWithProtocol(klass)
if __name__ == '__main__':
main()

View file

@ -0,0 +1,86 @@
# -*- test-case-name: twisted.conch.test.test_tap -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Support module for making SSH servers with twistd.
"""
from twisted.conch import unix
from twisted.conch import checkers as conch_checkers
from twisted.conch.openssh_compat import factory
from twisted.cred import portal, strcred
from twisted.python import usage
from twisted.application import strports
class Options(usage.Options, strcred.AuthOptionMixin):
synopsis = "[-i <interface>] [-p <port>] [-d <dir>] "
longdesc = ("Makes a Conch SSH server. If no authentication methods are "
"specified, the default authentication methods are UNIX passwords "
"and SSH public keys. If --auth options are "
"passed, only the measures specified will be used.")
optParameters = [
["interface", "i", "", "local interface to which we listen"],
["port", "p", "tcp:22", "Port on which to listen"],
["data", "d", "/etc", "directory to look for host keys in"],
["moduli", "", None, "directory to look for moduli in "
"(if different from --data)"]
]
compData = usage.Completions(
optActions={"data": usage.CompleteDirs(descr="data directory"),
"moduli": usage.CompleteDirs(descr="moduli directory"),
"interface": usage.CompleteNetInterfaces()}
)
def __init__(self, *a, **kw):
usage.Options.__init__(self, *a, **kw)
# Call the default addCheckers (for backwards compatibility) that will
# be used if no --auth option is provided - note that conch's
# UNIXPasswordDatabase is used, instead of twisted.plugins.cred_unix's
# checker
super(Options, self).addChecker(conch_checkers.UNIXPasswordDatabase())
super(Options, self).addChecker(conch_checkers.SSHPublicKeyChecker(
conch_checkers.UNIXAuthorizedKeysFiles()))
self._usingDefaultAuth = True
def addChecker(self, checker):
"""
Add the checker specified. If any checkers are added, the default
checkers are automatically cleared and the only checkers will be the
specified one(s).
"""
if self._usingDefaultAuth:
self['credCheckers'] = []
self['credInterfaces'] = {}
self._usingDefaultAuth = False
super(Options, self).addChecker(checker)
def makeService(config):
"""
Construct a service for operating a SSH server.
@param config: An L{Options} instance specifying server options, including
where server keys are stored and what authentication methods to use.
@return: A L{twisted.application.service.IService} provider which contains
the requested SSH server.
"""
t = factory.OpenSSHFactory()
r = unix.UnixSSHRealm()
t.portal = portal.Portal(r, config.get('credCheckers', []))
t.dataRoot = config['data']
t.moduliRoot = config['moduli'] or config['data']
port = config['port']
if config['interface']:
# Add warning here
port += ':interface=' + config['interface']
return strports.service(port, t)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1 @@
'conch tests'

View file

@ -0,0 +1,576 @@
# -*- test-case-name: twisted.conch.test.test_keys -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
# pylint: disable=I0011,C0103,W9401,W9402
"""
Data used by test_keys as well as others.
"""
from __future__ import absolute_import, division
from twisted.python.compat import long, _b64decodebytes as decodebytes
RSAData = {
'n': long('269413617238113438198661010376758399219880277968382122687862697'
'296942471209955603071120391975773283844560230371884389952067978'
'789684135947515341209478065209455427327369102356204259106807047'
'964139525310539133073743116175821417513079706301100600025815509'
'786721808719302671068052414466483676821987505720384645561708425'
'794379383191274856941628512616355437197560712892001107828247792'
'561858327085521991407807015047750218508971611590850575870321007'
'991909043252470730134547038841839367764074379439843108550888709'
'430958143271417044750314742880542002948053835745429446485015316'
'60749404403945254975473896534482849256068133525751'),
'e': long(65537),
'd': long('420335724286999695680502438485489819800002417295071059780489811'
'840828351636754206234982682752076205397047218449504537476523960'
'987613148307573487322720481066677105211155388802079519869249746'
'774085882219244493290663802569201213676433159425782937159766786'
'329742053214957933941260042101377175565683849732354700525628975'
'239000548651346620826136200952740446562751690924335365940810658'
'931238410612521441739702170503547025018016868116037053013935451'
'477930426013703886193016416453215950072147440344656137718959053'
'897268663969428680144841987624962928576808352739627262941675617'
'7724661940425316604626522633351193810751757014073'),
'p': long('152689878451107675391723141129365667732639179427453246378763774'
'448531436802867910180261906924087589684175595016060014593521649'
'964959248408388984465569934780790357826811592229318702991401054'
'226302790395714901636384511513449977061729214247279176398290513'
'085108930550446985490864812445551198848562639933888780317'),
'q': long('176444974592327996338888725079951900172097062203378367409936859'
'072670162290963119826394224277287608693818012745872307600855894'
'647300295516866118620024751601329775653542084052616260193174546'
'400544176890518564317596334518015173606460860373958663673307503'
'231977779632583864454001476729233959405710696795574874403'),
'u': long('936018002388095842969518498561007090965136403384715613439364803'
'229386793506402222847415019772053080458257034241832795210460612'
'924445085372678524176842007912276654532773301546269997020970818'
'155956828553418266110329867222673040098885651348225673298948529'
'93885224775891490070400861134282266967852120152546563278')
}
DSAData = {
'g': long("10253261326864117157640690761723586967382334319435778695"
"29171533815411392477819921538350732400350395446211982054"
"96512489289702949127531056893725702005035043292195216541"
"11525058911428414042792836395195432445511200566318251789"
"10575695836669396181746841141924498545494149998282951407"
"18645344764026044855941864175"),
'p': long("10292031726231756443208850082191198787792966516790381991"
"77502076899763751166291092085666022362525614129374702633"
"26262930887668422949051881895212412718444016917144560705"
"45675251775747156453237145919794089496168502517202869160"
"78674893099371444940800865897607102159386345313384716752"
"18590012064772045092956919481"),
'q': long(1393384845225358996250882900535419012502712821577),
'x': long(1220877188542930584999385210465204342686893855021),
'y': long("14604423062661947579790240720337570315008549983452208015"
"39426429789435409684914513123700756086453120500041882809"
"10283610277194188071619191739512379408443695946763554493"
"86398594314468629823767964702559709430618263927529765769"
"10270265745700231533660131769648708944711006508965764877"
"684264272082256183140297951")
}
ECDatanistp256 = {
'x': long('762825130203920963171185031449647317742997734817505505433829043'
'45687059013883'),
'y': long('815431978646028526322656647694416475343443758943143196810611371'
'59310646683104'),
'privateValue': long('3463874347721034170096400845565569825355565567882605'
'9678074967909361042656500'),
'curve': b'ecdsa-sha2-nistp256'
}
ECDatanistp384 = {
'privateValue': long('280814107134858470598753916394807521398239633534281633982576099083'
'35787109896602102090002196616273211495718603965098'),
'x': long('10036914308591746758780165503819213553101287571902957054148542'
'504671046744460374996612408381962208627004841444205030'),
'y': long('17337335659928075994560513699823544906448896792102247714689323'
'575406618073069185107088229463828921069465902299522926'),
'curve': b'ecdsa-sha2-nistp384'
}
ECDatanistp521 = {
'x': long('12944742826257420846659527752683763193401384271391513286022917'
'29910013082920512632908350502247952686156279140016049549948975'
'670668730618745449113644014505462'),
'y': long('10784108810271976186737587749436295782985563640368689081052886'
'16296815984553198866894145509329328086635278430266482551941240'
'591605833440825557820439734509311'),
'privateValue': long('662751235215460886290293902658128847495347691199214706697089140769'
'672273950767961331442265530524063943548846724348048614239791498442'
'5997823106818915698960565'),
'curve': b'ecdsa-sha2-nistp521'
}
privateECDSA_openssh521 = b"""-----BEGIN EC PRIVATE KEY-----
MIHcAgEBBEIAjn0lSVF6QweS4bjOGP9RHwqxUiTastSE0MVuLtFvkxygZqQ712oZ
ewMvqKkxthMQgxzSpGtRBcmkL7RqZ94+18qgBwYFK4EEACOhgYkDgYYABAFpX/6B
mxxglwD+VpEvw0hcyxVzLxNnMGzxZGF7xmNj8nlF7M+TQctdlR2Xv/J+AgIeVGmB
j2p84bkV9jBzrUNJEACsJjttZw8NbUrhxjkLT/3rMNtuwjE4vLja0P7DMTE0EV8X
f09ETdku/z/1tOSSrSvRwmUcM9nQUJtHHAZlr5Q0fw==
-----END EC PRIVATE KEY-----"""
# New format introduced in OpenSSH 6.5
privateECDSA_openssh521_new = b"""-----BEGIN OPENSSH PRIVATE KEY-----
b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAArAAAABNlY2RzYS
1zaGEyLW5pc3RwNTIxAAAACG5pc3RwNTIxAAAAhQQBaV/+gZscYJcA/laRL8NIXMsVcy8T
ZzBs8WRhe8ZjY/J5RezPk0HLXZUdl7/yfgICHlRpgY9qfOG5FfYwc61DSRAArCY7bWcPDW
1K4cY5C0/96zDbbsIxOLy42tD+wzExNBFfF39PRE3ZLv8/9bTkkq0r0cJlHDPZ0FCbRxwG
Za+UNH8AAAEAeRISlnkSEpYAAAATZWNkc2Etc2hhMi1uaXN0cDUyMQAAAAhuaXN0cDUyMQ
AAAIUEAWlf/oGbHGCXAP5WkS/DSFzLFXMvE2cwbPFkYXvGY2PyeUXsz5NBy12VHZe/8n4C
Ah5UaYGPanzhuRX2MHOtQ0kQAKwmO21nDw1tSuHGOQtP/esw227CMTi8uNrQ/sMxMTQRXx
d/T0RN2S7/P/W05JKtK9HCZRwz2dBQm0ccBmWvlDR/AAAAQgCOfSVJUXpDB5LhuM4Y/1Ef
CrFSJNqy1ITQxW4u0W+THKBmpDvXahl7Ay+oqTG2ExCDHNKka1EFyaQvtGpn3j7XygAAAA
ABAg==
-----END OPENSSH PRIVATE KEY-----
"""
publicECDSA_openssh521 = (
b"ecdsa-sha2-nistp521 AAAAE2VjZHNhLXNoYTItbmlzdHA1MjEAAAAIbmlzdHA1MjEAAACF"
b"BAFpX/6BmxxglwD+VpEvw0hcyxVzLxNnMGzxZGF7xmNj8nlF7M+TQctdlR2Xv/J+AgIeVGmB"
b"j2p84bkV9jBzrUNJEACsJjttZw8NbUrhxjkLT/3rMNtuwjE4vLja0P7DMTE0EV8Xf09ETdku"
b"/z/1tOSSrSvRwmUcM9nQUJtHHAZlr5Q0fw== comment"
)
privateECDSA_openssh384 = b"""-----BEGIN EC PRIVATE KEY-----
MIGkAgEBBDAtAi7I8j73WCX20qUM5hhHwHuFzYWYYILs2Sh8UZ+awNkARZ/Fu2LU
LLl5RtOQpbWgBwYFK4EEACKhZANiAATU17sA9P5FRwSknKcFsjjsk0+E3CeXPYX0
Tk/M0HK3PpWQWgrO8JdRHP9eFE9O/23P8BumwFt7F/AvPlCzVd35VfraFT0o4cCW
G0RqpQ+np31aKmeJshkcYALEchnU+tQ=
-----END EC PRIVATE KEY-----"""
# New format introduced in OpenSSH 6.5
privateECDSA_openssh384_new = b"""-----BEGIN OPENSSH PRIVATE KEY-----
b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAiAAAABNlY2RzYS
1zaGEyLW5pc3RwMzg0AAAACG5pc3RwMzg0AAAAYQTU17sA9P5FRwSknKcFsjjsk0+E3CeX
PYX0Tk/M0HK3PpWQWgrO8JdRHP9eFE9O/23P8BumwFt7F/AvPlCzVd35VfraFT0o4cCWG0
RqpQ+np31aKmeJshkcYALEchnU+tQAAADIiktpWIpLaVgAAAATZWNkc2Etc2hhMi1uaXN0
cDM4NAAAAAhuaXN0cDM4NAAAAGEE1Ne7APT+RUcEpJynBbI47JNPhNwnlz2F9E5PzNBytz
6VkFoKzvCXURz/XhRPTv9tz/AbpsBbexfwLz5Qs1Xd+VX62hU9KOHAlhtEaqUPp6d9Wipn
ibIZHGACxHIZ1PrUAAAAMC0CLsjyPvdYJfbSpQzmGEfAe4XNhZhgguzZKHxRn5rA2QBFn8
W7YtQsuXlG05CltQAAAAA=
-----END OPENSSH PRIVATE KEY-----
"""
publicECDSA_openssh384 = (
b"ecdsa-sha2-nistp384 AAAAE2VjZHNhLXNoYTItbmlzdHAzODQAAAAIbmlzdHAzODQAAABh"
b"BNTXuwD0/kVHBKScpwWyOOyTT4TcJ5c9hfROT8zQcrc+lZBaCs7wl1Ec/14UT07/bc/wG6bA"
b"W3sX8C8+ULNV3flV+toVPSjhwJYbRGqlD6enfVoqZ4myGRxgAsRyGdT61A== comment"
)
publicECDSA_openssh = (
b"ecdsa-sha2-nistp256 AAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABB"
b"BKimX1DZ7+Qj0SpfePMbo1pb6yGkAb5l7duC1l855yD7tEfQfqk7bc7v46We1hLMyz6ObUBY"
b"gkN/34n42F4vpeA= comment"
)
privateECDSA_openssh = b"""-----BEGIN EC PRIVATE KEY-----
MHcCAQEEIEyU1YOT2JxxofwbJXIjGftdNcJK55aQdNrhIt2xYQz0oAoGCCqGSM49
AwEHoUQDQgAEqKZfUNnv5CPRKl948xujWlvrIaQBvmXt24LWXznnIPu0R9B+qTtt
zu/jpZ7WEszLPo5tQFiCQ3/fifjYXi+l4A==
-----END EC PRIVATE KEY-----"""
# New format introduced in OpenSSH 6.5
privateECDSA_openssh_new = b"""-----BEGIN OPENSSH PRIVATE KEY-----
b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAAAaAAAABNlY2RzYS
1zaGEyLW5pc3RwMjU2AAAACG5pc3RwMjU2AAAAQQSopl9Q2e/kI9EqX3jzG6NaW+shpAG+
Ze3bgtZfOecg+7RH0H6pO23O7+OlntYSzMs+jm1AWIJDf9+J+NheL6XgAAAAmCKU4hcilO
IXAAAAE2VjZHNhLXNoYTItbmlzdHAyNTYAAAAIbmlzdHAyNTYAAABBBKimX1DZ7+Qj0Spf
ePMbo1pb6yGkAb5l7duC1l855yD7tEfQfqk7bc7v46We1hLMyz6ObUBYgkN/34n42F4vpe
AAAAAgTJTVg5PYnHGh/BslciMZ+101wkrnlpB02uEi3bFhDPQAAAAA
-----END OPENSSH PRIVATE KEY-----
"""
publicRSA_openssh = (
b"ssh-rsa AAAAB3NzaC1yc2EAAAADAQABAAABAQDVaqx4I9bWG+wloVDEd2NQhEUBVUIUKirg"
b"0GDu1OmjrUr6OQZehFV1XwA2v2+qKj+DJjfBaS5b/fDz0n3WmM06QHjVyqgYwBGTJAkMgUyP"
b"95ztExZqpATpSXfD5FVks3loniwI66zoBC0hdwWnju9TMA2l5bs9auIJNm/9NNN9b0b/h9qp"
b"KSeq/631heY+Grh6HUqx6sBa9zDfH8Kk5O8/kUmWQNUZdy03w17snaY6RKXCpCnd1bqcPUWz"
b"xiwYZNW6Pd+rf81CrKfxGAugWBViC6QqbkPD5ASfNaNHjkbtM6Vlvbw7KW4CC1ffdOgTtDc1"
b"foNfICZgptyti8ZseZj3 comment"
)
privateRSA_openssh = b'''-----BEGIN RSA PRIVATE KEY-----
MIIEogIBAAKCAQEA1WqseCPW1hvsJaFQxHdjUIRFAVVCFCoq4NBg7tTpo61K+jkG
XoRVdV8ANr9vqio/gyY3wWkuW/3w89J91pjNOkB41cqoGMARkyQJDIFMj/ec7RMW
aqQE6Ul3w+RVZLN5aJ4sCOus6AQtIXcFp47vUzANpeW7PWriCTZv/TTTfW9G/4fa
qSknqv+t9YXmPhq4eh1KserAWvcw3x/CpOTvP5FJlkDVGXctN8Ne7J2mOkSlwqQp
3dW6nD1Fs8YsGGTVuj3fq3/NQqyn8RgLoFgVYgukKm5Dw+QEnzWjR45G7TOlZb28
OyluAgtX33ToE7Q3NX6DXyAmYKbcrYvGbHmY9wIDAQABAoIBACFMCGaiKNW0+44P
chuFCQC58k438BxXS+NRf54jp+Q6mFUb6ot6mB682Lqx+YkSGGCs6MwLTglaQGq6
L5n4syRghLnOaZWa+eL8H1FNJxXbKyet77RprL59EOuGR3BztACHlRU7N/nnFOeA
u2geG+bdu3NjuWfmsid/z88wm8KY/dkYNi82LvE9gXqf4QMtR9s0UWI53U/prKiL
2dbzhMQXuXGdBghCeE27xSr0w1jNVSvtvjNfBOp75gQkY/It1z0bbNWcY0MvkoiN
Pm7aGDfYDyVniR25RjReyc7Ei+2SWjMHD9+GCPmS6dvrOAg2yc3NCgFIWzk+esrG
gKnc1DkCgYEA2XAG2OK81HiRUJTUwRuJOGxGZFpRoJoHPUiPA1HMaxKOfRqxZedx
dTngMgV1jRhMr5OxSbFmX3hietEMyuZNQ7Oc9Gt95gyY3M8hYo7VLhLeBK7XJG6D
MaIVokQ9IqliJiK5su1UCp0Ig6cHDf8ZGI7Yqx3aSJwxaBGhZm3j2B0CgYEA+0QX
i6Q2vh43Haf2YWwExKrdeD4HjB4zAq4DFIeDeuWefQhnqPKqvxJwz3Kpp8cLHYjV
IP2cY8pHMFVOi8TP9H8WpJISdKEJwsRunIwz76Xl9+ArrU9cEaoahDdb/Xrqw818
sMjkH1Rjtcev3/QJp/zHJfxc6ZHXksWYHlbTsSMCgYBRr+mSn5QLSoRlPpSzO5IQ
tXS4jMnvyQ4BMvovaBKhAyauz1FoFEwmmyikAjMIX+GncJgBNHleUo7Ezza8H0tV
rOvBU4TH4WGoStSi/0ANgB8SqVDAKhh1lAwGmxZQqEvsQc177/dLyXUCaMSYuIaI
GFpD5wIzlyJkk4MMRSp87QKBgGlmN8ZA3SHFBPOwuD5HlHx2/C3rPzk8lcNDAVHE
Qpfz6Bakxu7s1EkQUDgE7jvN19DMzDJpkAegG1qf/jHNHjp+cR4ZlBpOTwzfX1LV
0Rdu7NectlWd244hX7wkiLb8r6vw76QssNyfhrADEriL4t0PwO4jPUpQ/i+4KUZY
v7YnAoGAZhb5IDTQVCW8YTGsgvvvnDUefkpVAmiVDQqTvh6/4UD6kKdUcDHpePzg
Zrcid5rr3dXSMEbK4tdeQZvPtUg1Uaol3N7bNClIIdvWdPx+5S9T95wJcLnkoHam
rXp0IjScTxfLP+Cq5V6lJ94/pX8Ppoj1FdZfNxeS4NYFSRA7kvY=
-----END RSA PRIVATE KEY-----'''
# Some versions of OpenSSH generate these (slightly different keys): the PKCS#1
# structure is wrapped in an extra ASN.1 SEQUENCE and there's an empty SEQUENCE
# following it. It is not any standard key format and was probably a bug in
# OpenSSH at some point.
privateRSA_openssh_alternate = b"""-----BEGIN RSA PRIVATE KEY-----
MIIEqTCCBKMCAQACggEBANVqrHgj1tYb7CWhUMR3Y1CERQFVQhQqKuDQYO7U6aOtSvo5Bl6EVXVf
ADa/b6oqP4MmN8FpLlv98PPSfdaYzTpAeNXKqBjAEZMkCQyBTI/3nO0TFmqkBOlJd8PkVWSzeWie
LAjrrOgELSF3BaeO71MwDaXluz1q4gk2b/00031vRv+H2qkpJ6r/rfWF5j4auHodSrHqwFr3MN8f
wqTk7z+RSZZA1Rl3LTfDXuydpjpEpcKkKd3Vupw9RbPGLBhk1bo936t/zUKsp/EYC6BYFWILpCpu
Q8PkBJ81o0eORu0zpWW9vDspbgILV9906BO0NzV+g18gJmCm3K2Lxmx5mPcCAwEAAQKCAQAhTAhm
oijVtPuOD3IbhQkAufJON/AcV0vjUX+eI6fkOphVG+qLepgevNi6sfmJEhhgrOjMC04JWkBqui+Z
+LMkYIS5zmmVmvni/B9RTScV2ysnre+0aay+fRDrhkdwc7QAh5UVOzf55xTngLtoHhvm3btzY7ln
5rInf8/PMJvCmP3ZGDYvNi7xPYF6n+EDLUfbNFFiOd1P6ayoi9nW84TEF7lxnQYIQnhNu8Uq9MNY
zVUr7b4zXwTqe+YEJGPyLdc9G2zVnGNDL5KIjT5u2hg32A8lZ4kduUY0XsnOxIvtklozBw/fhgj5
kunb6zgINsnNzQoBSFs5PnrKxoCp3NQ5AoGBANlwBtjivNR4kVCU1MEbiThsRmRaUaCaBz1IjwNR
zGsSjn0asWXncXU54DIFdY0YTK+TsUmxZl94YnrRDMrmTUOznPRrfeYMmNzPIWKO1S4S3gSu1yRu
gzGiFaJEPSKpYiYiubLtVAqdCIOnBw3/GRiO2Ksd2kicMWgRoWZt49gdAoGBAPtEF4ukNr4eNx2n
9mFsBMSq3Xg+B4weMwKuAxSHg3rlnn0IZ6jyqr8ScM9yqafHCx2I1SD9nGPKRzBVTovEz/R/FqSS
EnShCcLEbpyMM++l5ffgK61PXBGqGoQ3W/166sPNfLDI5B9UY7XHr9/0Caf8xyX8XOmR15LFmB5W
07EjAoGAUa/pkp+UC0qEZT6UszuSELV0uIzJ78kOATL6L2gSoQMmrs9RaBRMJpsopAIzCF/hp3CY
ATR5XlKOxM82vB9LVazrwVOEx+FhqErUov9ADYAfEqlQwCoYdZQMBpsWUKhL7EHNe+/3S8l1AmjE
mLiGiBhaQ+cCM5ciZJODDEUqfO0CgYBpZjfGQN0hxQTzsLg+R5R8dvwt6z85PJXDQwFRxEKX8+gW
pMbu7NRJEFA4BO47zdfQzMwyaZAHoBtan/4xzR46fnEeGZQaTk8M319S1dEXbuzXnLZVnduOIV+8
JIi2/K+r8O+kLLDcn4awAxK4i+LdD8DuIz1KUP4vuClGWL+2JwKBgQCFSxt6mxIQN54frV7a/saW
/t81a7k04haXkiYJvb1wIAOnNb0tG6DSB0cr1N6oqAcHG7gEIKcnQTxsOTnpQc7nFx3RTFy8PdIm
Jv5q1v1Icq5G+nvD0xlgRB2lE6eA9WMp1HpdBgcWXfaLPctkOuKEWk2MBi0tnRzrg0x4PXlUzjAA
-----END RSA PRIVATE KEY-----"""
# New format introduced in OpenSSH 6.5
privateRSA_openssh_new = b'''-----BEGIN OPENSSH PRIVATE KEY-----
b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAABFwAAAAdzc2gtcn
NhAAAAAwEAAQAAAQEA1WqseCPW1hvsJaFQxHdjUIRFAVVCFCoq4NBg7tTpo61K+jkGXoRV
dV8ANr9vqio/gyY3wWkuW/3w89J91pjNOkB41cqoGMARkyQJDIFMj/ec7RMWaqQE6Ul3w+
RVZLN5aJ4sCOus6AQtIXcFp47vUzANpeW7PWriCTZv/TTTfW9G/4faqSknqv+t9YXmPhq4
eh1KserAWvcw3x/CpOTvP5FJlkDVGXctN8Ne7J2mOkSlwqQp3dW6nD1Fs8YsGGTVuj3fq3
/NQqyn8RgLoFgVYgukKm5Dw+QEnzWjR45G7TOlZb28OyluAgtX33ToE7Q3NX6DXyAmYKbc
rYvGbHmY9wAAA7gXkBoMF5AaDAAAAAdzc2gtcnNhAAABAQDVaqx4I9bWG+wloVDEd2NQhE
UBVUIUKirg0GDu1OmjrUr6OQZehFV1XwA2v2+qKj+DJjfBaS5b/fDz0n3WmM06QHjVyqgY
wBGTJAkMgUyP95ztExZqpATpSXfD5FVks3loniwI66zoBC0hdwWnju9TMA2l5bs9auIJNm
/9NNN9b0b/h9qpKSeq/631heY+Grh6HUqx6sBa9zDfH8Kk5O8/kUmWQNUZdy03w17snaY6
RKXCpCnd1bqcPUWzxiwYZNW6Pd+rf81CrKfxGAugWBViC6QqbkPD5ASfNaNHjkbtM6Vlvb
w7KW4CC1ffdOgTtDc1foNfICZgptyti8ZseZj3AAAAAwEAAQAAAQAhTAhmoijVtPuOD3Ib
hQkAufJON/AcV0vjUX+eI6fkOphVG+qLepgevNi6sfmJEhhgrOjMC04JWkBqui+Z+LMkYI
S5zmmVmvni/B9RTScV2ysnre+0aay+fRDrhkdwc7QAh5UVOzf55xTngLtoHhvm3btzY7ln
5rInf8/PMJvCmP3ZGDYvNi7xPYF6n+EDLUfbNFFiOd1P6ayoi9nW84TEF7lxnQYIQnhNu8
Uq9MNYzVUr7b4zXwTqe+YEJGPyLdc9G2zVnGNDL5KIjT5u2hg32A8lZ4kduUY0XsnOxIvt
klozBw/fhgj5kunb6zgINsnNzQoBSFs5PnrKxoCp3NQ5AAAAgQCFSxt6mxIQN54frV7a/s
aW/t81a7k04haXkiYJvb1wIAOnNb0tG6DSB0cr1N6oqAcHG7gEIKcnQTxsOTnpQc7nFx3R
TFy8PdImJv5q1v1Icq5G+nvD0xlgRB2lE6eA9WMp1HpdBgcWXfaLPctkOuKEWk2MBi0tnR
zrg0x4PXlUzgAAAIEA2XAG2OK81HiRUJTUwRuJOGxGZFpRoJoHPUiPA1HMaxKOfRqxZedx
dTngMgV1jRhMr5OxSbFmX3hietEMyuZNQ7Oc9Gt95gyY3M8hYo7VLhLeBK7XJG6DMaIVok
Q9IqliJiK5su1UCp0Ig6cHDf8ZGI7Yqx3aSJwxaBGhZm3j2B0AAACBAPtEF4ukNr4eNx2n
9mFsBMSq3Xg+B4weMwKuAxSHg3rlnn0IZ6jyqr8ScM9yqafHCx2I1SD9nGPKRzBVTovEz/
R/FqSSEnShCcLEbpyMM++l5ffgK61PXBGqGoQ3W/166sPNfLDI5B9UY7XHr9/0Caf8xyX8
XOmR15LFmB5W07EjAAAAAAEC
-----END OPENSSH PRIVATE KEY-----
'''
# Encrypted with the passphrase 'encrypted'
privateRSA_openssh_encrypted = b"""-----BEGIN RSA PRIVATE KEY-----
Proc-Type: 4,ENCRYPTED
DEK-Info: DES-EDE3-CBC,FFFFFFFFFFFFFFFF
p2A1YsHLXkpMVcsEqhh/nCYb5AqL0uMzfEIqc8hpZ/Ub8PtLsypilMkqzYTnZIGS
ouyPjU/WgtR4VaDnutPWdgYaKdixSEmGhKghCtXFySZqCTJ4O8NCczsktYjUK3D4
Jtl90zL6O81WBY6xP76PBQo9lrI/heAetATeyqutc18bwQIGU+gKk32qvfo15DfS
VYiY0Ds4D7F7fd9pz+f5+UbFUCgU+tfDvBrqodYrUgmH7jKoW/CRDCHHyeEIZDbF
mcMwdcKOyw1sRLaPdihRSVx3kOMvIotHKVTkIDMp+0RTNeXzQnp5U2qzsxzTcG/M
UyJN38XXkuvq5VMj2zmmjHzx34w3NK3ZxpZcoaFUqUBlNp2C8hkCLrAa/DWobKqN
5xA1ElrQvli9XXkT/RIuy4Gc10bbGEoJjuxNRibtSxxWd5Bd1E40ocOd4l1ebI8+
w69XvMTnsmHvkBEADGF2zfRszKnMelg+W5NER1UDuNT03i+1cuhp+2AZg8z7niTO
M17XP3ScGVxrQAEYgtxPrPeIpFJvOx2j5Yt78U9Y2WlaAG6DrubbYv2RsMIibhOG
yk139vMdD8FwCey6yMkkhFAJwnBtC22MAWgjmC5c6AF3SRQSjjQXepPsJcLgpOjy
YwjhnL8w56x9kVDUNPw9A9Cqgxo2sty34ATnKrh4h59PsP83LOL6OC5WjbASgZRd
OIBD8RloQPISo+RUF7X0i4kdaHVNPlR0KyapR+3M5BwhQuvEO99IArDV2LNKGzfc
W4ssugm8iyAJlmwmb2yRXIDHXabInWY7XCdGk8J2qPFbDTvnPbiagJBimjVjgpWw
tV3sVlJYqmOqmCDP78J6he04l0vaHtiOWTDEmNCrK7oFMXIIp3XWjOZGPSOJFdPs
6Go3YB+EGWfOQxqkFM28gcqmYfVPF2sa1FbZLz0ffO11Ma/rliZxZu7WdrAXe/tc
BgIQ8etp2PwAK4jCwwVwjIO8FzqQGpS23Y9NY3rfi97ckgYXKESFtXPsMMA+drZd
ThbXvccfh4EPmaqQXKf4WghHiVJ+/yuY1kUIDEl/O0jRZWT7STgBim/Aha1m6qRs
zl1H7hkDbU4solb1GM5oPzbgGTzyBc+z0XxM9iFRM+fMzPB8+yYHTr4kPbVmKBjy
SCovjQQVsHE4YeUGTq6k/NF5cVIRKTW/RlHvzxsky1Zj31MC736jrxGw4KG7VSLZ
fP6F5jj+mXwS7m0v5to42JBZmRJdKUD88QaGE3ncyQ4yleW5bn9Lf9SuzQg1Dhao
3rSA1RuexsHlIAHvGxx/17X+pyygl8DJbt6TBfbLQk9wc707DJTfh5M/bnk9wwIX
l/Hsa1WtylAMW/2MzgiVy83MbYz4+Ss6GQ5W66okWji+NxrnrYEy6q+WgVQanp7X
D+D7oKykqE1Cdvvulvtfl5fh8wlAs8mrUnKPBBUru348u++2lfacLkxRXyT1ooqY
uSNE5nlwFt08N2Ou/bl7yq6QNRMYrRkn+UEfHWCNYDoGMHln2/i6Z1RapQzNarik
tJf7radBz5nBwBjP08YAEACNSQvpsUgdqiuYjLwX7efFXQva2RzqaQ==
-----END RSA PRIVATE KEY-----"""
# Encrypted with the passphrase 'encrypted', and using the new format
# introduced in OpenSSH 6.5
privateRSA_openssh_encrypted_new = b"""-----BEGIN OPENSSH PRIVATE KEY-----
b3BlbnNzaC1rZXktdjEAAAAACmFlczI1Ni1jdHIAAAAGYmNyeXB0AAAAGAAAABD0f9WAof
DTbmwztb8pdrSeAAAAEAAAAAEAAAEXAAAAB3NzaC1yc2EAAAADAQABAAABAQDVaqx4I9bW
G+wloVDEd2NQhEUBVUIUKirg0GDu1OmjrUr6OQZehFV1XwA2v2+qKj+DJjfBaS5b/fDz0n
3WmM06QHjVyqgYwBGTJAkMgUyP95ztExZqpATpSXfD5FVks3loniwI66zoBC0hdwWnju9T
MA2l5bs9auIJNm/9NNN9b0b/h9qpKSeq/631heY+Grh6HUqx6sBa9zDfH8Kk5O8/kUmWQN
UZdy03w17snaY6RKXCpCnd1bqcPUWzxiwYZNW6Pd+rf81CrKfxGAugWBViC6QqbkPD5ASf
NaNHjkbtM6Vlvbw7KW4CC1ffdOgTtDc1foNfICZgptyti8ZseZj3AAADwPQaac8s1xX3af
hQTQexj0vEAWDQsLYzDHN9G7W+UP5WHUu7igeu2GqAC/TOnjUXDP73I+EN3n7T3JFeDRfs
U1Z6Zqb0NKHSRVYwDIdIi8qVohFv85g6+xQ01OpaoOzz+vI34OUvCRHQGTgR6L9fQShZyC
McopYMYfbIse6KcqkfxX3KSdG1Pao6Njx/ShFRbgvmALpR/z0EaGCzHCDxpfUyAdnxm621
Jzaf+LverWdN7sfrfMptaS9//9iJb70sL67K+YIB64qhDnA/w9UOQfXGQFL+AEtdM0BPv8
thP1bs7T0yucBl+ZXdrDKVLZfaS3S/w85Jlgfu+a1DG73pOBOuag435iEJ9EnspjXiiydx
GrfSRk2C+/c4fBDZVGFscK5bfQuUUZyU1qOagekxX7WLHFKk9xajnud+nrAN070SeNwlX8
FZ2CI4KGlQfDvVUpKanYn8Kkj3fZ+YBGyx4M+19clF65FKSM0x1Rrh5tAmNT/SNDbSc28m
ASxrBhztzxUFTrIn3tp+uqkJniFLmFsUtiAUmj8fNyE9blykU7dqq+CqpLA872nQ9bOHHA
JsS1oBYmQ0n6AJz8WrYMdcepqWVld6Q8QSD1zdrY/sAWUovuBA1s4oIEXZhpXSS4ZJiMfh
PVktKBwj5bmoG/mmwYLbo0JHntK8N3TGTzTGLq5TpSBBdVvWSWo7tnfEkrFObmhi1uJSrQ
3zfPVP6BguboxBv+oxhaUBK8UOANe6ZwM4vfiu+QN+sZqWymHIfAktz7eWzwlToe4cKpdG
Uv+e3/7Lo2dyMl3nke5HsSUrlsMGPREuGkBih8+o85ii6D+cuCiVtus3f5c78Cir80zLIr
Z0wWvEAjciEvml00DWaA+JIaOrWwvXySaOzFGpCqC9SQjao379bvn9P3b7kVZsy6zBfHqm
bNEJUOuhBZaY8Okz36chh1xqh4sz7m3nsZ3GYGcvM+3mvRY72QnqsQEG0Sp1XYIn2bHa29
tqp7CG9X8J6dqMcPeoPRDWIX9gw7EPl/M0LP6xgewGJ9bgxwle6Mnr9kNITIswjAJqrLec
zx7dfixjAPc42ADqrw/tEdFQcSqxigcfJNKO1LbDBjh+Hk/cSBou2PoxbIcl0qfQfbGcqI
Dbpd695IEuiW9pYR22txNoIi+7cbMsuFHxQ/OqbrX/jCsprGNNJLAjgGsVEI1JnHWDH0db
3UbqbOHAeY3ufoYXNY1utVOIACpW3r9wBw3FjRi04d70VcKr16OXvOAHGN2G++Y+kMya84
Hl/Kt/gA==
-----END OPENSSH PRIVATE KEY-----
"""
# Encrypted with the passphrase 'testxp'. NB: this key was generated by
# OpenSSH, so it doesn't use the same key data as the other keys here.
privateRSA_openssh_encrypted_aes = b"""-----BEGIN RSA PRIVATE KEY-----
Proc-Type: 4,ENCRYPTED
DEK-Info: AES-128-CBC,0673309A6ACCAB4B77DEE1C1E536AC26
4Ed/a9OgJWHJsne7yOGWeWMzHYKsxuP9w1v0aYcp+puS75wvhHLiUnNwxz0KDi6n
T3YkKLBsoCWS68ApR2J9yeQ6R+EyS+UQDrO9nwqo3DB5BT3Ggt8S1wE7vjNLQD0H
g/SJnlqwsECNhh8aAx+Ag0m3ZKOZiRD5mCkcDQsZET7URSmFytDKOjhFn3u6ZFVB
sXrfpYc6TJtOQlHd/52JB6aAbjt6afSv955Z7enIi+5yEJ5y7oYQTaE5zrFMP7N5
9LbfJFlKXxEddy/DErRLxEjmC+t4svHesoJKc2jjjyNPiOoGGF3kJXea62vsjdNV
gMK5Eged3TBVIk2dv8rtJUvyFeCUtjQ1UJZIebScRR47KrbsIpCmU8I4/uHWm5hW
0mOwvdx1L/mqx/BHqVU9Dw2COhOdLbFxlFI92chkovkmNk4P48ziyVnpm7ME22sE
vfCMsyirdqB1mrL4CSM7FXONv+CgfBfeYVkYW8RfJac9U1L/O+JNn7yee414O/rS
hRYw4UdWnH6Gg6niklVKWNY0ZwUZC8zgm2iqy8YCYuneS37jC+OEKP+/s6HSKuqk
2bzcl3/TcZXNSM815hnFRpz0anuyAsvwPNRyvxG2/DacJHL1f6luV4B0o6W410yf
qXQx01DLo7nuyhJqoH3UGCyyXB+/QUs0mbG2PAEn3f5dVs31JMdbt+PrxURXXjKk
4cexpUcIpqqlfpIRe3RD0sDVbH4OXsGhi2kiTfPZu7mgyFxKopRbn1KwU1qKinfY
EU9O4PoTak/tPT+5jFNhaP+HrURoi/pU8EAUNSktl7xAkHYwkN/9Cm7DeBghgf3n
8+tyCGYDsB5utPD0/Xe9yx0Qhc/kMm4xIyQDyA937dk3mUvLC9vulnAP8I+Izim0
fZ182+D1bWwykoD0997mUHG/AUChWR01V1OLwRyPv2wUtiS8VNG76Y2aqKlgqP1P
V+IvIEqR4ERvSBVFzXNF8Y6j/sVxo8+aZw+d0L1Ns/R55deErGg3B8i/2EqGd3r+
0jps9BqFHHWW87n3VyEB3jWCMj8Vi2EJIfa/7pSaViFIQn8LiBLf+zxG5LTOToK5
xkN42fReDcqi3UNfKNGnv4dsplyTR2hyx65lsj4bRKDGLKOuB1y7iB0AGb0LtcAI
dcsVlcCeUquDXtqKvRnwfIMg+ZunyjqHBhj3qgRgbXbT6zjaSdNnih569aTg0Vup
VykzZ7+n/KVcGLmvX0NesdoI7TKbq4TnEIOynuG5Sf+2GpARO5bjcWKSZeN/Ybgk
gccf8Cqf6XWqiwlWd0B7BR3SymeHIaSymC45wmbgdstrbk7Ppa2Tp9AZku8M2Y7c
8mY9b+onK075/ypiwBm4L4GRNTFLnoNQJXx0OSl4FNRWsn6ztbD+jZhu8Seu10Jw
SEJVJ+gmTKdRLYORJKyqhDet6g7kAxs4EoJ25WsOnX5nNr00rit+NkMPA7xbJT+7
CfI51GQLw7pUPeO2WNt6yZO/YkzZrqvTj5FEwybkUyBv7L0gkqu9wjfDdUw0fVHE
xEm4DxjEoaIp8dW/JOzXQ2EF+WaSOgdYsw3Ac+rnnjnNptCdOEDGP6QBkt+oXj4P
-----END RSA PRIVATE KEY-----"""
publicRSA_lsh = (
b'{KDEwOnB1YmxpYy1rZXkoMTQ6cnNhLXBrY3MxLXNoYTEoMTpuMjU3OgDVaqx4I9bWG+wloVD'
b'Ed2NQhEUBVUIUKirg0GDu1OmjrUr6OQZehFV1XwA2v2+qKj+DJjfBaS5b/fDz0n3WmM06QHj'
b'VyqgYwBGTJAkMgUyP95ztExZqpATpSXfD5FVks3loniwI66zoBC0hdwWnju9TMA2l5bs9auI'
b'JNm/9NNN9b0b/h9qpKSeq/631heY+Grh6HUqx6sBa9zDfH8Kk5O8/kUmWQNUZdy03w17snaY'
b'6RKXCpCnd1bqcPUWzxiwYZNW6Pd+rf81CrKfxGAugWBViC6QqbkPD5ASfNaNHjkbtM6Vlvbw'
b'7KW4CC1ffdOgTtDc1foNfICZgptyti8ZseZj3KSgxOmUzOgEAASkpKQ==}'
)
privateRSA_lsh = (
b"(11:private-key(9:rsa-pkcs1(1:n257:\x00\xd5j\xacx#\xd6\xd6\x1b\xec%\xa1P"
b"\xc4wcP\x84E\x01UB\x14**\xe0\xd0`\xee\xd4\xe9\xa3\xadJ\xfa9\x06^\x84Uu_"
b"\x006\xbfo\xaa*?\x83&7\xc1i.[\xfd\xf0\xf3\xd2}\xd6\x98\xcd:@x\xd5\xca"
b"\xa8\x18\xc0\x11\x93$\t\x0c\x81L\x8f\xf7\x9c\xed\x13\x16j\xa4\x04\xe9Iw"
b"\xc3\xe4Ud\xb3yh\x9e,\x08\xeb\xac\xe8\x04-!w\x05\xa7\x8e\xefS0\r\xa5\xe5"
b"\xbb=j\xe2\t6o\xfd4\xd3}oF\xff\x87\xda\xa9)'\xaa\xff\xad\xf5\x85\xe6>"
b"\x1a\xb8z\x1dJ\xb1\xea\xc0Z\xf70\xdf\x1f\xc2\xa4\xe4\xef?\x91I\x96@\xd5"
b"\x19w-7\xc3^\xec\x9d\xa6:D\xa5\xc2\xa4)\xdd\xd5\xba\x9c=E\xb3\xc6,\x18d"
b"\xd5\xba=\xdf\xab\x7f\xcdB\xac\xa7\xf1\x18\x0b\xa0X\x15b\x0b\xa4*nC\xc3"
b"\xe4\x04\x9f5\xa3G\x8eF\xed3\xa5e\xbd\xbc;)n\x02\x0bW\xdft\xe8\x13\xb475"
b"~\x83_ &`\xa6\xdc\xad\x8b\xc6ly\x98\xf7)(1:e3:\x01\x00\x01)(1:d256:!L"
b"\x08f\xa2(\xd5\xb4\xfb\x8e\x0fr\x1b\x85\t\x00\xb9\xf2N7\xf0\x1cWK\xe3Q"
b"\x7f\x9e#\xa7\xe4:\x98U\x1b\xea\x8bz\x98\x1e\xbc\xd8\xba\xb1\xf9\x89\x12"
b"\x18`\xac\xe8\xcc\x0bN\tZ@j\xba/\x99\xf8\xb3$`\x84\xb9\xcei\x95\x9a\xf9"
b"\xe2\xfc\x1fQM'\x15\xdb+'\xad\xef\xb4i\xac\xbe}\x10\xeb\x86Gps\xb4\x00"
b"\x87\x95\x15;7\xf9\xe7\x14\xe7\x80\xbbh\x1e\x1b\xe6\xdd\xbbsc\xb9g\xe6"
b"\xb2'\x7f\xcf\xcf0\x9b\xc2\x98\xfd\xd9\x186/6.\xf1=\x81z\x9f\xe1\x03-G"
b"\xdb4Qb9\xddO\xe9\xac\xa8\x8b\xd9\xd6\xf3\x84\xc4\x17\xb9q\x9d\x06\x08Bx"
b"M\xbb\xc5*\xf4\xc3X\xcdU+\xed\xbe3_\x04\xea{\xe6\x04$c\xf2-\xd7=\x1bl"
b"\xd5\x9ccC/\x92\x88\x8d>n\xda\x187\xd8\x0f%g\x89\x1d\xb9F4^\xc9\xce\xc4"
b"\x8b\xed\x92Z3\x07\x0f\xdf\x86\x08\xf9\x92\xe9\xdb\xeb8\x086\xc9\xcd\xcd"
b"\n\x01H[9>z\xca\xc6\x80\xa9\xdc\xd49)(1:p129:\x00\xfbD\x17\x8b\xa46\xbe"
b"\x1e7\x1d\xa7\xf6al\x04\xc4\xaa\xddx>\x07\x8c\x1e3\x02\xae\x03\x14\x87"
b"\x83z\xe5\x9e}\x08g\xa8\xf2\xaa\xbf\x12p\xcfr\xa9\xa7\xc7\x0b\x1d\x88"
b"\xd5 \xfd\x9cc\xcaG0UN\x8b\xc4\xcf\xf4\x7f\x16\xa4\x92\x12t\xa1\t\xc2"
b"\xc4n\x9c\x8c3\xef\xa5\xe5\xf7\xe0+\xadO\\\x11\xaa\x1a\x847[\xfdz\xea"
b"\xc3\xcd|\xb0\xc8\xe4\x1fTc\xb5\xc7\xaf\xdf\xf4\t\xa7\xfc\xc7%\xfc\\\xe9"
b"\x91\xd7\x92\xc5\x98\x1eV\xd3\xb1#)(1:q129:\x00\xd9p\x06\xd8\xe2\xbc\xd4"
b"x\x91P\x94\xd4\xc1\x1b\x898lFdZQ\xa0\x9a\x07=H\x8f\x03Q\xcck\x12\x8e}"
b"\x1a\xb1e\xe7qu9\xe02\x05u\x8d\x18L\xaf\x93\xb1I\xb1f_xbz\xd1\x0c\xca"
b"\xe6MC\xb3\x9c\xf4k}\xe6\x0c\x98\xdc\xcf!b\x8e\xd5.\x12\xde\x04\xae\xd7$"
b"n\x831\xa2\x15\xa2D=\"\xa9b&\"\xb9\xb2\xedT\n\x9d\x08\x83\xa7\x07\r\xff"
b"\x19\x18\x8e\xd8\xab\x1d\xdaH\x9c1h\x11\xa1fm\xe3\xd8\x1d)(1:a128:if7"
b"\xc6@\xdd!\xc5\x04\xf3\xb0\xb8>G\x94|v\xfc-\xeb?9<\x95\xc3C\x01Q\xc4B"
b"\x97\xf3\xe8\x16\xa4\xc6\xee\xec\xd4I\x10P8\x04\xee;\xcd\xd7\xd0\xcc\xcc"
b"2i\x90\x07\xa0\x1bZ\x9f\xfe1\xcd\x1e:~q\x1e\x19\x94\x1aNO\x0c\xdf_R\xd5"
b"\xd1\x17n\xec\xd7\x9c\xb6U\x9d\xdb\x8e!_\xbc$\x88\xb6\xfc\xaf\xab\xf0"
b"\xef\xa4,\xb0\xdc\x9f\x86\xb0\x03\x12\xb8\x8b\xe2\xdd\x0f\xc0\xee#=JP"
b"\xfe/\xb8)FX\xbf\xb6')(1:b128:Q\xaf\xe9\x92\x9f\x94\x0bJ\x84e>\x94\xb3;"
b"\x92\x10\xb5t\xb8\x8c\xc9\xef\xc9\x0e\x012\xfa/h\x12\xa1\x03&\xae\xcfQh"
b"\x14L&\x9b(\xa4\x023\x08_\xe1\xa7p\x98\x014y^R\x8e\xc4\xcf6\xbc\x1fKU"
b"\xac\xeb\xc1S\x84\xc7\xe1a\xa8J\xd4\xa2\xff@\r\x80\x1f\x12\xa9P\xc0*\x18"
b"u\x94\x0c\x06\x9b\x16P\xa8K\xecA\xcd{\xef\xf7K\xc9u\x02h\xc4\x98\xb8\x86"
b"\x88\x18ZC\xe7\x023\x97\"d\x93\x83\x0cE*|\xed)(1:c128:f\x16\xf9 4\xd0T%"
b"\xbca1\xac\x82\xfb\xef\x9c5\x1e~JU\x02h\x95\r\n\x93\xbe\x1e\xbf\xe1@\xfa"
b"\x90\xa7Tp1\xe9x\xfc\xe0f\xb7\"w\x9a\xeb\xdd\xd5\xd20F\xca\xe2\xd7^A\x9b"
b"\xcf\xb5H5Q\xaa%\xdc\xde\xdb4)H!\xdb\xd6t\xfc~\xe5/S\xf7\x9c\tp\xb9\xe4"
b"\xa0v\xa6\xadzt\"4\x9cO\x17\xcb?\xe0\xaa\xe5^\xa5\'\xde?\xa5\x7f\x0f\xa6"
b"\x88\xf5\x15\xd6_7\x17\x92\xe0\xd6\x05I\x10;\x92\xf6)))"
)
privateRSA_agentv3 = (
b"\x00\x00\x00\x07ssh-rsa\x00\x00\x00\x03\x01\x00\x01\x00\x00\x01\x00!L"
b"\x08f\xa2(\xd5\xb4\xfb\x8e\x0fr\x1b\x85\t\x00\xb9\xf2N7\xf0\x1cWK\xe3Q"
b"\x7f\x9e#\xa7\xe4:\x98U\x1b\xea\x8bz\x98\x1e\xbc\xd8\xba\xb1\xf9\x89\x12"
b"\x18`\xac\xe8\xcc\x0bN\tZ@j\xba/\x99\xf8\xb3$`\x84\xb9\xcei\x95\x9a\xf9"
b"\xe2\xfc\x1fQM'\x15\xdb+'\xad\xef\xb4i\xac\xbe}\x10\xeb\x86Gps\xb4\x00"
b"\x87\x95\x15;7\xf9\xe7\x14\xe7\x80\xbbh\x1e\x1b\xe6\xdd\xbbsc\xb9g\xe6"
b"\xb2'\x7f\xcf\xcf0\x9b\xc2\x98\xfd\xd9\x186/6.\xf1=\x81z\x9f\xe1\x03-G"
b"\xdb4Qb9\xddO\xe9\xac\xa8\x8b\xd9\xd6\xf3\x84\xc4\x17\xb9q\x9d\x06\x08Bx"
b"M\xbb\xc5*\xf4\xc3X\xcdU+\xed\xbe3_\x04\xea{\xe6\x04$c\xf2-\xd7=\x1bl"
b"\xd5\x9ccC/\x92\x88\x8d>n\xda\x187\xd8\x0f%g\x89\x1d\xb9F4^\xc9\xce\xc4"
b"\x8b\xed\x92Z3\x07\x0f\xdf\x86\x08\xf9\x92\xe9\xdb\xeb8\x086\xc9\xcd\xcd"
b"\n\x01H[9>z\xca\xc6\x80\xa9\xdc\xd49\x00\x00\x01\x01\x00\xd5j\xacx#\xd6"
b"\xd6\x1b\xec%\xa1P\xc4wcP\x84E\x01UB\x14**\xe0\xd0`\xee\xd4\xe9\xa3\xadJ"
b"\xfa9\x06^\x84Uu_\x006\xbfo\xaa*?\x83&7\xc1i.[\xfd\xf0\xf3\xd2}\xd6\x98"
b"\xcd:@x\xd5\xca\xa8\x18\xc0\x11\x93$\t\x0c\x81L\x8f\xf7\x9c\xed\x13\x16j"
b"\xa4\x04\xe9Iw\xc3\xe4Ud\xb3yh\x9e,\x08\xeb\xac\xe8\x04-!w\x05\xa7\x8e"
b"\xefS0\r\xa5\xe5\xbb=j\xe2\t6o\xfd4\xd3}oF\xff\x87\xda\xa9)'\xaa\xff\xad"
b"\xf5\x85\xe6>\x1a\xb8z\x1dJ\xb1\xea\xc0Z\xf70\xdf\x1f\xc2\xa4\xe4\xef?"
b"\x91I\x96@\xd5\x19w-7\xc3^\xec\x9d\xa6:D\xa5\xc2\xa4)\xdd\xd5\xba\x9c=E"
b"\xb3\xc6,\x18d\xd5\xba=\xdf\xab\x7f\xcdB\xac\xa7\xf1\x18\x0b\xa0X\x15b"
b"\x0b\xa4*nC\xc3\xe4\x04\x9f5\xa3G\x8eF\xed3\xa5e\xbd\xbc;)n\x02\x0bW\xdf"
b"t\xe8\x13\xb475~\x83_ &`\xa6\xdc\xad\x8b\xc6ly\x98\xf7\x00\x00\x00\x81"
b"\x00\x85K\x1bz\x9b\x12\x107\x9e\x1f\xad^\xda\xfe\xc6\x96\xfe\xdf5k\xb94"
b"\xe2\x16\x97\x92&\t\xbd\xbdp \x03\xa75\xbd-\x1b\xa0\xd2\x07G+\xd4\xde"
b"\xa8\xa8\x07\x07\x1b\xb8\x04 \xa7'A<l99\xe9A\xce\xe7\x17\x1d\xd1L\\\xbc="
b"\xd2&&\xfej\xd6\xfdHr\xaeF\xfa{\xc3\xd3\x19`D\x1d\xa5\x13\xa7\x80\xf5c)"
b"\xd4z]\x06\x07\x16]\xf6\x8b=\xcbd:\xe2\x84ZM\x8c\x06--\x9d\x1c\xeb\x83Lx"
b"=yT\xce\x00\x00\x00\x81\x00\xd9p\x06\xd8\xe2\xbc\xd4x\x91P\x94\xd4\xc1"
b"\x1b\x898lFdZQ\xa0\x9a\x07=H\x8f\x03Q\xcck\x12\x8e}\x1a\xb1e\xe7qu9\xe02"
b"\x05u\x8d\x18L\xaf\x93\xb1I\xb1f_xbz\xd1\x0c\xca\xe6MC\xb3\x9c\xf4k}\xe6"
b"\x0c\x98\xdc\xcf!b\x8e\xd5.\x12\xde\x04\xae\xd7$n\x831\xa2\x15\xa2D=\""
b"\xa9b&\"\xb9\xb2\xedT\n\x9d\x08\x83\xa7\x07\r\xff\x19\x18\x8e\xd8\xab"
b"\x1d\xdaH\x9c1h\x11\xa1fm\xe3\xd8\x1d\x00\x00\x00\x81\x00\xfbD\x17\x8b"
b"\xa46\xbe\x1e7\x1d\xa7\xf6al\x04\xc4\xaa\xddx>\x07\x8c\x1e3\x02\xae\x03"
b"\x14\x87\x83z\xe5\x9e}\x08g\xa8\xf2\xaa\xbf\x12p\xcfr\xa9\xa7\xc7\x0b"
b"\x1d\x88\xd5 \xfd\x9cc\xcaG0UN\x8b\xc4\xcf\xf4\x7f\x16\xa4\x92\x12t\xa1"
b"\t\xc2\xc4n\x9c\x8c3\xef\xa5\xe5\xf7\xe0+\xadO\\\x11\xaa\x1a\x847[\xfdz"
b"\xea\xc3\xcd|\xb0\xc8\xe4\x1fTc\xb5\xc7\xaf\xdf\xf4\t\xa7\xfc\xc7%\xfc\\"
b"\xe9\x91\xd7\x92\xc5\x98\x1eV\xd3\xb1#"
)
publicDSA_openssh = b"""\
ssh-dss AAAAB3NzaC1kc3MAAACBAJKQOsVERVDQIpANHH+JAAylo9\
LvFYmFFVMIuHFGlZpIL7sh3IMkqy+cssINM/lnHD3fmsAyLlUXZtt6PD9LgZRazsPOgptuH+Gu48G\
+yFuE8l0fVVUivos/MmYVJ66qT99htcZKatrTWZnpVW7gFABoqw+he2LZ0gkeU0+Sx9a5AAAAFQD0\
EYmTNaFJ8CS0+vFSF4nYcyEnSQAAAIEAkgLjxHJAE7qFWdTqf7EZngu7jAGmdB9k3YzMHe1ldMxEB\
7zNw5aOnxjhoYLtiHeoEcOk2XOyvnE+VfhIWwWAdOiKRTEZlmizkvhGbq0DCe2EPMXirjqWACI5nD\
ioQX1oEMonR8N3AEO5v9SfBqS2Q9R6OBr6lf04RvwpHZ0UGu8AAACAAhRpxGMIWEyaEh8YnjiazQT\
NEpklRZqeBGo1gotJggNmVaIQNIClGlLyCi359efEUuQcZ9SXxM59P+hecc/GU/GHakW5YWE4dP2G\
gdgMQWC7S6WFIXePGGXqNQDdWxlX8umhenvQqa1PnKrFRhDrJw8Z7GjdHxflsxCEmXPoLN8= \
comment\
"""
privateDSA_openssh = b"""\
-----BEGIN DSA PRIVATE KEY-----
MIIBvAIBAAKBgQCSkDrFREVQ0CKQDRx/iQAMpaPS7xWJhRVTCLhxRpWaSC+7IdyD
JKsvnLLCDTP5Zxw935rAMi5VF2bbejw/S4GUWs7DzoKbbh/hruPBvshbhPJdH1VV
Ir6LPzJmFSeuqk/fYbXGSmra01mZ6VVu4BQAaKsPoXti2dIJHlNPksfWuQIVAPQR
iZM1oUnwJLT68VIXidhzISdJAoGBAJIC48RyQBO6hVnU6n+xGZ4Lu4wBpnQfZN2M
zB3tZXTMRAe8zcOWjp8Y4aGC7Yh3qBHDpNlzsr5xPlX4SFsFgHToikUxGZZos5L4
Rm6tAwnthDzF4q46lgAiOZw4qEF9aBDKJ0fDdwBDub/UnwaktkPUejga+pX9OEb8
KR2dFBrvAoGAAhRpxGMIWEyaEh8YnjiazQTNEpklRZqeBGo1gotJggNmVaIQNICl
GlLyCi359efEUuQcZ9SXxM59P+hecc/GU/GHakW5YWE4dP2GgdgMQWC7S6WFIXeP
GGXqNQDdWxlX8umhenvQqa1PnKrFRhDrJw8Z7GjdHxflsxCEmXPoLN8CFQDV2gbL
czUdxCus0pfEP1bddaXRLQ==
-----END DSA PRIVATE KEY-----\
"""
privateDSA_openssh_new = b"""\
-----BEGIN OPENSSH PRIVATE KEY-----
b3BlbnNzaC1rZXktdjEAAAAABG5vbmUAAAAEbm9uZQAAAAAAAAABAAABsgAAAAdzc2gtZH
NzAAAAgQCSkDrFREVQ0CKQDRx/iQAMpaPS7xWJhRVTCLhxRpWaSC+7IdyDJKsvnLLCDTP5
Zxw935rAMi5VF2bbejw/S4GUWs7DzoKbbh/hruPBvshbhPJdH1VVIr6LPzJmFSeuqk/fYb
XGSmra01mZ6VVu4BQAaKsPoXti2dIJHlNPksfWuQAAABUA9BGJkzWhSfAktPrxUheJ2HMh
J0kAAACBAJIC48RyQBO6hVnU6n+xGZ4Lu4wBpnQfZN2MzB3tZXTMRAe8zcOWjp8Y4aGC7Y
h3qBHDpNlzsr5xPlX4SFsFgHToikUxGZZos5L4Rm6tAwnthDzF4q46lgAiOZw4qEF9aBDK
J0fDdwBDub/UnwaktkPUejga+pX9OEb8KR2dFBrvAAAAgAIUacRjCFhMmhIfGJ44ms0EzR
KZJUWangRqNYKLSYIDZlWiEDSApRpS8got+fXnxFLkHGfUl8TOfT/oXnHPxlPxh2pFuWFh
OHT9hoHYDEFgu0ulhSF3jxhl6jUA3VsZV/LpoXp70KmtT5yqxUYQ6ycPGexo3R8X5bMQhJ
lz6CzfAAAB2MVcBjzFXAY8AAAAB3NzaC1kc3MAAACBAJKQOsVERVDQIpANHH+JAAylo9Lv
FYmFFVMIuHFGlZpIL7sh3IMkqy+cssINM/lnHD3fmsAyLlUXZtt6PD9LgZRazsPOgptuH+
Gu48G+yFuE8l0fVVUivos/MmYVJ66qT99htcZKatrTWZnpVW7gFABoqw+he2LZ0gkeU0+S
x9a5AAAAFQD0EYmTNaFJ8CS0+vFSF4nYcyEnSQAAAIEAkgLjxHJAE7qFWdTqf7EZngu7jA
GmdB9k3YzMHe1ldMxEB7zNw5aOnxjhoYLtiHeoEcOk2XOyvnE+VfhIWwWAdOiKRTEZlmiz
kvhGbq0DCe2EPMXirjqWACI5nDioQX1oEMonR8N3AEO5v9SfBqS2Q9R6OBr6lf04RvwpHZ
0UGu8AAACAAhRpxGMIWEyaEh8YnjiazQTNEpklRZqeBGo1gotJggNmVaIQNIClGlLyCi35
9efEUuQcZ9SXxM59P+hecc/GU/GHakW5YWE4dP2GgdgMQWC7S6WFIXePGGXqNQDdWxlX8u
mhenvQqa1PnKrFRhDrJw8Z7GjdHxflsxCEmXPoLN8AAAAVANXaBstzNR3EK6zSl8Q/Vt11
pdEtAAAAAAE=
-----END OPENSSH PRIVATE KEY-----
"""
publicDSA_lsh = decodebytes(b"""\
e0tERXdPbkIxWW14cFl5MXJaWGtvTXpwa2MyRW9NVHB3TVRJNU9nQ1NrRHJGUkVWUTBDS1FEUngv
aVFBTXBhUFM3eFdKaFJWVENMaHhScFdhU0MrN0lkeURKS3N2bkxMQ0RUUDVaeHc5MzVyQU1pNVZG
MmJiZWp3L1M0R1VXczdEem9LYmJoL2hydVBCdnNoYmhQSmRIMVZWSXI2TFB6Sm1GU2V1cWsvZlli
WEdTbXJhMDFtWjZWVnU0QlFBYUtzUG9YdGkyZElKSGxOUGtzZld1U2tvTVRweE1qRTZBUFFSaVpN
MW9VbndKTFQ2OFZJWGlkaHpJU2RKS1NneE9tY3hNams2QUpJQzQ4UnlRQk82aFZuVTZuK3hHWjRM
dTR3QnBuUWZaTjJNekIzdFpYVE1SQWU4emNPV2pwOFk0YUdDN1loM3FCSERwTmx6c3I1eFBsWDRT
RnNGZ0hUb2lrVXhHWlpvczVMNFJtNnRBd250aER6RjRxNDZsZ0FpT1p3NHFFRjlhQkRLSjBmRGR3
QkR1Yi9Vbndha3RrUFVlamdhK3BYOU9FYjhLUjJkRkJydktTZ3hPbmt4TWpnNkFoUnB4R01JV0V5
YUVoOFluamlhelFUTkVwa2xSWnFlQkdvMWdvdEpnZ05tVmFJUU5JQ2xHbEx5Q2kzNTllZkVVdVFj
WjlTWHhNNTlQK2hlY2MvR1UvR0hha1c1WVdFNGRQMkdnZGdNUVdDN1M2V0ZJWGVQR0dYcU5RRGRX
eGxYOHVtaGVudlFxYTFQbktyRlJoRHJKdzhaN0dqZEh4ZmxzeENFbVhQb0xOOHBLU2s9fQ==
""")
privateDSA_lsh = decodebytes(b"""\
KDExOnByaXZhdGUta2V5KDM6ZHNhKDE6cDEyOToAkpA6xURFUNAikA0cf4kADKWj0u8ViYUVUwi4
cUaVmkgvuyHcgySrL5yywg0z+WccPd+awDIuVRdm23o8P0uBlFrOw86Cm24f4a7jwb7IW4TyXR9V
VSK+iz8yZhUnrqpP32G1xkpq2tNZmelVbuAUAGirD6F7YtnSCR5TT5LH1rkpKDE6cTIxOgD0EYmT
NaFJ8CS0+vFSF4nYcyEnSSkoMTpnMTI5OgCSAuPEckATuoVZ1Op/sRmeC7uMAaZ0H2TdjMwd7WV0
zEQHvM3Dlo6fGOGhgu2Id6gRw6TZc7K+cT5V+EhbBYB06IpFMRmWaLOS+EZurQMJ7YQ8xeKuOpYA
IjmcOKhBfWgQyidHw3cAQ7m/1J8GpLZD1Ho4GvqV/ThG/CkdnRQa7ykoMTp5MTI4OgIUacRjCFhM
mhIfGJ44ms0EzRKZJUWangRqNYKLSYIDZlWiEDSApRpS8got+fXnxFLkHGfUl8TOfT/oXnHPxlPx
h2pFuWFhOHT9hoHYDEFgu0ulhSF3jxhl6jUA3VsZV/LpoXp70KmtT5yqxUYQ6ycPGexo3R8X5bMQ
hJlz6CzfKSgxOngyMToA1doGy3M1HcQrrNKXxD9W3XWl0S0pKSk=
""")
privateDSA_agentv3 = decodebytes(b"""\
AAAAB3NzaC1kc3MAAACBAJKQOsVERVDQIpANHH+JAAylo9LvFYmFFVMIuHFGlZpIL7sh3IMkqy+c
ssINM/lnHD3fmsAyLlUXZtt6PD9LgZRazsPOgptuH+Gu48G+yFuE8l0fVVUivos/MmYVJ66qT99h
tcZKatrTWZnpVW7gFABoqw+he2LZ0gkeU0+Sx9a5AAAAFQD0EYmTNaFJ8CS0+vFSF4nYcyEnSQAA
AIEAkgLjxHJAE7qFWdTqf7EZngu7jAGmdB9k3YzMHe1ldMxEB7zNw5aOnxjhoYLtiHeoEcOk2XOy
vnE+VfhIWwWAdOiKRTEZlmizkvhGbq0DCe2EPMXirjqWACI5nDioQX1oEMonR8N3AEO5v9SfBqS2
Q9R6OBr6lf04RvwpHZ0UGu8AAACAAhRpxGMIWEyaEh8YnjiazQTNEpklRZqeBGo1gotJggNmVaIQ
NIClGlLyCi359efEUuQcZ9SXxM59P+hecc/GU/GHakW5YWE4dP2GgdgMQWC7S6WFIXePGGXqNQDd
WxlX8umhenvQqa1PnKrFRhDrJw8Z7GjdHxflsxCEmXPoLN8AAAAVANXaBstzNR3EK6zSl8Q/Vt11
pdEt
""")
__all__ = ['DSAData', 'RSAData', 'privateDSA_agentv3', 'privateDSA_lsh',
'privateDSA_openssh', 'privateRSA_agentv3', 'privateRSA_lsh',
'privateRSA_openssh', 'publicDSA_lsh', 'publicDSA_openssh',
'publicRSA_lsh', 'publicRSA_openssh', 'privateRSA_openssh_alternate']

View file

@ -0,0 +1,28 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Loopback helper used in test_ssh and test_recvline
"""
from __future__ import division, absolute_import
from twisted.protocols import loopback
class LoopbackRelay(loopback.LoopbackRelay):
clearCall = None
def logPrefix(self):
return "LoopbackRelay(%r)" % (self.target.__class__.__name__,)
def write(self, data):
loopback.LoopbackRelay.write(self, data)
if self.clearCall is not None:
self.clearCall.cancel()
from twisted.internet import reactor
self.clearCall = reactor.callLater(0, self._clearBuffer)
def _clearBuffer(self):
self.clearCall = None
loopback.LoopbackRelay.clearBuffer(self)

View file

@ -0,0 +1,50 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{SSHTransportAddrress} in ssh/address.py
"""
from __future__ import division, absolute_import
from twisted.trial import unittest
from twisted.internet.address import IPv4Address
from twisted.internet.test.test_address import AddressTestCaseMixin
from twisted.conch.ssh.address import SSHTransportAddress
class SSHTransportAddressTests(unittest.TestCase, AddressTestCaseMixin):
"""
L{twisted.conch.ssh.address.SSHTransportAddress} is what Conch transports
use to represent the other side of the SSH connection. This tests the
basic functionality of that class (string representation, comparison, &c).
"""
def _stringRepresentation(self, stringFunction):
"""
The string representation of C{SSHTransportAddress} should be
"SSHTransportAddress(<stringFunction on address>)".
"""
addr = self.buildAddress()
stringValue = stringFunction(addr)
addressValue = stringFunction(addr.address)
self.assertEqual(stringValue,
"SSHTransportAddress(%s)" % addressValue)
def buildAddress(self):
"""
Create an arbitrary new C{SSHTransportAddress}. A new instance is
created for each call, but always for the same address.
"""
return SSHTransportAddress(IPv4Address("TCP", "127.0.0.1", 22))
def buildDifferentAddress(self):
"""
Like C{buildAddress}, but with a different fixed address.
"""
return SSHTransportAddress(IPv4Address("TCP", "127.0.0.2", 22))

View file

@ -0,0 +1,394 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.conch.ssh.agent}.
"""
from __future__ import absolute_import, division
import struct
from twisted.trial import unittest
from twisted.test import iosim
try:
import cryptography
except ImportError:
cryptography = None
try:
import pyasn1
except ImportError:
pyasn1 = None
if cryptography and pyasn1:
from twisted.conch.ssh import keys, agent
else:
keys = agent = None
from twisted.conch.test import keydata
from twisted.conch.error import ConchError, MissingKeyStoreError
class StubFactory(object):
"""
Mock factory that provides the keys attribute required by the
SSHAgentServerProtocol
"""
def __init__(self):
self.keys = {}
class AgentTestBase(unittest.TestCase):
"""
Tests for SSHAgentServer/Client.
"""
if iosim is None:
skip = "iosim requires SSL, but SSL is not available"
elif agent is None or keys is None:
skip = "Cannot run without cryptography or PyASN1"
def setUp(self):
# wire up our client <-> server
self.client, self.server, self.pump = iosim.connectedServerAndClient(
agent.SSHAgentServer, agent.SSHAgentClient)
# the server's end of the protocol is stateful and we store it on the
# factory, for which we only need a mock
self.server.factory = StubFactory()
# pub/priv keys of each kind
self.rsaPrivate = keys.Key.fromString(keydata.privateRSA_openssh)
self.dsaPrivate = keys.Key.fromString(keydata.privateDSA_openssh)
self.rsaPublic = keys.Key.fromString(keydata.publicRSA_openssh)
self.dsaPublic = keys.Key.fromString(keydata.publicDSA_openssh)
class ServerProtocolContractWithFactoryTests(AgentTestBase):
"""
The server protocol is stateful and so uses its factory to track state
across requests. This test asserts that the protocol raises if its factory
doesn't provide the necessary storage for that state.
"""
def test_factorySuppliesKeyStorageForServerProtocol(self):
# need a message to send into the server
msg = struct.pack('!LB',1, agent.AGENTC_REQUEST_IDENTITIES)
del self.server.factory.__dict__['keys']
self.assertRaises(MissingKeyStoreError,
self.server.dataReceived, msg)
class UnimplementedVersionOneServerTests(AgentTestBase):
"""
Tests for methods with no-op implementations on the server. We need these
for clients, such as openssh, that try v1 methods before going to v2.
Because the client doesn't expose these operations with nice method names,
we invoke sendRequest directly with an op code.
"""
def test_agentc_REQUEST_RSA_IDENTITIES(self):
"""
assert that we get the correct op code for an RSA identities request
"""
d = self.client.sendRequest(agent.AGENTC_REQUEST_RSA_IDENTITIES, b'')
self.pump.flush()
def _cb(packet):
self.assertEqual(
agent.AGENT_RSA_IDENTITIES_ANSWER, ord(packet[0:1]))
return d.addCallback(_cb)
def test_agentc_REMOVE_RSA_IDENTITY(self):
"""
assert that we get the correct op code for an RSA remove identity request
"""
d = self.client.sendRequest(agent.AGENTC_REMOVE_RSA_IDENTITY, b'')
self.pump.flush()
return d.addCallback(self.assertEqual, b'')
def test_agentc_REMOVE_ALL_RSA_IDENTITIES(self):
"""
assert that we get the correct op code for an RSA remove all identities
request.
"""
d = self.client.sendRequest(agent.AGENTC_REMOVE_ALL_RSA_IDENTITIES, b'')
self.pump.flush()
return d.addCallback(self.assertEqual, b'')
if agent is not None:
class CorruptServer(agent.SSHAgentServer):
"""
A misbehaving server that returns bogus response op codes so that we can
verify that our callbacks that deal with these op codes handle such
miscreants.
"""
def agentc_REQUEST_IDENTITIES(self, data):
self.sendResponse(254, b'')
def agentc_SIGN_REQUEST(self, data):
self.sendResponse(254, b'')
class ClientWithBrokenServerTests(AgentTestBase):
"""
verify error handling code in the client using a misbehaving server
"""
def setUp(self):
AgentTestBase.setUp(self)
self.client, self.server, self.pump = iosim.connectedServerAndClient(
CorruptServer, agent.SSHAgentClient)
# the server's end of the protocol is stateful and we store it on the
# factory, for which we only need a mock
self.server.factory = StubFactory()
def test_signDataCallbackErrorHandling(self):
"""
Assert that L{SSHAgentClient.signData} raises a ConchError
if we get a response from the server whose opcode doesn't match
the protocol for data signing requests.
"""
d = self.client.signData(self.rsaPublic.blob(), b"John Hancock")
self.pump.flush()
return self.assertFailure(d, ConchError)
def test_requestIdentitiesCallbackErrorHandling(self):
"""
Assert that L{SSHAgentClient.requestIdentities} raises a ConchError
if we get a response from the server whose opcode doesn't match
the protocol for identity requests.
"""
d = self.client.requestIdentities()
self.pump.flush()
return self.assertFailure(d, ConchError)
class AgentKeyAdditionTests(AgentTestBase):
"""
Test adding different flavors of keys to an agent.
"""
def test_addRSAIdentityNoComment(self):
"""
L{SSHAgentClient.addIdentity} adds the private key it is called
with to the SSH agent server to which it is connected, associating
it with the comment it is called with.
This test asserts that omitting the comment produces an
empty string for the comment on the server.
"""
d = self.client.addIdentity(self.rsaPrivate.privateBlob())
self.pump.flush()
def _check(ignored):
serverKey = self.server.factory.keys[self.rsaPrivate.blob()]
self.assertEqual(self.rsaPrivate, serverKey[0])
self.assertEqual(b'', serverKey[1])
return d.addCallback(_check)
def test_addDSAIdentityNoComment(self):
"""
L{SSHAgentClient.addIdentity} adds the private key it is called
with to the SSH agent server to which it is connected, associating
it with the comment it is called with.
This test asserts that omitting the comment produces an
empty string for the comment on the server.
"""
d = self.client.addIdentity(self.dsaPrivate.privateBlob())
self.pump.flush()
def _check(ignored):
serverKey = self.server.factory.keys[self.dsaPrivate.blob()]
self.assertEqual(self.dsaPrivate, serverKey[0])
self.assertEqual(b'', serverKey[1])
return d.addCallback(_check)
def test_addRSAIdentityWithComment(self):
"""
L{SSHAgentClient.addIdentity} adds the private key it is called
with to the SSH agent server to which it is connected, associating
it with the comment it is called with.
This test asserts that the server receives/stores the comment
as sent by the client.
"""
d = self.client.addIdentity(
self.rsaPrivate.privateBlob(), comment=b'My special key')
self.pump.flush()
def _check(ignored):
serverKey = self.server.factory.keys[self.rsaPrivate.blob()]
self.assertEqual(self.rsaPrivate, serverKey[0])
self.assertEqual(b'My special key', serverKey[1])
return d.addCallback(_check)
def test_addDSAIdentityWithComment(self):
"""
L{SSHAgentClient.addIdentity} adds the private key it is called
with to the SSH agent server to which it is connected, associating
it with the comment it is called with.
This test asserts that the server receives/stores the comment
as sent by the client.
"""
d = self.client.addIdentity(
self.dsaPrivate.privateBlob(), comment=b'My special key')
self.pump.flush()
def _check(ignored):
serverKey = self.server.factory.keys[self.dsaPrivate.blob()]
self.assertEqual(self.dsaPrivate, serverKey[0])
self.assertEqual(b'My special key', serverKey[1])
return d.addCallback(_check)
class AgentClientFailureTests(AgentTestBase):
def test_agentFailure(self):
"""
verify that the client raises ConchError on AGENT_FAILURE
"""
d = self.client.sendRequest(254, b'')
self.pump.flush()
return self.assertFailure(d, ConchError)
class AgentIdentityRequestsTests(AgentTestBase):
"""
Test operations against a server with identities already loaded.
"""
def setUp(self):
AgentTestBase.setUp(self)
self.server.factory.keys[self.dsaPrivate.blob()] = (
self.dsaPrivate, b'a comment')
self.server.factory.keys[self.rsaPrivate.blob()] = (
self.rsaPrivate, b'another comment')
def test_signDataRSA(self):
"""
Sign data with an RSA private key and then verify it with the public
key.
"""
d = self.client.signData(self.rsaPublic.blob(), b"John Hancock")
self.pump.flush()
signature = self.successResultOf(d)
expected = self.rsaPrivate.sign(b"John Hancock")
self.assertEqual(expected, signature)
self.assertTrue(self.rsaPublic.verify(signature, b"John Hancock"))
def test_signDataDSA(self):
"""
Sign data with a DSA private key and then verify it with the public
key.
"""
d = self.client.signData(self.dsaPublic.blob(), b"John Hancock")
self.pump.flush()
def _check(sig):
# Cannot do this b/c DSA uses random numbers when signing
# expected = self.dsaPrivate.sign("John Hancock")
# self.assertEqual(expected, sig)
self.assertTrue(self.dsaPublic.verify(sig, b"John Hancock"))
return d.addCallback(_check)
def test_signDataRSAErrbackOnUnknownBlob(self):
"""
Assert that we get an errback if we try to sign data using a key that
wasn't added.
"""
del self.server.factory.keys[self.rsaPublic.blob()]
d = self.client.signData(self.rsaPublic.blob(), b"John Hancock")
self.pump.flush()
return self.assertFailure(d, ConchError)
def test_requestIdentities(self):
"""
Assert that we get all of the keys/comments that we add when we issue a
request for all identities.
"""
d = self.client.requestIdentities()
self.pump.flush()
def _check(keyt):
expected = {}
expected[self.dsaPublic.blob()] = b'a comment'
expected[self.rsaPublic.blob()] = b'another comment'
received = {}
for k in keyt:
received[keys.Key.fromString(k[0], type='blob').blob()] = k[1]
self.assertEqual(expected, received)
return d.addCallback(_check)
class AgentKeyRemovalTests(AgentTestBase):
"""
Test support for removing keys in a remote server.
"""
def setUp(self):
AgentTestBase.setUp(self)
self.server.factory.keys[self.dsaPrivate.blob()] = (
self.dsaPrivate, b'a comment')
self.server.factory.keys[self.rsaPrivate.blob()] = (
self.rsaPrivate, b'another comment')
def test_removeRSAIdentity(self):
"""
Assert that we can remove an RSA identity.
"""
# only need public key for this
d = self.client.removeIdentity(self.rsaPrivate.blob())
self.pump.flush()
def _check(ignored):
self.assertEqual(1, len(self.server.factory.keys))
self.assertIn(self.dsaPrivate.blob(), self.server.factory.keys)
self.assertNotIn(self.rsaPrivate.blob(), self.server.factory.keys)
return d.addCallback(_check)
def test_removeDSAIdentity(self):
"""
Assert that we can remove a DSA identity.
"""
# only need public key for this
d = self.client.removeIdentity(self.dsaPrivate.blob())
self.pump.flush()
def _check(ignored):
self.assertEqual(1, len(self.server.factory.keys))
self.assertIn(self.rsaPrivate.blob(), self.server.factory.keys)
return d.addCallback(_check)
def test_removeAllIdentities(self):
"""
Assert that we can remove all identities.
"""
d = self.client.removeAllIdentities()
self.pump.flush()
def _check(ignored):
self.assertEqual(0, len(self.server.factory.keys))
return d.addCallback(_check)

File diff suppressed because it is too large Load diff

View file

@ -0,0 +1,355 @@
# Copyright Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Test ssh/channel.py.
"""
from __future__ import division, absolute_import
from zope.interface.verify import verifyObject
try:
from twisted.conch.ssh import channel
from twisted.conch.ssh.address import SSHTransportAddress
from twisted.conch.ssh.transport import SSHServerTransport
from twisted.conch.ssh.service import SSHService
from twisted.internet import interfaces
from twisted.internet.address import IPv4Address
from twisted.test.proto_helpers import StringTransport
skipTest = None
except ImportError:
skipTest = 'Conch SSH not supported.'
SSHService = object
from twisted.trial import unittest
from twisted.python.compat import intToBytes
class MockConnection(SSHService):
"""
A mock for twisted.conch.ssh.connection.SSHConnection. Record the data
that channels send, and when they try to close the connection.
@ivar data: a L{dict} mapping channel id #s to lists of data sent by that
channel.
@ivar extData: a L{dict} mapping channel id #s to lists of 2-tuples
(extended data type, data) sent by that channel.
@ivar closes: a L{dict} mapping channel id #s to True if that channel sent
a close message.
"""
def __init__(self):
self.data = {}
self.extData = {}
self.closes = {}
def logPrefix(self):
"""
Return our logging prefix.
"""
return "MockConnection"
def sendData(self, channel, data):
"""
Record the sent data.
"""
self.data.setdefault(channel, []).append(data)
def sendExtendedData(self, channel, type, data):
"""
Record the sent extended data.
"""
self.extData.setdefault(channel, []).append((type, data))
def sendClose(self, channel):
"""
Record that the channel sent a close message.
"""
self.closes[channel] = True
def connectSSHTransport(service, hostAddress=None, peerAddress=None):
"""
Connect a SSHTransport which is already connected to a remote peer to
the channel under test.
@param service: Service used over the connected transport.
@type service: L{SSHService}
@param hostAddress: Local address of the connected transport.
@type hostAddress: L{interfaces.IAddress}
@param peerAddress: Remote address of the connected transport.
@type peerAddress: L{interfaces.IAddress}
"""
transport = SSHServerTransport()
transport.makeConnection(StringTransport(
hostAddress=hostAddress, peerAddress=peerAddress))
transport.setService(service)
class ChannelTests(unittest.TestCase):
"""
Tests for L{SSHChannel}.
"""
skip = skipTest
def setUp(self):
"""
Initialize the channel. remoteMaxPacket is 10 so that data is able
to be sent (the default of 0 means no data is sent because no packets
are made).
"""
self.conn = MockConnection()
self.channel = channel.SSHChannel(conn=self.conn,
remoteMaxPacket=10)
self.channel.name = b'channel'
def test_interface(self):
"""
L{SSHChannel} instances provide L{interfaces.ITransport}.
"""
self.assertTrue(verifyObject(interfaces.ITransport, self.channel))
def test_init(self):
"""
Test that SSHChannel initializes correctly. localWindowSize defaults
to 131072 (2**17) and localMaxPacket to 32768 (2**15) as reasonable
defaults (what OpenSSH uses for those variables).
The values in the second set of assertions are meaningless; they serve
only to verify that the instance variables are assigned in the correct
order.
"""
c = channel.SSHChannel(conn=self.conn)
self.assertEqual(c.localWindowSize, 131072)
self.assertEqual(c.localWindowLeft, 131072)
self.assertEqual(c.localMaxPacket, 32768)
self.assertEqual(c.remoteWindowLeft, 0)
self.assertEqual(c.remoteMaxPacket, 0)
self.assertEqual(c.conn, self.conn)
self.assertIsNone(c.data)
self.assertIsNone(c.avatar)
c2 = channel.SSHChannel(1, 2, 3, 4, 5, 6, 7)
self.assertEqual(c2.localWindowSize, 1)
self.assertEqual(c2.localWindowLeft, 1)
self.assertEqual(c2.localMaxPacket, 2)
self.assertEqual(c2.remoteWindowLeft, 3)
self.assertEqual(c2.remoteMaxPacket, 4)
self.assertEqual(c2.conn, 5)
self.assertEqual(c2.data, 6)
self.assertEqual(c2.avatar, 7)
def test_str(self):
"""
Test that str(SSHChannel) works gives the channel name and local and
remote windows at a glance..
"""
self.assertEqual(
str(self.channel), '<SSHChannel channel (lw 131072 rw 0)>')
self.assertEqual(
str(channel.SSHChannel(localWindow=1)),
'<SSHChannel None (lw 1 rw 0)>')
def test_bytes(self):
"""
Test that bytes(SSHChannel) works, gives the channel name and
local and remote windows at a glance..
"""
self.assertEqual(
self.channel.__bytes__(),
b'<SSHChannel channel (lw 131072 rw 0)>')
self.assertEqual(
channel.SSHChannel(localWindow=1).__bytes__(),
b'<SSHChannel None (lw 1 rw 0)>')
def test_logPrefix(self):
"""
Test that SSHChannel.logPrefix gives the name of the channel, the
local channel ID and the underlying connection.
"""
self.assertEqual(self.channel.logPrefix(), 'SSHChannel channel '
'(unknown) on MockConnection')
def test_addWindowBytes(self):
"""
Test that addWindowBytes adds bytes to the window and resumes writing
if it was paused.
"""
cb = [False]
def stubStartWriting():
cb[0] = True
self.channel.startWriting = stubStartWriting
self.channel.write(b'test')
self.channel.writeExtended(1, b'test')
self.channel.addWindowBytes(50)
self.assertEqual(self.channel.remoteWindowLeft, 50 - 4 - 4)
self.assertTrue(self.channel.areWriting)
self.assertTrue(cb[0])
self.assertEqual(self.channel.buf, b'')
self.assertEqual(self.conn.data[self.channel], [b'test'])
self.assertEqual(self.channel.extBuf, [])
self.assertEqual(self.conn.extData[self.channel], [(1, b'test')])
cb[0] = False
self.channel.addWindowBytes(20)
self.assertFalse(cb[0])
self.channel.write(b'a'*80)
self.channel.loseConnection()
self.channel.addWindowBytes(20)
self.assertFalse(cb[0])
def test_requestReceived(self):
"""
Test that requestReceived handles requests by dispatching them to
request_* methods.
"""
self.channel.request_test_method = lambda data: data == b''
self.assertTrue(self.channel.requestReceived(b'test-method', b''))
self.assertFalse(self.channel.requestReceived(b'test-method', b'a'))
self.assertFalse(self.channel.requestReceived(b'bad-method', b''))
def test_closeReceieved(self):
"""
Test that the default closeReceieved closes the connection.
"""
self.assertFalse(self.channel.closing)
self.channel.closeReceived()
self.assertTrue(self.channel.closing)
def test_write(self):
"""
Test that write handles data correctly. Send data up to the size
of the remote window, splitting the data into packets of length
remoteMaxPacket.
"""
cb = [False]
def stubStopWriting():
cb[0] = True
# no window to start with
self.channel.stopWriting = stubStopWriting
self.channel.write(b'd')
self.channel.write(b'a')
self.assertFalse(self.channel.areWriting)
self.assertTrue(cb[0])
# regular write
self.channel.addWindowBytes(20)
self.channel.write(b'ta')
data = self.conn.data[self.channel]
self.assertEqual(data, [b'da', b'ta'])
self.assertEqual(self.channel.remoteWindowLeft, 16)
# larger than max packet
self.channel.write(b'12345678901')
self.assertEqual(data, [b'da', b'ta', b'1234567890', b'1'])
self.assertEqual(self.channel.remoteWindowLeft, 5)
# running out of window
cb[0] = False
self.channel.write(b'123456')
self.assertFalse(self.channel.areWriting)
self.assertTrue(cb[0])
self.assertEqual(data, [b'da', b'ta', b'1234567890', b'1', b'12345'])
self.assertEqual(self.channel.buf, b'6')
self.assertEqual(self.channel.remoteWindowLeft, 0)
def test_writeExtended(self):
"""
Test that writeExtended handles data correctly. Send extended data
up to the size of the window, splitting the extended data into packets
of length remoteMaxPacket.
"""
cb = [False]
def stubStopWriting():
cb[0] = True
# no window to start with
self.channel.stopWriting = stubStopWriting
self.channel.writeExtended(1, b'd')
self.channel.writeExtended(1, b'a')
self.channel.writeExtended(2, b't')
self.assertFalse(self.channel.areWriting)
self.assertTrue(cb[0])
# regular write
self.channel.addWindowBytes(20)
self.channel.writeExtended(2, b'a')
data = self.conn.extData[self.channel]
self.assertEqual(data, [(1, b'da'), (2, b't'), (2, b'a')])
self.assertEqual(self.channel.remoteWindowLeft, 16)
# larger than max packet
self.channel.writeExtended(3, b'12345678901')
self.assertEqual(data, [(1, b'da'), (2, b't'), (2, b'a'),
(3, b'1234567890'), (3, b'1')])
self.assertEqual(self.channel.remoteWindowLeft, 5)
# running out of window
cb[0] = False
self.channel.writeExtended(4, b'123456')
self.assertFalse(self.channel.areWriting)
self.assertTrue(cb[0])
self.assertEqual(data, [(1, b'da'), (2, b't'), (2, b'a'),
(3, b'1234567890'), (3, b'1'), (4, b'12345')])
self.assertEqual(self.channel.extBuf, [[4, b'6']])
self.assertEqual(self.channel.remoteWindowLeft, 0)
def test_writeSequence(self):
"""
Test that writeSequence is equivalent to write(''.join(sequece)).
"""
self.channel.addWindowBytes(20)
self.channel.writeSequence(map(intToBytes, range(10)))
self.assertEqual(self.conn.data[self.channel], [b'0123456789'])
def test_loseConnection(self):
"""
Tesyt that loseConnection() doesn't close the channel until all
the data is sent.
"""
self.channel.write(b'data')
self.channel.writeExtended(1, b'datadata')
self.channel.loseConnection()
self.assertIsNone(self.conn.closes.get(self.channel))
self.channel.addWindowBytes(4) # send regular data
self.assertIsNone(self.conn.closes.get(self.channel))
self.channel.addWindowBytes(8) # send extended data
self.assertTrue(self.conn.closes.get(self.channel))
def test_getPeer(self):
"""
L{SSHChannel.getPeer} returns the same object as the underlying
transport's C{getPeer} method returns.
"""
peer = IPv4Address('TCP', '192.168.0.1', 54321)
connectSSHTransport(service=self.channel.conn, peerAddress=peer)
self.assertEqual(SSHTransportAddress(peer), self.channel.getPeer())
def test_getHost(self):
"""
L{SSHChannel.getHost} returns the same object as the underlying
transport's C{getHost} method returns.
"""
host = IPv4Address('TCP', '127.0.0.1', 12345)
connectSSHTransport(service=self.channel.conn, hostAddress=host)
self.assertEqual(SSHTransportAddress(host), self.channel.getHost())

View file

@ -0,0 +1,875 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.conch.checkers}.
"""
from __future__ import absolute_import, division
try:
import crypt
except ImportError:
cryptSkip = 'cannot run without crypt module'
else:
cryptSkip = None
import os
from collections import namedtuple
from io import BytesIO
from zope.interface.verify import verifyObject
from twisted.python import util
from twisted.python.compat import _b64encodebytes
from twisted.python.failure import Failure
from twisted.python.reflect import requireModule
from twisted.trial.unittest import TestCase
from twisted.python.filepath import FilePath
from twisted.cred.checkers import InMemoryUsernamePasswordDatabaseDontUse
from twisted.cred.credentials import UsernamePassword, IUsernamePassword, \
SSHPrivateKey, ISSHPrivateKey
from twisted.cred.error import UnhandledCredentials, UnauthorizedLogin
from twisted.python.fakepwd import UserDatabase, ShadowDatabase
from twisted.test.test_process import MockOS
if requireModule('cryptography') and requireModule('pyasn1'):
dependencySkip = None
from twisted.conch.ssh import keys
from twisted.conch import checkers
from twisted.conch.error import NotEnoughAuthentication, ValidPublicKey
from twisted.conch.test import keydata
else:
dependencySkip = "can't run without cryptography and PyASN1"
if getattr(os, 'geteuid', None) is None:
euidSkip = "Cannot run without effective UIDs (questionable)"
else:
euidSkip = None
class HelperTests(TestCase):
"""
Tests for helper functions L{verifyCryptedPassword}, L{_pwdGetByName} and
L{_shadowGetByName}.
"""
skip = cryptSkip or dependencySkip
def setUp(self):
self.mockos = MockOS()
def test_verifyCryptedPassword(self):
"""
L{verifyCryptedPassword} returns C{True} if the plaintext password
passed to it matches the encrypted password passed to it.
"""
password = 'secret string'
salt = 'salty'
crypted = crypt.crypt(password, salt)
self.assertTrue(
checkers.verifyCryptedPassword(crypted, password),
'%r supposed to be valid encrypted password for %r' % (
crypted, password))
def test_verifyCryptedPasswordMD5(self):
"""
L{verifyCryptedPassword} returns True if the provided cleartext password
matches the provided MD5 password hash.
"""
password = 'password'
salt = '$1$salt'
crypted = crypt.crypt(password, salt)
self.assertTrue(
checkers.verifyCryptedPassword(crypted, password),
'%r supposed to be valid encrypted password for %s' % (
crypted, password))
def test_refuteCryptedPassword(self):
"""
L{verifyCryptedPassword} returns C{False} if the plaintext password
passed to it does not match the encrypted password passed to it.
"""
password = 'string secret'
wrong = 'secret string'
crypted = crypt.crypt(password, password)
self.assertFalse(
checkers.verifyCryptedPassword(crypted, wrong),
'%r not supposed to be valid encrypted password for %s' % (
crypted, wrong))
def test_pwdGetByName(self):
"""
L{_pwdGetByName} returns a tuple of items from the UNIX /etc/passwd
database if the L{pwd} module is present.
"""
userdb = UserDatabase()
userdb.addUser(
'alice', 'secrit', 1, 2, 'first last', '/foo', '/bin/sh')
self.patch(checkers, 'pwd', userdb)
self.assertEqual(
checkers._pwdGetByName('alice'), userdb.getpwnam('alice'))
def test_pwdGetByNameWithoutPwd(self):
"""
If the C{pwd} module isn't present, L{_pwdGetByName} returns L{None}.
"""
self.patch(checkers, 'pwd', None)
self.assertIsNone(checkers._pwdGetByName('alice'))
def test_shadowGetByName(self):
"""
L{_shadowGetByName} returns a tuple of items from the UNIX /etc/shadow
database if the L{spwd} is present.
"""
userdb = ShadowDatabase()
userdb.addUser('bob', 'passphrase', 1, 2, 3, 4, 5, 6, 7)
self.patch(checkers, 'spwd', userdb)
self.mockos.euid = 2345
self.mockos.egid = 1234
self.patch(util, 'os', self.mockos)
self.assertEqual(
checkers._shadowGetByName('bob'), userdb.getspnam('bob'))
self.assertEqual(self.mockos.seteuidCalls, [0, 2345])
self.assertEqual(self.mockos.setegidCalls, [0, 1234])
def test_shadowGetByNameWithoutSpwd(self):
"""
L{_shadowGetByName} returns L{None} if C{spwd} is not present.
"""
self.patch(checkers, 'spwd', None)
self.assertIsNone(checkers._shadowGetByName('bob'))
self.assertEqual(self.mockos.seteuidCalls, [])
self.assertEqual(self.mockos.setegidCalls, [])
class SSHPublicKeyDatabaseTests(TestCase):
"""
Tests for L{SSHPublicKeyDatabase}.
"""
skip = euidSkip or dependencySkip
def setUp(self):
self.checker = checkers.SSHPublicKeyDatabase()
self.key1 = _b64encodebytes(b"foobar")
self.key2 = _b64encodebytes(b"eggspam")
self.content = (b"t1 " + self.key1 + b" foo\nt2 " + self.key2 +
b" egg\n")
self.mockos = MockOS()
self.mockos.path = FilePath(self.mktemp())
self.mockos.path.makedirs()
self.patch(util, 'os', self.mockos)
self.sshDir = self.mockos.path.child('.ssh')
self.sshDir.makedirs()
userdb = UserDatabase()
userdb.addUser(
b'user', b'password', 1, 2, b'first last',
self.mockos.path.path, b'/bin/shell')
self.checker._userdb = userdb
def test_deprecated(self):
"""
L{SSHPublicKeyDatabase} is deprecated as of version 15.0
"""
warningsShown = self.flushWarnings(
offendingFunctions=[self.setUp])
self.assertEqual(warningsShown[0]['category'], DeprecationWarning)
self.assertEqual(
warningsShown[0]['message'],
"twisted.conch.checkers.SSHPublicKeyDatabase "
"was deprecated in Twisted 15.0.0: Please use "
"twisted.conch.checkers.SSHPublicKeyChecker, "
"initialized with an instance of "
"twisted.conch.checkers.UNIXAuthorizedKeysFiles instead.")
self.assertEqual(len(warningsShown), 1)
def _testCheckKey(self, filename):
self.sshDir.child(filename).setContent(self.content)
user = UsernamePassword(b"user", b"password")
user.blob = b"foobar"
self.assertTrue(self.checker.checkKey(user))
user.blob = b"eggspam"
self.assertTrue(self.checker.checkKey(user))
user.blob = b"notallowed"
self.assertFalse(self.checker.checkKey(user))
def test_checkKey(self):
"""
L{SSHPublicKeyDatabase.checkKey} should retrieve the content of the
authorized_keys file and check the keys against that file.
"""
self._testCheckKey("authorized_keys")
self.assertEqual(self.mockos.seteuidCalls, [])
self.assertEqual(self.mockos.setegidCalls, [])
def test_checkKey2(self):
"""
L{SSHPublicKeyDatabase.checkKey} should retrieve the content of the
authorized_keys2 file and check the keys against that file.
"""
self._testCheckKey("authorized_keys2")
self.assertEqual(self.mockos.seteuidCalls, [])
self.assertEqual(self.mockos.setegidCalls, [])
def test_checkKeyAsRoot(self):
"""
If the key file is readable, L{SSHPublicKeyDatabase.checkKey} should
switch its uid/gid to the ones of the authenticated user.
"""
keyFile = self.sshDir.child("authorized_keys")
keyFile.setContent(self.content)
# Fake permission error by changing the mode
keyFile.chmod(0o000)
self.addCleanup(keyFile.chmod, 0o777)
# And restore the right mode when seteuid is called
savedSeteuid = self.mockos.seteuid
def seteuid(euid):
keyFile.chmod(0o777)
return savedSeteuid(euid)
self.mockos.euid = 2345
self.mockos.egid = 1234
self.patch(self.mockos, "seteuid", seteuid)
self.patch(util, 'os', self.mockos)
user = UsernamePassword(b"user", b"password")
user.blob = b"foobar"
self.assertTrue(self.checker.checkKey(user))
self.assertEqual(self.mockos.seteuidCalls, [0, 1, 0, 2345])
self.assertEqual(self.mockos.setegidCalls, [2, 1234])
def test_requestAvatarId(self):
"""
L{SSHPublicKeyDatabase.requestAvatarId} should return the avatar id
passed in if its C{_checkKey} method returns True.
"""
def _checkKey(ignored):
return True
self.patch(self.checker, 'checkKey', _checkKey)
credentials = SSHPrivateKey(
b'test', b'ssh-rsa', keydata.publicRSA_openssh, b'foo',
keys.Key.fromString(keydata.privateRSA_openssh).sign(b'foo'))
d = self.checker.requestAvatarId(credentials)
def _verify(avatarId):
self.assertEqual(avatarId, b'test')
return d.addCallback(_verify)
def test_requestAvatarIdWithoutSignature(self):
"""
L{SSHPublicKeyDatabase.requestAvatarId} should raise L{ValidPublicKey}
if the credentials represent a valid key without a signature. This
tells the user that the key is valid for login, but does not actually
allow that user to do so without a signature.
"""
def _checkKey(ignored):
return True
self.patch(self.checker, 'checkKey', _checkKey)
credentials = SSHPrivateKey(
b'test', b'ssh-rsa', keydata.publicRSA_openssh, None, None)
d = self.checker.requestAvatarId(credentials)
return self.assertFailure(d, ValidPublicKey)
def test_requestAvatarIdInvalidKey(self):
"""
If L{SSHPublicKeyDatabase.checkKey} returns False,
C{_cbRequestAvatarId} should raise L{UnauthorizedLogin}.
"""
def _checkKey(ignored):
return False
self.patch(self.checker, 'checkKey', _checkKey)
d = self.checker.requestAvatarId(None);
return self.assertFailure(d, UnauthorizedLogin)
def test_requestAvatarIdInvalidSignature(self):
"""
Valid keys with invalid signatures should cause
L{SSHPublicKeyDatabase.requestAvatarId} to return a {UnauthorizedLogin}
failure
"""
def _checkKey(ignored):
return True
self.patch(self.checker, 'checkKey', _checkKey)
credentials = SSHPrivateKey(
b'test', b'ssh-rsa', keydata.publicRSA_openssh, b'foo',
keys.Key.fromString(keydata.privateDSA_openssh).sign(b'foo'))
d = self.checker.requestAvatarId(credentials)
return self.assertFailure(d, UnauthorizedLogin)
def test_requestAvatarIdNormalizeException(self):
"""
Exceptions raised while verifying the key should be normalized into an
C{UnauthorizedLogin} failure.
"""
def _checkKey(ignored):
return True
self.patch(self.checker, 'checkKey', _checkKey)
credentials = SSHPrivateKey(b'test', None, b'blob', b'sigData', b'sig')
d = self.checker.requestAvatarId(credentials)
def _verifyLoggedException(failure):
errors = self.flushLoggedErrors(keys.BadKeyError)
self.assertEqual(len(errors), 1)
return failure
d.addErrback(_verifyLoggedException)
return self.assertFailure(d, UnauthorizedLogin)
class SSHProtocolCheckerTests(TestCase):
"""
Tests for L{SSHProtocolChecker}.
"""
skip = dependencySkip
def test_registerChecker(self):
"""
L{SSHProcotolChecker.registerChecker} should add the given checker to
the list of registered checkers.
"""
checker = checkers.SSHProtocolChecker()
self.assertEqual(checker.credentialInterfaces, [])
checker.registerChecker(checkers.SSHPublicKeyDatabase(), )
self.assertEqual(checker.credentialInterfaces, [ISSHPrivateKey])
self.assertIsInstance(checker.checkers[ISSHPrivateKey],
checkers.SSHPublicKeyDatabase)
def test_registerCheckerWithInterface(self):
"""
If a specific interface is passed into
L{SSHProtocolChecker.registerChecker}, that interface should be
registered instead of what the checker specifies in
credentialIntefaces.
"""
checker = checkers.SSHProtocolChecker()
self.assertEqual(checker.credentialInterfaces, [])
checker.registerChecker(checkers.SSHPublicKeyDatabase(),
IUsernamePassword)
self.assertEqual(checker.credentialInterfaces, [IUsernamePassword])
self.assertIsInstance(checker.checkers[IUsernamePassword],
checkers.SSHPublicKeyDatabase)
def test_requestAvatarId(self):
"""
L{SSHProtocolChecker.requestAvatarId} should defer to one if its
registered checkers to authenticate a user.
"""
checker = checkers.SSHProtocolChecker()
passwordDatabase = InMemoryUsernamePasswordDatabaseDontUse()
passwordDatabase.addUser(b'test', b'test')
checker.registerChecker(passwordDatabase)
d = checker.requestAvatarId(UsernamePassword(b'test', b'test'))
def _callback(avatarId):
self.assertEqual(avatarId, b'test')
return d.addCallback(_callback)
def test_requestAvatarIdWithNotEnoughAuthentication(self):
"""
If the client indicates that it is never satisfied, by always returning
False from _areDone, then L{SSHProtocolChecker} should raise
L{NotEnoughAuthentication}.
"""
checker = checkers.SSHProtocolChecker()
def _areDone(avatarId):
return False
self.patch(checker, 'areDone', _areDone)
passwordDatabase = InMemoryUsernamePasswordDatabaseDontUse()
passwordDatabase.addUser(b'test', b'test')
checker.registerChecker(passwordDatabase)
d = checker.requestAvatarId(UsernamePassword(b'test', b'test'))
return self.assertFailure(d, NotEnoughAuthentication)
def test_requestAvatarIdInvalidCredential(self):
"""
If the passed credentials aren't handled by any registered checker,
L{SSHProtocolChecker} should raise L{UnhandledCredentials}.
"""
checker = checkers.SSHProtocolChecker()
d = checker.requestAvatarId(UsernamePassword(b'test', b'test'))
return self.assertFailure(d, UnhandledCredentials)
def test_areDone(self):
"""
The default L{SSHProcotolChecker.areDone} should simply return True.
"""
self.assertTrue(checkers.SSHProtocolChecker().areDone(None))
class UNIXPasswordDatabaseTests(TestCase):
"""
Tests for L{UNIXPasswordDatabase}.
"""
skip = cryptSkip or dependencySkip
def assertLoggedIn(self, d, username):
"""
Assert that the L{Deferred} passed in is called back with the value
'username'. This represents a valid login for this TestCase.
NOTE: To work, this method's return value must be returned from the
test method, or otherwise hooked up to the test machinery.
@param d: a L{Deferred} from an L{IChecker.requestAvatarId} method.
@type d: L{Deferred}
@rtype: L{Deferred}
"""
result = []
d.addBoth(result.append)
self.assertEqual(len(result), 1, "login incomplete")
if isinstance(result[0], Failure):
result[0].raiseException()
self.assertEqual(result[0], username)
def test_defaultCheckers(self):
"""
L{UNIXPasswordDatabase} with no arguments has checks the C{pwd} database
and then the C{spwd} database.
"""
checker = checkers.UNIXPasswordDatabase()
def crypted(username, password):
salt = crypt.crypt(password, username)
crypted = crypt.crypt(password, '$1$' + salt)
return crypted
pwd = UserDatabase()
pwd.addUser('alice', crypted('alice', 'password'),
1, 2, 'foo', '/foo', '/bin/sh')
# x and * are convention for "look elsewhere for the password"
pwd.addUser('bob', 'x', 1, 2, 'bar', '/bar', '/bin/sh')
spwd = ShadowDatabase()
spwd.addUser('alice', 'wrong', 1, 2, 3, 4, 5, 6, 7)
spwd.addUser('bob', crypted('bob', 'password'),
8, 9, 10, 11, 12, 13, 14)
self.patch(checkers, 'pwd', pwd)
self.patch(checkers, 'spwd', spwd)
mockos = MockOS()
self.patch(util, 'os', mockos)
mockos.euid = 2345
mockos.egid = 1234
cred = UsernamePassword(b"alice", b"password")
self.assertLoggedIn(checker.requestAvatarId(cred), b'alice')
self.assertEqual(mockos.seteuidCalls, [])
self.assertEqual(mockos.setegidCalls, [])
cred.username = b"bob"
self.assertLoggedIn(checker.requestAvatarId(cred), b'bob')
self.assertEqual(mockos.seteuidCalls, [0, 2345])
self.assertEqual(mockos.setegidCalls, [0, 1234])
def assertUnauthorizedLogin(self, d):
"""
Asserts that the L{Deferred} passed in is erred back with an
L{UnauthorizedLogin} L{Failure}. This reprsents an invalid login for
this TestCase.
NOTE: To work, this method's return value must be returned from the
test method, or otherwise hooked up to the test machinery.
@param d: a L{Deferred} from an L{IChecker.requestAvatarId} method.
@type d: L{Deferred}
@rtype: L{None}
"""
self.assertRaises(
checkers.UnauthorizedLogin, self.assertLoggedIn, d, 'bogus value')
def test_passInCheckers(self):
"""
L{UNIXPasswordDatabase} takes a list of functions to check for UNIX
user information.
"""
password = crypt.crypt('secret', 'secret')
userdb = UserDatabase()
userdb.addUser('anybody', password, 1, 2, 'foo', '/bar', '/bin/sh')
checker = checkers.UNIXPasswordDatabase([userdb.getpwnam])
self.assertLoggedIn(
checker.requestAvatarId(UsernamePassword(b'anybody', b'secret')),
b'anybody')
def test_verifyPassword(self):
"""
If the encrypted password provided by the getpwnam function is valid
(verified by the L{verifyCryptedPassword} function), we callback the
C{requestAvatarId} L{Deferred} with the username.
"""
def verifyCryptedPassword(crypted, pw):
return crypted == pw
def getpwnam(username):
return [username, username]
self.patch(checkers, 'verifyCryptedPassword', verifyCryptedPassword)
checker = checkers.UNIXPasswordDatabase([getpwnam])
credential = UsernamePassword(b'username', b'username')
self.assertLoggedIn(checker.requestAvatarId(credential), b'username')
def test_failOnKeyError(self):
"""
If the getpwnam function raises a KeyError, the login fails with an
L{UnauthorizedLogin} exception.
"""
def getpwnam(username):
raise KeyError(username)
checker = checkers.UNIXPasswordDatabase([getpwnam])
credential = UsernamePassword(b'username', b'username')
self.assertUnauthorizedLogin(checker.requestAvatarId(credential))
def test_failOnBadPassword(self):
"""
If the verifyCryptedPassword function doesn't verify the password, the
login fails with an L{UnauthorizedLogin} exception.
"""
def verifyCryptedPassword(crypted, pw):
return False
def getpwnam(username):
return [username, username]
self.patch(checkers, 'verifyCryptedPassword', verifyCryptedPassword)
checker = checkers.UNIXPasswordDatabase([getpwnam])
credential = UsernamePassword(b'username', b'username')
self.assertUnauthorizedLogin(checker.requestAvatarId(credential))
def test_loopThroughFunctions(self):
"""
UNIXPasswordDatabase.requestAvatarId loops through each getpwnam
function associated with it and returns a L{Deferred} which fires with
the result of the first one which returns a value other than None.
ones do not verify the password.
"""
def verifyCryptedPassword(crypted, pw):
return crypted == pw
def getpwnam1(username):
return [username, 'not the password']
def getpwnam2(username):
return [username, username]
self.patch(checkers, 'verifyCryptedPassword', verifyCryptedPassword)
checker = checkers.UNIXPasswordDatabase([getpwnam1, getpwnam2])
credential = UsernamePassword(b'username', b'username')
self.assertLoggedIn(checker.requestAvatarId(credential), b'username')
def test_failOnSpecial(self):
"""
If the password returned by any function is C{""}, C{"x"}, or C{"*"} it
is not compared against the supplied password. Instead it is skipped.
"""
pwd = UserDatabase()
pwd.addUser('alice', '', 1, 2, '', 'foo', 'bar')
pwd.addUser('bob', 'x', 1, 2, '', 'foo', 'bar')
pwd.addUser('carol', '*', 1, 2, '', 'foo', 'bar')
self.patch(checkers, 'pwd', pwd)
checker = checkers.UNIXPasswordDatabase([checkers._pwdGetByName])
cred = UsernamePassword(b'alice', b'')
self.assertUnauthorizedLogin(checker.requestAvatarId(cred))
cred = UsernamePassword(b'bob', b'x')
self.assertUnauthorizedLogin(checker.requestAvatarId(cred))
cred = UsernamePassword(b'carol', b'*')
self.assertUnauthorizedLogin(checker.requestAvatarId(cred))
class AuthorizedKeyFileReaderTests(TestCase):
"""
Tests for L{checkers.readAuthorizedKeyFile}
"""
skip = dependencySkip
def test_ignoresComments(self):
"""
L{checkers.readAuthorizedKeyFile} does not attempt to turn comments
into keys
"""
fileobj = BytesIO(b'# this comment is ignored\n'
b'this is not\n'
b'# this is again\n'
b'and this is not')
result = checkers.readAuthorizedKeyFile(fileobj, lambda x: x)
self.assertEqual([b'this is not', b'and this is not'], list(result))
def test_ignoresLeadingWhitespaceAndEmptyLines(self):
"""
L{checkers.readAuthorizedKeyFile} ignores leading whitespace in
lines, as well as empty lines
"""
fileobj = BytesIO(b"""
# ignore
not ignored
""")
result = checkers.readAuthorizedKeyFile(fileobj, parseKey=lambda x: x)
self.assertEqual([b'not ignored'], list(result))
def test_ignoresUnparsableKeys(self):
"""
L{checkers.readAuthorizedKeyFile} does not raise an exception
when a key fails to parse (raises a
L{twisted.conch.ssh.keys.BadKeyError}), but rather just keeps going
"""
def failOnSome(line):
if line.startswith(b'f'):
raise keys.BadKeyError('failed to parse')
return line
fileobj = BytesIO(b'failed key\ngood key')
result = checkers.readAuthorizedKeyFile(fileobj,
parseKey=failOnSome)
self.assertEqual([b'good key'], list(result))
class InMemorySSHKeyDBTests(TestCase):
"""
Tests for L{checkers.InMemorySSHKeyDB}
"""
skip = dependencySkip
def test_implementsInterface(self):
"""
L{checkers.InMemorySSHKeyDB} implements
L{checkers.IAuthorizedKeysDB}
"""
keydb = checkers.InMemorySSHKeyDB({b'alice': [b'key']})
verifyObject(checkers.IAuthorizedKeysDB, keydb)
def test_noKeysForUnauthorizedUser(self):
"""
If the user is not in the mapping provided to
L{checkers.InMemorySSHKeyDB}, an empty iterator is returned
by L{checkers.InMemorySSHKeyDB.getAuthorizedKeys}
"""
keydb = checkers.InMemorySSHKeyDB({b'alice': [b'keys']})
self.assertEqual([], list(keydb.getAuthorizedKeys(b'bob')))
def test_allKeysForAuthorizedUser(self):
"""
If the user is in the mapping provided to
L{checkers.InMemorySSHKeyDB}, an iterator with all the keys
is returned by L{checkers.InMemorySSHKeyDB.getAuthorizedKeys}
"""
keydb = checkers.InMemorySSHKeyDB({b'alice': [b'a', b'b']})
self.assertEqual([b'a', b'b'], list(keydb.getAuthorizedKeys(b'alice')))
class UNIXAuthorizedKeysFilesTests(TestCase):
"""
Tests for L{checkers.UNIXAuthorizedKeysFiles}.
"""
skip = dependencySkip
def setUp(self):
mockos = MockOS()
mockos.path = FilePath(self.mktemp())
mockos.path.makedirs()
self.userdb = UserDatabase()
self.userdb.addUser(b'alice', b'password', 1, 2, b'alice lastname',
mockos.path.path, b'/bin/shell')
self.sshDir = mockos.path.child('.ssh')
self.sshDir.makedirs()
authorizedKeys = self.sshDir.child('authorized_keys')
authorizedKeys.setContent(b'key 1\nkey 2')
self.expectedKeys = [b'key 1', b'key 2']
def test_implementsInterface(self):
"""
L{checkers.UNIXAuthorizedKeysFiles} implements
L{checkers.IAuthorizedKeysDB}.
"""
keydb = checkers.UNIXAuthorizedKeysFiles(self.userdb)
verifyObject(checkers.IAuthorizedKeysDB, keydb)
def test_noKeysForUnauthorizedUser(self):
"""
If the user is not in the user database provided to
L{checkers.UNIXAuthorizedKeysFiles}, an empty iterator is returned
by L{checkers.UNIXAuthorizedKeysFiles.getAuthorizedKeys}.
"""
keydb = checkers.UNIXAuthorizedKeysFiles(self.userdb,
parseKey=lambda x: x)
self.assertEqual([], list(keydb.getAuthorizedKeys('bob')))
def test_allKeysInAllAuthorizedFilesForAuthorizedUser(self):
"""
If the user is in the user database provided to
L{checkers.UNIXAuthorizedKeysFiles}, an iterator with all the keys in
C{~/.ssh/authorized_keys} and C{~/.ssh/authorized_keys2} is returned
by L{checkers.UNIXAuthorizedKeysFiles.getAuthorizedKeys}.
"""
self.sshDir.child('authorized_keys2').setContent(b'key 3')
keydb = checkers.UNIXAuthorizedKeysFiles(self.userdb,
parseKey=lambda x: x)
self.assertEqual(self.expectedKeys + [b'key 3'],
list(keydb.getAuthorizedKeys(b'alice')))
def test_ignoresNonexistantFile(self):
"""
L{checkers.UNIXAuthorizedKeysFiles.getAuthorizedKeys} returns only
the keys in C{~/.ssh/authorized_keys} and C{~/.ssh/authorized_keys2}
if they exist.
"""
keydb = checkers.UNIXAuthorizedKeysFiles(self.userdb,
parseKey=lambda x: x)
self.assertEqual(self.expectedKeys,
list(keydb.getAuthorizedKeys(b'alice')))
def test_ignoresUnreadableFile(self):
"""
L{checkers.UNIXAuthorizedKeysFiles.getAuthorizedKeys} returns only
the keys in C{~/.ssh/authorized_keys} and C{~/.ssh/authorized_keys2}
if they are readable.
"""
self.sshDir.child('authorized_keys2').makedirs()
keydb = checkers.UNIXAuthorizedKeysFiles(self.userdb,
parseKey=lambda x: x)
self.assertEqual(self.expectedKeys,
list(keydb.getAuthorizedKeys(b'alice')))
_KeyDB = namedtuple('KeyDB', ['getAuthorizedKeys'])
class _DummyException(Exception):
"""
Fake exception to be used for testing.
"""
pass
class SSHPublicKeyCheckerTests(TestCase):
"""
Tests for L{checkers.SSHPublicKeyChecker}.
"""
skip = dependencySkip
def setUp(self):
self.credentials = SSHPrivateKey(
b'alice', b'ssh-rsa', keydata.publicRSA_openssh, b'foo',
keys.Key.fromString(keydata.privateRSA_openssh).sign(b'foo'))
self.keydb = _KeyDB(lambda _: [
keys.Key.fromString(keydata.publicRSA_openssh)])
self.checker = checkers.SSHPublicKeyChecker(self.keydb)
def test_credentialsWithoutSignature(self):
"""
Calling L{checkers.SSHPublicKeyChecker.requestAvatarId} with
credentials that do not have a signature fails with L{ValidPublicKey}.
"""
self.credentials.signature = None
self.failureResultOf(self.checker.requestAvatarId(self.credentials),
ValidPublicKey)
def test_credentialsWithBadKey(self):
"""
Calling L{checkers.SSHPublicKeyChecker.requestAvatarId} with
credentials that have a bad key fails with L{keys.BadKeyError}.
"""
self.credentials.blob = b''
self.failureResultOf(self.checker.requestAvatarId(self.credentials),
keys.BadKeyError)
def test_credentialsNoMatchingKey(self):
"""
If L{checkers.IAuthorizedKeysDB.getAuthorizedKeys} returns no keys
that match the credentials,
L{checkers.SSHPublicKeyChecker.requestAvatarId} fails with
L{UnauthorizedLogin}.
"""
self.credentials.blob = keydata.publicDSA_openssh
self.failureResultOf(self.checker.requestAvatarId(self.credentials),
UnauthorizedLogin)
def test_credentialsInvalidSignature(self):
"""
Calling L{checkers.SSHPublicKeyChecker.requestAvatarId} with
credentials that are incorrectly signed fails with
L{UnauthorizedLogin}.
"""
self.credentials.signature = (
keys.Key.fromString(keydata.privateDSA_openssh).sign(b'foo'))
self.failureResultOf(self.checker.requestAvatarId(self.credentials),
UnauthorizedLogin)
def test_failureVerifyingKey(self):
"""
If L{keys.Key.verify} raises an exception,
L{checkers.SSHPublicKeyChecker.requestAvatarId} fails with
L{UnauthorizedLogin}.
"""
def fail(*args, **kwargs):
raise _DummyException()
self.patch(keys.Key, 'verify', fail)
self.failureResultOf(self.checker.requestAvatarId(self.credentials),
UnauthorizedLogin)
self.flushLoggedErrors(_DummyException)
def test_usernameReturnedOnSuccess(self):
"""
L{checker.SSHPublicKeyChecker.requestAvatarId}, if successful,
callbacks with the username.
"""
d = self.checker.requestAvatarId(self.credentials)
self.assertEqual(b'alice', self.successResultOf(d))

View file

@ -0,0 +1,625 @@
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
"""
Tests for L{twisted.conch.scripts.ckeygen}.
"""
import getpass
import sys
import subprocess
from io import BytesIO, StringIO
from twisted.python.compat import unicode, _PY3
from twisted.python.reflect import requireModule
if requireModule('cryptography') and requireModule('pyasn1'):
from twisted.conch.ssh.keys import (Key, BadKeyError,
BadFingerPrintFormat, FingerprintFormats)
from twisted.conch.scripts.ckeygen import (
changePassPhrase, displayPublicKey, printFingerprint,
_saveKey, enumrepresentation)
else:
skip = "cryptography and pyasn1 required for twisted.conch.scripts.ckeygen"
from twisted.python.filepath import FilePath
from twisted.trial.unittest import TestCase
from twisted.conch.test.keydata import (
publicRSA_openssh, privateRSA_openssh, privateRSA_openssh_encrypted, privateECDSA_openssh)
def makeGetpass(*passphrases):
"""
Return a callable to patch C{getpass.getpass}. Yields a passphrase each
time called. Use case is to provide an old, then new passphrase(s) as if
requested interactively.
@param passphrases: The list of passphrases returned, one per each call.
@return: A callable to patch C{getpass.getpass}.
"""
passphrases = iter(passphrases)
def fakeGetpass(_):
return next(passphrases)
return fakeGetpass
class KeyGenTests(TestCase):
"""
Tests for various functions used to implement the I{ckeygen} script.
"""
def setUp(self):
"""
Patch C{sys.stdout} so tests can make assertions about what's printed.
"""
if _PY3:
self.stdout = StringIO()
else:
self.stdout = BytesIO()
self.patch(sys, 'stdout', self.stdout)
def _testrun(self, keyType, keySize=None, privateKeySubtype=None):
filename = self.mktemp()
args = ['ckeygen', '-t', keyType, '-f', filename, '--no-passphrase']
if keySize is not None:
args.extend(['-b', keySize])
if privateKeySubtype is not None:
args.extend(['--private-key-subtype', privateKeySubtype])
subprocess.call(args)
privKey = Key.fromFile(filename)
pubKey = Key.fromFile(filename + '.pub')
if keyType == 'ecdsa':
self.assertEqual(privKey.type(), 'EC')
else:
self.assertEqual(privKey.type(), keyType.upper())
self.assertTrue(pubKey.isPublic())
def test_keygeneration(self):
self._testrun('ecdsa', '384')
self._testrun('ecdsa', '384', privateKeySubtype='v1')
self._testrun('ecdsa')
self._testrun('ecdsa', privateKeySubtype='v1')
self._testrun('dsa', '2048')
self._testrun('dsa', '2048', privateKeySubtype='v1')
self._testrun('dsa')
self._testrun('dsa', privateKeySubtype='v1')
self._testrun('rsa', '2048')
self._testrun('rsa', '2048', privateKeySubtype='v1')
self._testrun('rsa')
self._testrun('rsa', privateKeySubtype='v1')
def test_runBadKeytype(self):
filename = self.mktemp()
with self.assertRaises(subprocess.CalledProcessError):
subprocess.check_call(['ckeygen', '-t', 'foo', '-f', filename])
def test_enumrepresentation(self):
"""
L{enumrepresentation} takes a dictionary as input and returns a
dictionary with its attributes changed to enum representation.
"""
options = enumrepresentation({'format': 'md5-hex'})
self.assertIs(options['format'],
FingerprintFormats.MD5_HEX)
def test_enumrepresentationsha256(self):
"""
Test for format L{FingerprintFormats.SHA256-BASE64}.
"""
options = enumrepresentation({'format': 'sha256-base64'})
self.assertIs(options['format'],
FingerprintFormats.SHA256_BASE64)
def test_enumrepresentationBadFormat(self):
"""
Test for unsupported fingerprint format
"""
with self.assertRaises(BadFingerPrintFormat) as em:
enumrepresentation({'format': 'sha-base64'})
self.assertEqual('Unsupported fingerprint format: sha-base64',
em.exception.args[0])
def test_printFingerprint(self):
"""
L{printFingerprint} writes a line to standard out giving the number of
bits of the key, its fingerprint, and the basename of the file from it
was read.
"""
filename = self.mktemp()
FilePath(filename).setContent(publicRSA_openssh)
printFingerprint({'filename': filename,
'format': 'md5-hex'})
self.assertEqual(
self.stdout.getvalue(),
'2048 85:25:04:32:58:55:96:9f:57:ee:fb:a8:1a:ea:69:da temp\n')
def test_printFingerprintsha256(self):
"""
L{printFigerprint} will print key fingerprint in
L{FingerprintFormats.SHA256-BASE64} format if explicitly specified.
"""
filename = self.mktemp()
FilePath(filename).setContent(publicRSA_openssh)
printFingerprint({'filename': filename,
'format': 'sha256-base64'})
self.assertEqual(
self.stdout.getvalue(),
'2048 FBTCOoknq0mHy+kpfnY9tDdcAJuWtCpuQMaV3EsvbUI= temp\n')
def test_printFingerprintBadFingerPrintFormat(self):
"""
L{printFigerprint} raises C{keys.BadFingerprintFormat} when unsupported
formats are requested.
"""
filename = self.mktemp()
FilePath(filename).setContent(publicRSA_openssh)
with self.assertRaises(BadFingerPrintFormat) as em:
printFingerprint({'filename': filename, 'format':'sha-base64'})
self.assertEqual('Unsupported fingerprint format: sha-base64',
em.exception.args[0])
def test_saveKey(self):
"""
L{_saveKey} writes the private and public parts of a key to two
different files and writes a report of this to standard out.
"""
base = FilePath(self.mktemp())
base.makedirs()
filename = base.child('id_rsa').path
key = Key.fromString(privateRSA_openssh)
_saveKey(key, {'filename': filename, 'pass': 'passphrase',
'format': 'md5-hex'})
self.assertEqual(
self.stdout.getvalue(),
"Your identification has been saved in %s\n"
"Your public key has been saved in %s.pub\n"
"The key fingerprint in <FingerprintFormats=MD5_HEX> is:\n"
"85:25:04:32:58:55:96:9f:57:ee:fb:a8:1a:ea:69:da\n" % (
filename,
filename))
self.assertEqual(
key.fromString(
base.child('id_rsa').getContent(), None, 'passphrase'),
key)
self.assertEqual(
Key.fromString(base.child('id_rsa.pub').getContent()),
key.public())
def test_saveKeyECDSA(self):
"""
L{_saveKey} writes the private and public parts of a key to two
different files and writes a report of this to standard out.
Test with ECDSA key.
"""
base = FilePath(self.mktemp())
base.makedirs()
filename = base.child('id_ecdsa').path
key = Key.fromString(privateECDSA_openssh)
_saveKey(key, {'filename': filename, 'pass': 'passphrase',
'format': 'md5-hex'})
self.assertEqual(
self.stdout.getvalue(),
"Your identification has been saved in %s\n"
"Your public key has been saved in %s.pub\n"
"The key fingerprint in <FingerprintFormats=MD5_HEX> is:\n"
"1e:ab:83:a6:f2:04:22:99:7c:64:14:d2:ab:fa:f5:16\n" % (
filename,
filename))
self.assertEqual(
key.fromString(
base.child('id_ecdsa').getContent(), None, 'passphrase'),
key)
self.assertEqual(
Key.fromString(base.child('id_ecdsa.pub').getContent()),
key.public())
def test_saveKeysha256(self):
"""
L{_saveKey} will generate key fingerprint in
L{FingerprintFormats.SHA256-BASE64} format if explicitly specified.
"""
base = FilePath(self.mktemp())
base.makedirs()
filename = base.child('id_rsa').path
key = Key.fromString(privateRSA_openssh)
_saveKey(key, {'filename': filename, 'pass': 'passphrase',
'format': 'sha256-base64'})
self.assertEqual(
self.stdout.getvalue(),
"Your identification has been saved in %s\n"
"Your public key has been saved in %s.pub\n"
"The key fingerprint in <FingerprintFormats=SHA256_BASE64> is:\n"
"FBTCOoknq0mHy+kpfnY9tDdcAJuWtCpuQMaV3EsvbUI=\n" % (
filename,
filename))
self.assertEqual(
key.fromString(
base.child('id_rsa').getContent(), None, 'passphrase'),
key)
self.assertEqual(
Key.fromString(base.child('id_rsa.pub').getContent()),
key.public())
def test_saveKeyBadFingerPrintformat(self):
"""
L{_saveKey} raises C{keys.BadFingerprintFormat} when unsupported
formats are requested.
"""
base = FilePath(self.mktemp())
base.makedirs()
filename = base.child('id_rsa').path
key = Key.fromString(privateRSA_openssh)
with self.assertRaises(BadFingerPrintFormat) as em:
_saveKey(key, {'filename': filename, 'pass': 'passphrase',
'format': 'sha-base64'})
self.assertEqual('Unsupported fingerprint format: sha-base64',
em.exception.args[0])
def test_saveKeyEmptyPassphrase(self):
"""
L{_saveKey} will choose an empty string for the passphrase if
no-passphrase is C{True}.
"""
base = FilePath(self.mktemp())
base.makedirs()
filename = base.child('id_rsa').path
key = Key.fromString(privateRSA_openssh)
_saveKey(key, {'filename': filename, 'no-passphrase': True,
'format': 'md5-hex'})
self.assertEqual(
key.fromString(
base.child('id_rsa').getContent(), None, b''),
key)
def test_saveKeyECDSAEmptyPassphrase(self):
"""
L{_saveKey} will choose an empty string for the passphrase if
no-passphrase is C{True}.
"""
base = FilePath(self.mktemp())
base.makedirs()
filename = base.child('id_ecdsa').path
key = Key.fromString(privateECDSA_openssh)
_saveKey(key, {'filename': filename, 'no-passphrase': True,
'format': 'md5-hex'})
self.assertEqual(
key.fromString(
base.child('id_ecdsa').getContent(), None),
key)
def test_saveKeyNoFilename(self):
"""
When no path is specified, it will ask for the path used to store the
key.
"""
base = FilePath(self.mktemp())
base.makedirs()
keyPath = base.child('custom_key').path
import twisted.conch.scripts.ckeygen
self.patch(twisted.conch.scripts.ckeygen, 'raw_input', lambda _: keyPath)
key = Key.fromString(privateRSA_openssh)
_saveKey(key, {'filename': None, 'no-passphrase': True,
'format': 'md5-hex'})
persistedKeyContent = base.child('custom_key').getContent()
persistedKey = key.fromString(persistedKeyContent, None, b'')
self.assertEqual(key, persistedKey)
def test_saveKeySubtypeV1(self):
"""
L{_saveKey} can be told to write the new private key file in OpenSSH
v1 format.
"""
base = FilePath(self.mktemp())
base.makedirs()
filename = base.child('id_rsa').path
key = Key.fromString(privateRSA_openssh)
_saveKey(key, {
'filename': filename, 'pass': 'passphrase',
'format': 'md5-hex', 'private-key-subtype': 'v1',
})
self.assertEqual(
self.stdout.getvalue(),
"Your identification has been saved in %s\n"
"Your public key has been saved in %s.pub\n"
"The key fingerprint in <FingerprintFormats=MD5_HEX> is:\n"
"85:25:04:32:58:55:96:9f:57:ee:fb:a8:1a:ea:69:da\n" % (
filename,
filename))
privateKeyContent = base.child('id_rsa').getContent()
self.assertEqual(
key.fromString(privateKeyContent, None, 'passphrase'), key)
self.assertTrue(privateKeyContent.startswith(
b'-----BEGIN OPENSSH PRIVATE KEY-----\n'))
self.assertEqual(
Key.fromString(base.child('id_rsa.pub').getContent()),
key.public())
def test_displayPublicKey(self):
"""
L{displayPublicKey} prints out the public key associated with a given
private key.
"""
filename = self.mktemp()
pubKey = Key.fromString(publicRSA_openssh)
FilePath(filename).setContent(privateRSA_openssh)
displayPublicKey({'filename': filename})
displayed = self.stdout.getvalue().strip('\n')
if isinstance(displayed, unicode):
displayed = displayed.encode("ascii")
self.assertEqual(
displayed,
pubKey.toString('openssh'))
def test_displayPublicKeyEncrypted(self):
"""
L{displayPublicKey} prints out the public key associated with a given
private key using the given passphrase when it's encrypted.
"""
filename = self.mktemp()
pubKey = Key.fromString(publicRSA_openssh)
FilePath(filename).setContent(privateRSA_openssh_encrypted)
displayPublicKey({'filename': filename, 'pass': 'encrypted'})
displayed = self.stdout.getvalue().strip('\n')
if isinstance(displayed, unicode):
displayed = displayed.encode("ascii")
self.assertEqual(
displayed,
pubKey.toString('openssh'))
def test_displayPublicKeyEncryptedPassphrasePrompt(self):
"""
L{displayPublicKey} prints out the public key associated with a given
private key, asking for the passphrase when it's encrypted.
"""
filename = self.mktemp()
pubKey = Key.fromString(publicRSA_openssh)
FilePath(filename).setContent(privateRSA_openssh_encrypted)
self.patch(getpass, 'getpass', lambda x: 'encrypted')
displayPublicKey({'filename': filename})
displayed = self.stdout.getvalue().strip('\n')
if isinstance(displayed, unicode):
displayed = displayed.encode("ascii")
self.assertEqual(
displayed,
pubKey.toString('openssh'))
def test_displayPublicKeyWrongPassphrase(self):
"""
L{displayPublicKey} fails with a L{BadKeyError} when trying to decrypt
an encrypted key with the wrong password.
"""
filename = self.mktemp()
FilePath(filename).setContent(privateRSA_openssh_encrypted)
self.assertRaises(
BadKeyError, displayPublicKey,
{'filename': filename, 'pass': 'wrong'})
def test_changePassphrase(self):
"""
L{changePassPhrase} allows a user to change the passphrase of a
private key interactively.
"""
oldNewConfirm = makeGetpass('encrypted', 'newpass', 'newpass')
self.patch(getpass, 'getpass', oldNewConfirm)
filename = self.mktemp()
FilePath(filename).setContent(privateRSA_openssh_encrypted)
changePassPhrase({'filename': filename})
self.assertEqual(
self.stdout.getvalue().strip('\n'),
'Your identification has been saved with the new passphrase.')
self.assertNotEqual(privateRSA_openssh_encrypted,
FilePath(filename).getContent())
def test_changePassphraseWithOld(self):
"""
L{changePassPhrase} allows a user to change the passphrase of a
private key, providing the old passphrase and prompting for new one.
"""
newConfirm = makeGetpass('newpass', 'newpass')
self.patch(getpass, 'getpass', newConfirm)
filename = self.mktemp()
FilePath(filename).setContent(privateRSA_openssh_encrypted)
changePassPhrase({'filename': filename, 'pass': 'encrypted'})
self.assertEqual(
self.stdout.getvalue().strip('\n'),
'Your identification has been saved with the new passphrase.')
self.assertNotEqual(privateRSA_openssh_encrypted,
FilePath(filename).getContent())
def test_changePassphraseWithBoth(self):
"""
L{changePassPhrase} allows a user to change the passphrase of a private
key by providing both old and new passphrases without prompting.
"""
filename = self.mktemp()
FilePath(filename).setContent(privateRSA_openssh_encrypted)
changePassPhrase(
{'filename': filename, 'pass': 'encrypted',
'newpass': 'newencrypt'})
self.assertEqual(
self.stdout.getvalue().strip('\n'),
'Your identification has been saved with the new passphrase.')
self.assertNotEqual(privateRSA_openssh_encrypted,
FilePath(filename).getContent())
def test_changePassphraseWrongPassphrase(self):
"""
L{changePassPhrase} exits if passed an invalid old passphrase when
trying to change the passphrase of a private key.
"""
filename = self.mktemp()
FilePath(filename).setContent(privateRSA_openssh_encrypted)
error = self.assertRaises(
SystemExit, changePassPhrase,
{'filename': filename, 'pass': 'wrong'})
self.assertEqual('Could not change passphrase: old passphrase error',
str(error))
self.assertEqual(privateRSA_openssh_encrypted,
FilePath(filename).getContent())
def test_changePassphraseEmptyGetPass(self):
"""
L{changePassPhrase} exits if no passphrase is specified for the
C{getpass} call and the key is encrypted.
"""
self.patch(getpass, 'getpass', makeGetpass(''))
filename = self.mktemp()
FilePath(filename).setContent(privateRSA_openssh_encrypted)
error = self.assertRaises(
SystemExit, changePassPhrase, {'filename': filename})
self.assertEqual(
'Could not change passphrase: Passphrase must be provided '
'for an encrypted key',
str(error))
self.assertEqual(privateRSA_openssh_encrypted,
FilePath(filename).getContent())
def test_changePassphraseBadKey(self):
"""
L{changePassPhrase} exits if the file specified points to an invalid
key.
"""
filename = self.mktemp()
FilePath(filename).setContent(b'foobar')
error = self.assertRaises(
SystemExit, changePassPhrase, {'filename': filename})
if _PY3:
expected = "Could not change passphrase: cannot guess the type of b'foobar'"
else:
expected = "Could not change passphrase: cannot guess the type of 'foobar'"
self.assertEqual(expected, str(error))
self.assertEqual(b'foobar', FilePath(filename).getContent())
def test_changePassphraseCreateError(self):
"""
L{changePassPhrase} doesn't modify the key file if an unexpected error
happens when trying to create the key with the new passphrase.
"""
filename = self.mktemp()
FilePath(filename).setContent(privateRSA_openssh)
def toString(*args, **kwargs):
raise RuntimeError('oops')
self.patch(Key, 'toString', toString)
error = self.assertRaises(
SystemExit, changePassPhrase,
{'filename': filename,
'newpass': 'newencrypt'})
self.assertEqual(
'Could not change passphrase: oops', str(error))
self.assertEqual(privateRSA_openssh, FilePath(filename).getContent())
def test_changePassphraseEmptyStringError(self):
"""
L{changePassPhrase} doesn't modify the key file if C{toString} returns
an empty string.
"""
filename = self.mktemp()
FilePath(filename).setContent(privateRSA_openssh)
def toString(*args, **kwargs):
return ''
self.patch(Key, 'toString', toString)
error = self.assertRaises(
SystemExit, changePassPhrase,
{'filename': filename, 'newpass': 'newencrypt'})
if _PY3:
expected = (
"Could not change passphrase: cannot guess the type of b''")
else:
expected = (
"Could not change passphrase: cannot guess the type of ''")
self.assertEqual(expected, str(error))
self.assertEqual(privateRSA_openssh, FilePath(filename).getContent())
def test_changePassphrasePublicKey(self):
"""
L{changePassPhrase} exits when trying to change the passphrase on a
public key, and doesn't change the file.
"""
filename = self.mktemp()
FilePath(filename).setContent(publicRSA_openssh)
error = self.assertRaises(
SystemExit, changePassPhrase,
{'filename': filename, 'newpass': 'pass'})
self.assertEqual(
'Could not change passphrase: key not encrypted', str(error))
self.assertEqual(publicRSA_openssh, FilePath(filename).getContent())
def test_changePassphraseSubtypeV1(self):
"""
L{changePassPhrase} can be told to write the new private key file in
OpenSSH v1 format.
"""
oldNewConfirm = makeGetpass('encrypted', 'newpass', 'newpass')
self.patch(getpass, 'getpass', oldNewConfirm)
filename = self.mktemp()
FilePath(filename).setContent(privateRSA_openssh_encrypted)
changePassPhrase({'filename': filename, 'private-key-subtype': 'v1'})
self.assertEqual(
self.stdout.getvalue().strip('\n'),
'Your identification has been saved with the new passphrase.')
privateKeyContent = FilePath(filename).getContent()
self.assertNotEqual(privateRSA_openssh_encrypted, privateKeyContent)
self.assertTrue(privateKeyContent.startswith(
b'-----BEGIN OPENSSH PRIVATE KEY-----\n'))

View file

@ -0,0 +1,832 @@
# -*- test-case-name: twisted.conch.test.test_conch -*-
# Copyright (c) Twisted Matrix Laboratories.
# See LICENSE for details.
import os, sys, socket
import subprocess
from itertools import count
from zope.interface import implementer
from twisted.python.reflect import requireModule
from twisted.conch.error import ConchError
from twisted.cred import portal
from twisted.internet import reactor, defer, protocol
from twisted.internet.error import ProcessExitedAlready
from twisted.internet.task import LoopingCall
from twisted.internet.utils import getProcessValue
from twisted.python import filepath, log, runtime
from twisted.python.compat import unicode, _PYPY
from twisted.trial import unittest
from twisted.conch.test.test_ssh import ConchTestRealm
from twisted.python.procutils import which
from twisted.conch.test.keydata import publicRSA_openssh, privateRSA_openssh
from twisted.conch.test.keydata import publicDSA_openssh, privateDSA_openssh
from twisted.python.filepath import FilePath
from twisted.trial.unittest import SkipTest
try:
from twisted.conch.test.test_ssh import ConchTestServerFactory, \
conchTestPublicKeyChecker
except ImportError:
pass
try:
import pyasn1
except ImportError:
pyasn1 = None
cryptography = requireModule("cryptography")
if cryptography:
from twisted.conch.avatar import ConchUser
from twisted.conch.ssh.session import ISession, SSHSession, wrapProtocol
else:
from twisted.conch.interfaces import ISession
class ConchUser:
pass
try:
from twisted.conch.scripts.conch import (
SSHSession as StdioInteractingSession
)
except ImportError as e:
StdioInteractingSession = None
_reason = str(e)
del e
def _has_ipv6():
""" Returns True if the system can bind an IPv6 address."""
sock = None
has_ipv6 = False
try:
sock = socket.socket(socket.AF_INET6)
sock.bind(('::1', 0))
has_ipv6 = True
except socket.error:
pass
if sock:
sock.close()
return has_ipv6
HAS_IPV6 = _has_ipv6()
class FakeStdio(object):
"""
A fake for testing L{twisted.conch.scripts.conch.SSHSession.eofReceived} and
L{twisted.conch.scripts.cftp.SSHSession.eofReceived}.
@ivar writeConnLost: A flag which records whether L{loserWriteConnection}
has been called.
"""
writeConnLost = False
def loseWriteConnection(self):
"""
Record the call to loseWriteConnection.
"""
self.writeConnLost = True
class StdioInteractingSessionTests(unittest.TestCase):
"""
Tests for L{twisted.conch.scripts.conch.SSHSession}.
"""
if StdioInteractingSession is None:
skip = _reason
def test_eofReceived(self):
"""
L{twisted.conch.scripts.conch.SSHSession.eofReceived} loses the
write half of its stdio connection.
"""
stdio = FakeStdio()
channel = StdioInteractingSession()
channel.stdio = stdio
channel.eofReceived()
self.assertTrue(stdio.writeConnLost)
class Echo(protocol.Protocol):
def connectionMade(self):
log.msg('ECHO CONNECTION MADE')
def connectionLost(self, reason):
log.msg('ECHO CONNECTION DONE')
def dataReceived(self, data):
self.transport.write(data)
if b'\n' in data:
self.transport.loseConnection()
class EchoFactory(protocol.Factory):
protocol = Echo
class ConchTestOpenSSHProcess(protocol.ProcessProtocol):
"""
Test protocol for launching an OpenSSH client process.
@ivar deferred: Set by whatever uses this object. Accessed using
L{_getDeferred}, which destroys the value so the Deferred is not
fired twice. Fires when the process is terminated.
"""
deferred = None
buf = b''
problems = b''
def _getDeferred(self):
d, self.deferred = self.deferred, None
return d
def outReceived(self, data):
self.buf += data
def errReceived(self, data):
self.problems += data
def processEnded(self, reason):
"""
Called when the process has ended.
@param reason: a Failure giving the reason for the process' end.
"""
if reason.value.exitCode != 0:
self._getDeferred().errback(
ConchError(
"exit code was not 0: {} ({})".format(
reason.value.exitCode,
self.problems.decode("charmap"),
)
)
)
else:
buf = self.buf.replace(b'\r\n', b'\n')
self._getDeferred().callback(buf)
class ConchTestForwardingProcess(protocol.ProcessProtocol):
"""
Manages a third-party process which launches a server.
Uses L{ConchTestForwardingPort} to connect to the third-party server.
Once L{ConchTestForwardingPort} has disconnected, kill the process and fire
a Deferred with the data received by the L{ConchTestForwardingPort}.
@ivar deferred: Set by whatever uses this object. Accessed using
L{_getDeferred}, which destroys the value so the Deferred is not
fired twice. Fires when the process is terminated.
"""
deferred = None
def __init__(self, port, data):
"""
@type port: L{int}
@param port: The port on which the third-party server is listening.
(it is assumed that the server is running on localhost).
@type data: L{str}
@param data: This is sent to the third-party server. Must end with '\n'
in order to trigger a disconnect.
"""
self.port = port
self.buffer = None
self.data = data
def _getDeferred(self):
d, self.deferred = self.deferred, None
return d
def connectionMade(self):
self._connect()
def _connect(self):
"""
Connect to the server, which is often a third-party process.
Tries to reconnect if it fails because we have no way of determining
exactly when the port becomes available for listening -- we can only
know when the process starts.
"""
cc = protocol.ClientCreator(reactor, ConchTestForwardingPort, self,
self.data)
d = cc.connectTCP('127.0.0.1', self.port)
d.addErrback(self._ebConnect)
return d
def _ebConnect(self, f):
reactor.callLater(.1, self._connect)
def forwardingPortDisconnected(self, buffer):
"""
The network connection has died; save the buffer of output
from the network and attempt to quit the process gracefully,
and then (after the reactor has spun) send it a KILL signal.
"""
self.buffer = buffer
self.transport.write(b'\x03')
self.transport.loseConnection()
reactor.callLater(0, self._reallyDie)
def _reallyDie(self):
try:
self.transport.signalProcess('KILL')
except ProcessExitedAlready:
pass
def processEnded(self, reason):
"""
Fire the Deferred at self.deferred with the data collected
from the L{ConchTestForwardingPort} connection, if any.
"""
self._getDeferred().callback(self.buffer)
class ConchTestForwardingPort(protocol.Protocol):
"""
Connects to server launched by a third-party process (managed by
L{ConchTestForwardingProcess}) sends data, then reports whatever it
received back to the L{ConchTestForwardingProcess} once the connection
is ended.
"""
def __init__(self, protocol, data):
"""
@type protocol: L{ConchTestForwardingProcess}
@param protocol: The L{ProcessProtocol} which made this connection.
@type data: str
@param data: The data to be sent to the third-party server.
"""
self.protocol = protocol
self.data = data
def connectionMade(self):
self.buffer = b''
self.transport.write(self.data)
def dataReceived(self, data):
self.buffer += data
def connectionLost(self, reason):
self.protocol.forwardingPortDisconnected(self.buffer)
def _makeArgs(args, mod="conch"):
start = [sys.executable, '-c'
"""
### Twisted Preamble
import sys, os
path = os.path.abspath(sys.argv[0])
while os.path.dirname(path) != path:
if os.path.basename(path).startswith('Twisted'):
sys.path.insert(0, path)
break
path = os.path.dirname(path)
from twisted.conch.scripts.%s import run
run()""" % mod]
madeArgs = []
for arg in start + list(args):
if isinstance(arg, unicode):
arg = arg.encode("utf-8")
madeArgs.append(arg)
return madeArgs
class ConchServerSetupMixin:
if not cryptography:
skip = "can't run without cryptography"
if not pyasn1:
skip = "Cannot run without PyASN1"
# FIXME: https://twistedmatrix.com/trac/ticket/8506
# This should be un-skipped on Travis after the ticket is fixed. For now
# is enabled so that we can continue with fixing other stuff using Travis.
if _PYPY:
skip = 'PyPy known_host not working yet on Travis.'
realmFactory = staticmethod(lambda: ConchTestRealm(b'testuser'))
def _createFiles(self):
for f in ['rsa_test','rsa_test.pub','dsa_test','dsa_test.pub',
'kh_test']:
if os.path.exists(f):
os.remove(f)
with open('rsa_test', 'wb') as f:
f.write(privateRSA_openssh)
with open('rsa_test.pub', 'wb') as f:
f.write(publicRSA_openssh)
with open('dsa_test.pub', 'wb') as f:
f.write(publicDSA_openssh)
with open('dsa_test', 'wb') as f:
f.write(privateDSA_openssh)
os.chmod('dsa_test', 0o600)
os.chmod('rsa_test', 0o600)
permissions = FilePath('dsa_test').getPermissions()
if permissions.group.read or permissions.other.read:
raise SkipTest(
"private key readable by others despite chmod;"
" possible windows permission issue?"
" see https://tm.tl/9767"
)
with open('kh_test', 'wb') as f:
f.write(b'127.0.0.1 '+publicRSA_openssh)
def _getFreePort(self):
s = socket.socket()
s.bind(('', 0))
port = s.getsockname()[1]
s.close()
return port
def _makeConchFactory(self):
"""
Make a L{ConchTestServerFactory}, which allows us to start a
L{ConchTestServer} -- i.e. an actually listening conch.
"""
realm = self.realmFactory()
p = portal.Portal(realm)
p.registerChecker(conchTestPublicKeyChecker())
factory = ConchTestServerFactory()
factory.portal = p
return factory
def setUp(self):
self._createFiles()
self.conchFactory = self._makeConchFactory()
self.conchFactory.expectedLoseConnection = 1
self.conchServer = reactor.listenTCP(0, self.conchFactory,
interface="127.0.0.1")
self.echoServer = reactor.listenTCP(0, EchoFactory())
self.echoPort = self.echoServer.getHost().port
if HAS_IPV6:
self.echoServerV6 = reactor.listenTCP(0, EchoFactory(), interface="::1")
self.echoPortV6 = self.echoServerV6.getHost().port
def tearDown(self):
try:
self.conchFactory.proto.done = 1
except AttributeError:
pass
else:
self.conchFactory.proto.transport.loseConnection()
deferreds = [
defer.maybeDeferred(self.conchServer.stopListening),
defer.maybeDeferred(self.echoServer.stopListening),
]
if HAS_IPV6:
deferreds.append(defer.maybeDeferred(self.echoServerV6.stopListening))
return defer.gatherResults(deferreds)
class ForwardingMixin(ConchServerSetupMixin):
"""
Template class for tests of the Conch server's ability to forward arbitrary
protocols over SSH.
These tests are integration tests, not unit tests. They launch a Conch
server, a custom TCP server (just an L{EchoProtocol}) and then call
L{execute}.
L{execute} is implemented by subclasses of L{ForwardingMixin}. It should
cause an SSH client to connect to the Conch server, asking it to forward
data to the custom TCP server.
"""
def test_exec(self):
"""
Test that we can use whatever client to send the command "echo goodbye"
to the Conch server. Make sure we receive "goodbye" back from the
server.
"""
d = self.execute('echo goodbye', ConchTestOpenSSHProcess())
return d.addCallback(self.assertEqual, b'goodbye\n')
def test_localToRemoteForwarding(self):
"""
Test that we can use whatever client to forward a local port to a
specified port on the server.
"""
localPort = self._getFreePort()
process = ConchTestForwardingProcess(localPort, b'test\n')
d = self.execute('', process,
sshArgs='-N -L%i:127.0.0.1:%i'
% (localPort, self.echoPort))
d.addCallback(self.assertEqual, b'test\n')
return d
def test_remoteToLocalForwarding(self):
"""
Test that we can use whatever client to forward a port from the server
to a port locally.
"""
localPort = self._getFreePort()
process = ConchTestForwardingProcess(localPort, b'test\n')
d = self.execute('', process,
sshArgs='-N -R %i:127.0.0.1:%i'
% (localPort, self.echoPort))
d.addCallback(self.assertEqual, b'test\n')
return d
# Conventionally there is a separate adapter object which provides ISession for
# the user, but making the user provide ISession directly works too. This isn't
# a full implementation of ISession though, just enough to make these tests
# pass.
@implementer(ISession)
class RekeyAvatar(ConchUser):
"""
This avatar implements a shell which sends 60 numbered lines to whatever
connects to it, then closes the session with a 0 exit status.
60 lines is selected as being enough to send more than 2kB of traffic, the
amount the client is configured to initiate a rekey after.
"""
def __init__(self):
ConchUser.__init__(self)
self.channelLookup[b'session'] = SSHSession
def openShell(self, transport):
"""
Write 60 lines of data to the transport, then exit.
"""
proto = protocol.Protocol()
proto.makeConnection(transport)
transport.makeConnection(wrapProtocol(proto))
# Send enough bytes to the connection so that a rekey is triggered in
# the client.
def write(counter):
i = next(counter)
if i == 60:
call.stop()
transport.session.conn.sendRequest(
transport.session, b'exit-status', b'\x00\x00\x00\x00')
transport.loseConnection()
else:
line = "line #%02d\n" % (i,)
line = line.encode("utf-8")
transport.write(line)
# The timing for this loop is an educated guess (and/or the result of
# experimentation) to exercise the case where a packet is generated
# mid-rekey. Since the other side of the connection is (so far) the
# OpenSSH command line client, there's no easy way to determine when the
# rekey has been initiated. If there were, then generating a packet
# immediately at that time would be a better way to test the
# functionality being tested here.
call = LoopingCall(write, count())
call.start(0.01)
def closed(self):
"""
Ignore the close of the session.
"""
class RekeyRealm:
"""
This realm gives out new L{RekeyAvatar} instances for any avatar request.
"""
def requestAvatar(self, avatarID, mind, *interfaces):
return interfaces[0], RekeyAvatar(), lambda: None
class RekeyTestsMixin(ConchServerSetupMixin):
"""
TestCase mixin which defines tests exercising L{SSHTransportBase}'s handling
of rekeying messages.
"""
realmFactory = RekeyRealm
def test_clientRekey(self):
"""
After a client-initiated rekey is completed, application data continues
to be passed over the SSH connection.
"""
process = ConchTestOpenSSHProcess()
d = self.execute("", process, '-o RekeyLimit=2K')
def finished(result):
expectedResult = '\n'.join(['line #%02d' % (i,) for i in range(60)]) + '\n'
expectedResult = expectedResult.encode("utf-8")
self.assertEqual(result, expectedResult)
d.addCallback(finished)
return d
class OpenSSHClientMixin:
if not which('ssh'):
skip = "no ssh command-line client available"
def execute(self, remoteCommand, process, sshArgs=''):
"""
Connects to the SSH server started in L{ConchServerSetupMixin.setUp} by
running the 'ssh' command line tool.
@type remoteCommand: str
@param remoteCommand: The command (with arguments) to run on the
remote end.
@type process: L{ConchTestOpenSSHProcess}
@type sshArgs: str
@param sshArgs: Arguments to pass to the 'ssh' process.
@return: L{defer.Deferred}
"""
# PubkeyAcceptedKeyTypes does not exist prior to OpenSSH 7.0 so we
# first need to check if we can set it. If we can, -V will just print
# the version without doing anything else; if we can't, we will get a
# configuration error.
d = getProcessValue(
which('ssh')[0], ('-o', 'PubkeyAcceptedKeyTypes=ssh-dss', '-V'))
def hasPAKT(status):
if status == 0:
opts = '-oPubkeyAcceptedKeyTypes=ssh-dss '
else:
opts = ''
process.deferred = defer.Deferred()
# Pass -F /dev/null to avoid the user's configuration file from
# being loaded, as it may contain settings that cause our tests to
# fail or hang.
cmdline = ('ssh -2 -l testuser -p %i '
'-F /dev/null '
'-oUserKnownHostsFile=kh_test '
'-oPasswordAuthentication=no '
# Always use the RSA key, since that's the one in kh_test.
'-oHostKeyAlgorithms=ssh-rsa '
'-a '
'-i dsa_test ') + opts + sshArgs + \
' 127.0.0.1 ' + remoteCommand
port = self.conchServer.getHost().port
cmds = (cmdline % port).split()
encodedCmds = []
for cmd in cmds:
if isinstance(cmd, unicode):
cmd = cmd.encode("utf-8")
encodedCmds.append(cmd)
reactor.spawnProcess(process, which('ssh')[0], encodedCmds)
return process.deferred
return d.addCallback(hasPAKT)
class OpenSSHKeyExchangeTests(ConchServerSetupMixin, OpenSSHClientMixin,
unittest.TestCase):
"""
Tests L{SSHTransportBase}'s key exchange algorithm compatibility with
OpenSSH.
"""
def assertExecuteWithKexAlgorithm(self, keyExchangeAlgo):
"""
Call execute() method of L{OpenSSHClientMixin} with an ssh option that
forces the exclusive use of the key exchange algorithm specified by
keyExchangeAlgo
@type keyExchangeAlgo: L{str}
@param keyExchangeAlgo: The key exchange algorithm to use
@return: L{defer.Deferred}
"""
kexAlgorithms = []
try:
output = subprocess.check_output([which('ssh')[0], '-Q', 'kex'],
stderr=subprocess.STDOUT)
if not isinstance(output, str):
output = output.decode("utf-8")
kexAlgorithms = output.split()
except:
pass
if keyExchangeAlgo not in kexAlgorithms:
raise unittest.SkipTest(
"{} not supported by ssh client".format(
keyExchangeAlgo))
d = self.execute('echo hello', ConchTestOpenSSHProcess(),
'-oKexAlgorithms=' + keyExchangeAlgo)
return d.addCallback(self.assertEqual, b'hello\n')
def test_ECDHSHA256(self):
"""
The ecdh-sha2-nistp256 key exchange algorithm is compatible with
OpenSSH
"""
return self.assertExecuteWithKexAlgorithm(
'ecdh-sha2-nistp256')
def test_ECDHSHA384(self):
"""
The ecdh-sha2-nistp384 key exchange algorithm is compatible with
OpenSSH
"""
return self.assertExecuteWithKexAlgorithm(
'ecdh-sha2-nistp384')
def test_ECDHSHA521(self):
"""
The ecdh-sha2-nistp521 key exchange algorithm is compatible with
OpenSSH
"""
return self.assertExecuteWithKexAlgorithm(
'ecdh-sha2-nistp521')
def test_DH_GROUP14(self):
"""
The diffie-hellman-group14-sha1 key exchange algorithm is compatible
with OpenSSH.
"""
return self.assertExecuteWithKexAlgorithm(
'diffie-hellman-group14-sha1')
def test_DH_GROUP_EXCHANGE_SHA1(self):
"""
The diffie-hellman-group-exchange-sha1 key exchange algorithm is
compatible with OpenSSH.
"""
return self.assertExecuteWithKexAlgorithm(
'diffie-hellman-group-exchange-sha1')
def test_DH_GROUP_EXCHANGE_SHA256(self):
"""
The diffie-hellman-group-exchange-sha256 key exchange algorithm is
compatible with OpenSSH.
"""
return self.assertExecuteWithKexAlgorithm(
'diffie-hellman-group-exchange-sha256')
def test_unsupported_algorithm(self):
"""
The list of key exchange algorithms supported
by OpenSSH client is obtained with C{ssh -Q kex}.
"""
self.assertRaises(unittest.SkipTest,
self.assertExecuteWithKexAlgorithm,
'unsupported-algorithm')
class OpenSSHClientForwardingTests(ForwardingMixin, OpenSSHClientMixin,
unittest.TestCase):
"""
Connection forwarding tests run against the OpenSSL command line client.
"""
def test_localToRemoteForwardingV6(self):
"""
Forwarding of arbitrary IPv6 TCP connections via SSH.
"""
localPort = self._getFreePort()
process = ConchTestForwardingProcess(localPort, b'test\n')
d = self.execute('', process,
sshArgs='-N -L%i:[::1]:%i'
% (localPort, self.echoPortV6))
d.addCallback(self.assertEqual, b'test\n')
return d
if not HAS_IPV6:
test_localToRemoteForwardingV6.skip = "Requires IPv6 support"
class OpenSSHClientRekeyTests(RekeyTestsMixin, OpenSSHClientMixin,
unittest.TestCase):
"""
Rekeying tests run against the OpenSSL command line client.
"""
class CmdLineClientTests(ForwardingMixin, unittest.TestCase):
"""
Connection forwarding tests run against the Conch command line client.
"""
if runtime.platformType == 'win32':
skip = "can't run cmdline client on win32"
def execute(self, remoteCommand, process, sshArgs='', conchArgs=None):
"""
As for L{OpenSSHClientTestCase.execute}, except it runs the 'conch'
command line tool, not 'ssh'.
"""
if conchArgs is None:
conchArgs = []
process.deferred = defer.Deferred()
port = self.conchServer.getHost().port
cmd = ('-p {} -l testuser '
'--known-hosts kh_test '
'--user-authentications publickey '
'-a '
'-i dsa_test '
'-v '.format(port) + sshArgs +
' 127.0.0.1 ' + remoteCommand)
cmds = _makeArgs(conchArgs + cmd.split())
env = os.environ.copy()
env['PYTHONPATH'] = os.pathsep.join(sys.path)
encodedCmds = []
encodedEnv = {}
for cmd in cmds:
if isinstance(cmd, unicode):
cmd = cmd.encode("utf-8")
encodedCmds.append(cmd)
for var in env:
val = env[var]
if isinstance(var, unicode):
var = var.encode("utf-8")
if isinstance(val, unicode):
val = val.encode("utf-8")
encodedEnv[var] = val
reactor.spawnProcess(process, sys.executable, encodedCmds, env=encodedEnv)
return process.deferred
def test_runWithLogFile(self):
"""
It can store logs to a local file.
"""
def cb_check_log(result):
logContent = logPath.getContent()
self.assertIn(b'Log opened.', logContent)
logPath = filepath.FilePath(self.mktemp())
d = self.execute(
remoteCommand='echo goodbye',
process=ConchTestOpenSSHProcess(),
conchArgs=['--log', '--logfile', logPath.path,
'--host-key-algorithms', 'ssh-rsa']
)
d.addCallback(self.assertEqual, b'goodbye\n')
d.addCallback(cb_check_log)
return d
def test_runWithNoHostAlgorithmsSpecified(self):
"""
Do not use --host-key-algorithms flag on command line.
"""
d = self.execute(
remoteCommand='echo goodbye',
process=ConchTestOpenSSHProcess()
)
d.addCallback(self.assertEqual, b'goodbye\n')
return d

View file

@ -0,0 +1,761 @@
# Copyright (c) 2007-2010 Twisted Matrix Laboratories.
# See LICENSE for details
"""
This module tests twisted.conch.ssh.connection.
"""
from __future__ import division, absolute_import
import struct
from twisted.python.reflect import requireModule
cryptography = requireModule("cryptography")
from twisted.conch import error
if cryptography:
from twisted.conch.ssh import common, connection
else:
class connection:
class SSHConnection: pass
from twisted.conch.ssh import channel
from twisted.python.compat import long
from twisted.trial import unittest
from twisted.conch.test import test_userauth
class TestChannel(channel.SSHChannel):
"""
A mocked-up version of twisted.conch.ssh.channel.SSHChannel.
@ivar gotOpen: True if channelOpen has been called.
@type gotOpen: L{bool}
@ivar specificData: the specific channel open data passed to channelOpen.
@type specificData: L{bytes}
@ivar openFailureReason: the reason passed to openFailed.
@type openFailed: C{error.ConchError}
@ivar inBuffer: a C{list} of strings received by the channel.
@type inBuffer: C{list}
@ivar extBuffer: a C{list} of 2-tuples (type, extended data) of received by
the channel.
@type extBuffer: C{list}
@ivar numberRequests: the number of requests that have been made to this
channel.
@type numberRequests: L{int}
@ivar gotEOF: True if the other side sent EOF.
@type gotEOF: L{bool}
@ivar gotOneClose: True if the other side closed the connection.
@type gotOneClose: L{bool}
@ivar gotClosed: True if the channel is closed.
@type gotClosed: L{bool}
"""
name = b"TestChannel"
gotOpen = False
gotClosed = False
def logPrefix(self):
return "TestChannel %i" % self.id
def channelOpen(self, specificData):
"""
The channel is open. Set up the instance variables.
"""
self.gotOpen = True
self.specificData = specificData
self.inBuffer = []
self.extBuffer = []
self.numberRequests = 0
self.gotEOF = False
self.gotOneClose = False
self.gotClosed = False
def openFailed(self, reason):
"""
Opening the channel failed. Store the reason why.
"""
self.openFailureReason = reason
def request_test(self, data):
"""
A test request. Return True if data is 'data'.
@type data: L{bytes}
"""
self.numberRequests += 1
return data == b'data'
def dataReceived(self, data):
"""
Data was received. Store it in the buffer.
"""
self.inBuffer.append(data)
def extReceived(self, code, data):
"""
Extended data was received. Store it in the buffer.
"""
self.extBuffer.append((code, data))
def eofReceived(self):
"""
EOF was received. Remember it.
"""
self.gotEOF = True
def closeReceived(self):
"""
Close was received. Remember it.
"""
self.gotOneClose = True
def closed(self):
"""
The channel is closed. Rembember it.
"""
self.gotClosed = True
class TestAvatar:
"""
A mocked-up version of twisted.conch.avatar.ConchUser
"""
_ARGS_ERROR_CODE = 123
def lookupChannel(self, channelType, windowSize, maxPacket, data):
"""
The server wants us to return a channel. If the requested channel is
our TestChannel, return it, otherwise return None.
"""
if channelType == TestChannel.name:
return TestChannel(remoteWindow=windowSize,
remoteMaxPacket=maxPacket,
data=data, avatar=self)
elif channelType == b"conch-error-args":
# Raise a ConchError with backwards arguments to make sure the
# connection fixes it for us. This case should be deprecated and
# deleted eventually, but only after all of Conch gets the argument
# order right.
raise error.ConchError(
self._ARGS_ERROR_CODE, "error args in wrong order")
def gotGlobalRequest(self, requestType, data):
"""
The client has made a global request. If the global request is
'TestGlobal', return True. If the global request is 'TestData',
return True and the request-specific data we received. Otherwise,
return False.
"""
if requestType == b'TestGlobal':
return True
elif requestType == b'TestData':
return True, data
else:
return False
class TestConnection(connection.SSHConnection):
"""
A subclass of SSHConnection for testing.
@ivar channel: the current channel.
@type channel. C{TestChannel}
"""
if not cryptography:
skip = "Cannot run without cryptography"
def logPrefix(self):
return "TestConnection"
def global_TestGlobal(self, data):
"""
The other side made the 'TestGlobal' global request. Return True.
"""
return True
def global_Test_Data(self, data):
"""
The other side made the 'Test-Data' global request. Return True and
the data we received.
"""
return True, data
def channel_TestChannel(self, windowSize, maxPacket, data):
"""
The other side is requesting the TestChannel. Create a C{TestChannel}
instance, store it, and return it.
"""
self.channel = TestChannel(remoteWindow=windowSize,
remoteMaxPacket=maxPacket, data=data)
return self.channel
def channel_ErrorChannel(self, windowSize, maxPacket, data):
"""
The other side is requesting the ErrorChannel. Raise an exception.
"""
raise AssertionError('no such thing')
class ConnectionTests(unittest.TestCase):
if not cryptography:
skip = "Cannot run without cryptography"
if test_userauth.transport is None:
skip = "Cannot run without both cryptography and pyasn1"
def setUp(self):
self.transport = test_userauth.FakeTransport(None)
self.transport.avatar = TestAvatar()
self.conn = TestConnection()
self.conn.transport = self.transport
self.conn.serviceStarted()
def _openChannel(self, channel):
"""
Open the channel with the default connection.
"""
self.conn.openChannel(channel)
self.transport.packets = self.transport.packets[:-1]
self.conn.ssh_CHANNEL_OPEN_CONFIRMATION(struct.pack('>2L',
channel.id, 255) + b'\x00\x02\x00\x00\x00\x00\x80\x00')
def tearDown(self):
self.conn.serviceStopped()
def test_linkAvatar(self):
"""
Test that the connection links itself to the avatar in the
transport.
"""
self.assertIs(self.transport.avatar.conn, self.conn)
def test_serviceStopped(self):
"""
Test that serviceStopped() closes any open channels.
"""
channel1 = TestChannel()
channel2 = TestChannel()
self.conn.openChannel(channel1)
self.conn.openChannel(channel2)
self.conn.ssh_CHANNEL_OPEN_CONFIRMATION(b'\x00\x00\x00\x00' * 4)
self.assertTrue(channel1.gotOpen)
self.assertFalse(channel1.gotClosed)
self.assertFalse(channel2.gotOpen)
self.assertFalse(channel2.gotClosed)
self.conn.serviceStopped()
self.assertTrue(channel1.gotClosed)
self.assertFalse(channel2.gotOpen)
self.assertFalse(channel2.gotClosed)
from twisted.internet.error import ConnectionLost
self.assertIsInstance(channel2.openFailureReason,
ConnectionLost)
def test_GLOBAL_REQUEST(self):
"""
Test that global request packets are dispatched to the global_*
methods and the return values are translated into success or failure
messages.
"""
self.conn.ssh_GLOBAL_REQUEST(common.NS(b'TestGlobal') + b'\xff')
self.assertEqual(self.transport.packets,
[(connection.MSG_REQUEST_SUCCESS, b'')])
self.transport.packets = []
self.conn.ssh_GLOBAL_REQUEST(common.NS(b'TestData') + b'\xff' +
b'test data')
self.assertEqual(self.transport.packets,
[(connection.MSG_REQUEST_SUCCESS, b'test data')])
self.transport.packets = []
self.conn.ssh_GLOBAL_REQUEST(common.NS(b'TestBad') + b'\xff')
self.assertEqual(self.transport.packets,
[(connection.MSG_REQUEST_FAILURE, b'')])
self.transport.packets = []
self.conn.ssh_GLOBAL_REQUEST(common.NS(b'TestGlobal') + b'\x00')
self.assertEqual(self.transport.packets, [])
def test_REQUEST_SUCCESS(self):
"""
Test that global request success packets cause the Deferred to be
called back.
"""
d = self.conn.sendGlobalRequest(b'request', b'data', True)
self.conn.ssh_REQUEST_SUCCESS(b'data')
def check(data):
self.assertEqual(data, b'data')
d.addCallback(check)
d.addErrback(self.fail)
return d
def test_REQUEST_FAILURE(self):
"""
Test that global request failure packets cause the Deferred to be
erred back.
"""
d = self.conn.sendGlobalRequest(b'request', b'data', True)
self.conn.ssh_REQUEST_FAILURE(b'data')
def check(f):
self.assertEqual(f.value.data, b'data')
d.addCallback(self.fail)
d.addErrback(check)
return d
def test_CHANNEL_OPEN(self):
"""
Test that open channel packets cause a channel to be created and
opened or a failure message to be returned.
"""
del self.transport.avatar
self.conn.ssh_CHANNEL_OPEN(common.NS(b'TestChannel') +
b'\x00\x00\x00\x01' * 4)
self.assertTrue(self.conn.channel.gotOpen)
self.assertEqual(self.conn.channel.conn, self.conn)
self.assertEqual(self.conn.channel.data, b'\x00\x00\x00\x01')
self.assertEqual(self.conn.channel.specificData, b'\x00\x00\x00\x01')
self.assertEqual(self.conn.channel.remoteWindowLeft, 1)
self.assertEqual(self.conn.channel.remoteMaxPacket, 1)
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_OPEN_CONFIRMATION,
b'\x00\x00\x00\x01\x00\x00\x00\x00\x00\x02\x00\x00'
b'\x00\x00\x80\x00')])
self.transport.packets = []
self.conn.ssh_CHANNEL_OPEN(common.NS(b'BadChannel') +
b'\x00\x00\x00\x02' * 4)
self.flushLoggedErrors()
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_OPEN_FAILURE,
b'\x00\x00\x00\x02\x00\x00\x00\x03' + common.NS(
b'unknown channel') + common.NS(b''))])
self.transport.packets = []
self.conn.ssh_CHANNEL_OPEN(common.NS(b'ErrorChannel') +
b'\x00\x00\x00\x02' * 4)
self.flushLoggedErrors()
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_OPEN_FAILURE,
b'\x00\x00\x00\x02\x00\x00\x00\x02' + common.NS(
b'unknown failure') + common.NS(b''))])
def _lookupChannelErrorTest(self, code):
"""
Deliver a request for a channel open which will result in an exception
being raised during channel lookup. Assert that an error response is
delivered as a result.
"""
self.transport.avatar._ARGS_ERROR_CODE = code
self.conn.ssh_CHANNEL_OPEN(
common.NS(b'conch-error-args') + b'\x00\x00\x00\x01' * 4)
errors = self.flushLoggedErrors(error.ConchError)
self.assertEqual(
len(errors), 1, "Expected one error, got: %r" % (errors,))
self.assertEqual(errors[0].value.args, (long(123), "error args in wrong order"))
self.assertEqual(
self.transport.packets,
[(connection.MSG_CHANNEL_OPEN_FAILURE,
# The response includes some bytes which identifying the
# associated request, as well as the error code (7b in hex) and
# the error message.
b'\x00\x00\x00\x01\x00\x00\x00\x7b' + common.NS(
b'error args in wrong order') + common.NS(b''))])
def test_lookupChannelError(self):
"""
If a C{lookupChannel} implementation raises L{error.ConchError} with the
arguments in the wrong order, a C{MSG_CHANNEL_OPEN} failure is still
sent in response to the message.
This is a temporary work-around until L{error.ConchError} is given
better attributes and all of the Conch code starts constructing
instances of it properly. Eventually this functionality should be
deprecated and then removed.
"""
self._lookupChannelErrorTest(123)
def test_lookupChannelErrorLongCode(self):
"""
Like L{test_lookupChannelError}, but for the case where the failure code
is represented as a L{long} instead of a L{int}.
"""
self._lookupChannelErrorTest(long(123))
def test_CHANNEL_OPEN_CONFIRMATION(self):
"""
Test that channel open confirmation packets cause the channel to be
notified that it's open.
"""
channel = TestChannel()
self.conn.openChannel(channel)
self.conn.ssh_CHANNEL_OPEN_CONFIRMATION(b'\x00\x00\x00\x00'*5)
self.assertEqual(channel.remoteWindowLeft, 0)
self.assertEqual(channel.remoteMaxPacket, 0)
self.assertEqual(channel.specificData, b'\x00\x00\x00\x00')
self.assertEqual(self.conn.channelsToRemoteChannel[channel],
0)
self.assertEqual(self.conn.localToRemoteChannel[0], 0)
def test_CHANNEL_OPEN_FAILURE(self):
"""
Test that channel open failure packets cause the channel to be
notified that its opening failed.
"""
channel = TestChannel()
self.conn.openChannel(channel)
self.conn.ssh_CHANNEL_OPEN_FAILURE(b'\x00\x00\x00\x00\x00\x00\x00'
b'\x01' + common.NS(b'failure!'))
self.assertEqual(channel.openFailureReason.args, (b'failure!', 1))
self.assertIsNone(self.conn.channels.get(channel))
def test_CHANNEL_WINDOW_ADJUST(self):
"""
Test that channel window adjust messages add bytes to the channel
window.
"""
channel = TestChannel()
self._openChannel(channel)
oldWindowSize = channel.remoteWindowLeft
self.conn.ssh_CHANNEL_WINDOW_ADJUST(b'\x00\x00\x00\x00\x00\x00\x00'
b'\x01')
self.assertEqual(channel.remoteWindowLeft, oldWindowSize + 1)
def test_CHANNEL_DATA(self):
"""
Test that channel data messages are passed up to the channel, or
cause the channel to be closed if the data is too large.
"""
channel = TestChannel(localWindow=6, localMaxPacket=5)
self._openChannel(channel)
self.conn.ssh_CHANNEL_DATA(b'\x00\x00\x00\x00' + common.NS(b'data'))
self.assertEqual(channel.inBuffer, [b'data'])
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_WINDOW_ADJUST, b'\x00\x00\x00\xff'
b'\x00\x00\x00\x04')])
self.transport.packets = []
longData = b'a' * (channel.localWindowLeft + 1)
self.conn.ssh_CHANNEL_DATA(b'\x00\x00\x00\x00' + common.NS(longData))
self.assertEqual(channel.inBuffer, [b'data'])
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_CLOSE, b'\x00\x00\x00\xff')])
channel = TestChannel()
self._openChannel(channel)
bigData = b'a' * (channel.localMaxPacket + 1)
self.transport.packets = []
self.conn.ssh_CHANNEL_DATA(b'\x00\x00\x00\x01' + common.NS(bigData))
self.assertEqual(channel.inBuffer, [])
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_CLOSE, b'\x00\x00\x00\xff')])
def test_CHANNEL_EXTENDED_DATA(self):
"""
Test that channel extended data messages are passed up to the channel,
or cause the channel to be closed if they're too big.
"""
channel = TestChannel(localWindow=6, localMaxPacket=5)
self._openChannel(channel)
self.conn.ssh_CHANNEL_EXTENDED_DATA(b'\x00\x00\x00\x00\x00\x00\x00'
b'\x00' + common.NS(b'data'))
self.assertEqual(channel.extBuffer, [(0, b'data')])
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_WINDOW_ADJUST, b'\x00\x00\x00\xff'
b'\x00\x00\x00\x04')])
self.transport.packets = []
longData = b'a' * (channel.localWindowLeft + 1)
self.conn.ssh_CHANNEL_EXTENDED_DATA(b'\x00\x00\x00\x00\x00\x00\x00'
b'\x00' + common.NS(longData))
self.assertEqual(channel.extBuffer, [(0, b'data')])
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_CLOSE, b'\x00\x00\x00\xff')])
channel = TestChannel()
self._openChannel(channel)
bigData = b'a' * (channel.localMaxPacket + 1)
self.transport.packets = []
self.conn.ssh_CHANNEL_EXTENDED_DATA(b'\x00\x00\x00\x01\x00\x00\x00'
b'\x00' + common.NS(bigData))
self.assertEqual(channel.extBuffer, [])
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_CLOSE, b'\x00\x00\x00\xff')])
def test_CHANNEL_EOF(self):
"""
Test that channel eof messages are passed up to the channel.
"""
channel = TestChannel()
self._openChannel(channel)
self.conn.ssh_CHANNEL_EOF(b'\x00\x00\x00\x00')
self.assertTrue(channel.gotEOF)
def test_CHANNEL_CLOSE(self):
"""
Test that channel close messages are passed up to the channel. Also,
test that channel.close() is called if both sides are closed when this
message is received.
"""
channel = TestChannel()
self._openChannel(channel)
self.assertTrue(channel.gotOpen)
self.assertFalse(channel.gotOneClose)
self.assertFalse(channel.gotClosed)
self.conn.sendClose(channel)
self.conn.ssh_CHANNEL_CLOSE(b'\x00\x00\x00\x00')
self.assertTrue(channel.gotOneClose)
self.assertTrue(channel.gotClosed)
def test_CHANNEL_REQUEST_success(self):
"""
Test that channel requests that succeed send MSG_CHANNEL_SUCCESS.
"""
channel = TestChannel()
self._openChannel(channel)
self.conn.ssh_CHANNEL_REQUEST(b'\x00\x00\x00\x00' + common.NS(b'test')
+ b'\x00')
self.assertEqual(channel.numberRequests, 1)
d = self.conn.ssh_CHANNEL_REQUEST(b'\x00\x00\x00\x00' + common.NS(
b'test') + b'\xff' + b'data')
def check(result):
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_SUCCESS, b'\x00\x00\x00\xff')])
d.addCallback(check)
return d
def test_CHANNEL_REQUEST_failure(self):
"""
Test that channel requests that fail send MSG_CHANNEL_FAILURE.
"""
channel = TestChannel()
self._openChannel(channel)
d = self.conn.ssh_CHANNEL_REQUEST(b'\x00\x00\x00\x00' + common.NS(
b'test') + b'\xff')
def check(result):
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_FAILURE, b'\x00\x00\x00\xff'
)])
d.addCallback(self.fail)
d.addErrback(check)
return d
def test_CHANNEL_REQUEST_SUCCESS(self):
"""
Test that channel request success messages cause the Deferred to be
called back.
"""
channel = TestChannel()
self._openChannel(channel)
d = self.conn.sendRequest(channel, b'test', b'data', True)
self.conn.ssh_CHANNEL_SUCCESS(b'\x00\x00\x00\x00')
def check(result):
self.assertTrue(result)
return d
def test_CHANNEL_REQUEST_FAILURE(self):
"""
Test that channel request failure messages cause the Deferred to be
erred back.
"""
channel = TestChannel()
self._openChannel(channel)
d = self.conn.sendRequest(channel, b'test', b'', True)
self.conn.ssh_CHANNEL_FAILURE(b'\x00\x00\x00\x00')
def check(result):
self.assertEqual(result.value.value, 'channel request failed')
d.addCallback(self.fail)
d.addErrback(check)
return d
def test_sendGlobalRequest(self):
"""
Test that global request messages are sent in the right format.
"""
d = self.conn.sendGlobalRequest(b'wantReply', b'data', True)
# must be added to prevent errbacking during teardown
d.addErrback(lambda failure: None)
self.conn.sendGlobalRequest(b'noReply', b'', False)
self.assertEqual(self.transport.packets,
[(connection.MSG_GLOBAL_REQUEST, common.NS(b'wantReply') +
b'\xffdata'),
(connection.MSG_GLOBAL_REQUEST, common.NS(b'noReply') +
b'\x00')])
self.assertEqual(self.conn.deferreds, {'global':[d]})
def test_openChannel(self):
"""
Test that open channel messages are sent in the right format.
"""
channel = TestChannel()
self.conn.openChannel(channel, b'aaaa')
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_OPEN, common.NS(b'TestChannel') +
b'\x00\x00\x00\x00\x00\x02\x00\x00\x00\x00\x80\x00aaaa')])
self.assertEqual(channel.id, 0)
self.assertEqual(self.conn.localChannelID, 1)
def test_sendRequest(self):
"""
Test that channel request messages are sent in the right format.
"""
channel = TestChannel()
self._openChannel(channel)
d = self.conn.sendRequest(channel, b'test', b'test', True)
# needed to prevent errbacks during teardown.
d.addErrback(lambda failure: None)
self.conn.sendRequest(channel, b'test2', b'', False)
channel.localClosed = True # emulate sending a close message
self.conn.sendRequest(channel, b'test3', b'', True)
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_REQUEST, b'\x00\x00\x00\xff' +
common.NS(b'test') + b'\x01test'),
(connection.MSG_CHANNEL_REQUEST, b'\x00\x00\x00\xff' +
common.NS(b'test2') + b'\x00')])
self.assertEqual(self.conn.deferreds[0], [d])
def test_adjustWindow(self):
"""
Test that channel window adjust messages cause bytes to be added
to the window.
"""
channel = TestChannel(localWindow=5)
self._openChannel(channel)
channel.localWindowLeft = 0
self.conn.adjustWindow(channel, 1)
self.assertEqual(channel.localWindowLeft, 1)
channel.localClosed = True
self.conn.adjustWindow(channel, 2)
self.assertEqual(channel.localWindowLeft, 1)
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_WINDOW_ADJUST, b'\x00\x00\x00\xff'
b'\x00\x00\x00\x01')])
def test_sendData(self):
"""
Test that channel data messages are sent in the right format.
"""
channel = TestChannel()
self._openChannel(channel)
self.conn.sendData(channel, b'a')
channel.localClosed = True
self.conn.sendData(channel, b'b')
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_DATA, b'\x00\x00\x00\xff' +
common.NS(b'a'))])
def test_sendExtendedData(self):
"""
Test that channel extended data messages are sent in the right format.
"""
channel = TestChannel()
self._openChannel(channel)
self.conn.sendExtendedData(channel, 1, b'test')
channel.localClosed = True
self.conn.sendExtendedData(channel, 2, b'test2')
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_EXTENDED_DATA, b'\x00\x00\x00\xff' +
b'\x00\x00\x00\x01' + common.NS(b'test'))])
def test_sendEOF(self):
"""
Test that channel EOF messages are sent in the right format.
"""
channel = TestChannel()
self._openChannel(channel)
self.conn.sendEOF(channel)
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_EOF, b'\x00\x00\x00\xff')])
channel.localClosed = True
self.conn.sendEOF(channel)
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_EOF, b'\x00\x00\x00\xff')])
def test_sendClose(self):
"""
Test that channel close messages are sent in the right format.
"""
channel = TestChannel()
self._openChannel(channel)
self.conn.sendClose(channel)
self.assertTrue(channel.localClosed)
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_CLOSE, b'\x00\x00\x00\xff')])
self.conn.sendClose(channel)
self.assertEqual(self.transport.packets,
[(connection.MSG_CHANNEL_CLOSE, b'\x00\x00\x00\xff')])
channel2 = TestChannel()
self._openChannel(channel2)
self.assertTrue(channel2.gotOpen)
self.assertFalse(channel2.gotClosed)
channel2.remoteClosed = True
self.conn.sendClose(channel2)
self.assertTrue(channel2.gotClosed)
def test_getChannelWithAvatar(self):
"""
Test that getChannel dispatches to the avatar when an avatar is
present. Correct functioning without the avatar is verified in
test_CHANNEL_OPEN.
"""
channel = self.conn.getChannel(b'TestChannel', 50, 30, b'data')
self.assertEqual(channel.data, b'data')
self.assertEqual(channel.remoteWindowLeft, 50)
self.assertEqual(channel.remoteMaxPacket, 30)
self.assertRaises(error.ConchError, self.conn.getChannel,
b'BadChannel', 50, 30, b'data')
def test_gotGlobalRequestWithoutAvatar(self):
"""
Test that gotGlobalRequests dispatches to global_* without an avatar.
"""
del self.transport.avatar
self.assertTrue(self.conn.gotGlobalRequest(b'TestGlobal', b'data'))
self.assertEqual(self.conn.gotGlobalRequest(b'Test-Data', b'data'),
(True, b'data'))
self.assertFalse(self.conn.gotGlobalRequest(b'BadGlobal', b'data'))
def test_channelClosedCausesLeftoverChannelDeferredsToErrback(self):
"""
Whenever an SSH channel gets closed any Deferred that was returned by a
sendRequest() on its parent connection must be errbacked.
"""
channel = TestChannel()
self._openChannel(channel)
d = self.conn.sendRequest(
channel, b"dummyrequest", b"dummydata", wantReply=1)
d = self.assertFailure(d, error.ConchError)
self.conn.channelClosed(channel)
return d
class CleanConnectionShutdownTests(unittest.TestCase):
"""
Check whether correct cleanup is performed on connection shutdown.
"""
if not cryptography:
skip = "Cannot run without cryptography"
if test_userauth.transport is None:
skip = "Cannot run without both cryptography and pyasn1"
def setUp(self):
self.transport = test_userauth.FakeTransport(None)
self.transport.avatar = TestAvatar()
self.conn = TestConnection()
self.conn.transport = self.transport
def test_serviceStoppedCausesLeftoverGlobalDeferredsToErrback(self):
"""
Once the service is stopped any leftover global deferred returned by
a sendGlobalRequest() call must be errbacked.
"""
self.conn.serviceStarted()
d = self.conn.sendGlobalRequest(
b"dummyrequest", b"dummydata", wantReply=1)
d = self.assertFailure(d, error.ConchError)
self.conn.serviceStopped()
return d

Some files were not shown because too many files have changed in this diff Show more