~launchpad-pqm/launchpad/devel

« back to all changes in this revision

Viewing changes to lib/lp/codehosting/sshserver/tests/test_daemon.py

  • Committer: Launchpad Patch Queue Manager
  • Date: 2010-04-16 22:37:59 UTC
  • mfrom: (10693.3.32 extract-ssh-server-auth)
  • Revision ID: launchpad@pqm.canonical.com-20100416223759-nimb88k917yijz02
[r=intellectronica][ui=none] Move codehosting-specific stuff into new
        codehosting-specific module

Show diffs side-by-side

added added

removed removed

Lines of Context:
 
1
# Copyright 2010 Canonical Ltd.  This software is licensed under the
 
2
# GNU Affero General Public License version 3 (see the file LICENSE).
 
3
 
 
4
"""Tests for the codehosting SSH server glue."""
 
5
 
 
6
__metaclass__ = type
 
7
 
 
8
import unittest
 
9
 
 
10
from twisted.conch.ssh.common import NS
 
11
from twisted.conch.ssh.keys import Key
 
12
from twisted.test.proto_helpers import StringTransport
 
13
from twisted.trial.unittest import TestCase as TrialTestCase
 
14
 
 
15
from canonical.testing.layers import TwistedLayer
 
16
 
 
17
from lp.codehosting.sshserver.auth import SSHUserAuthServer
 
18
from lp.codehosting.sshserver.daemon import (
 
19
    get_key_path, get_portal, PRIVATE_KEY_FILE, PUBLIC_KEY_FILE)
 
20
from lp.codehosting.sshserver.service import Factory
 
21
 
 
22
 
 
23
class StringTransportWith_setTcpKeepAlive(StringTransport):
 
24
    def __init__(self, hostAddress=None, peerAddress=None):
 
25
        StringTransport.__init__(self, hostAddress, peerAddress)
 
26
        self._keepAlive = False
 
27
 
 
28
    def setTcpKeepAlive(self, flag):
 
29
        self._keepAlive = flag
 
30
 
 
31
 
 
32
class TestFactory(TrialTestCase):
 
33
    """Tests for our SSH factory."""
 
34
 
 
35
    layer = TwistedLayer
 
36
 
 
37
    def makeFactory(self):
 
38
        """Create and start the factory that our SSH server uses."""
 
39
        factory = Factory(
 
40
            get_portal(None, None),
 
41
            private_key=Key.fromFile(
 
42
                get_key_path(PRIVATE_KEY_FILE)),
 
43
            public_key=Key.fromFile(
 
44
                get_key_path(PUBLIC_KEY_FILE)))
 
45
        factory.startFactory()
 
46
        return factory
 
47
 
 
48
    def startConnecting(self, factory):
 
49
        """Connect to the `factory`."""
 
50
        server_transport = factory.buildProtocol(None)
 
51
        server_transport.makeConnection(StringTransportWith_setTcpKeepAlive())
 
52
        return server_transport
 
53
 
 
54
    def test_set_keepalive_on_connection(self):
 
55
        # The server transport sets TCP keep alives on the underlying
 
56
        # transport.
 
57
        factory = self.makeFactory()
 
58
        server_transport = self.startConnecting(factory)
 
59
        self.assertTrue(server_transport.transport._keepAlive)
 
60
 
 
61
    def beginAuthentication(self, factory):
 
62
        """Connect to `factory` and begin authentication on this connection.
 
63
 
 
64
        :return: The `SSHServerTransport` after the process of authentication
 
65
            has begun.
 
66
        """
 
67
        server_transport = self.startConnecting(factory)
 
68
        server_transport.ssh_SERVICE_REQUEST(NS('ssh-userauth'))
 
69
        self.addCleanup(server_transport.service.serviceStopped)
 
70
        return server_transport
 
71
 
 
72
    def test_authentication_uses_our_userauth_service(self):
 
73
        # The service of a SSHServerTransport after authentication has started
 
74
        # is an instance of our SSHUserAuthServer class.
 
75
        factory = self.makeFactory()
 
76
        transport = self.beginAuthentication(factory)
 
77
        self.assertIsInstance(transport.service, SSHUserAuthServer)
 
78
 
 
79
    def test_two_connections_two_minds(self):
 
80
        # Two attempts to authenticate do not share the user-details cache.
 
81
        factory = self.makeFactory()
 
82
 
 
83
        server_transport1 = self.beginAuthentication(factory)
 
84
        server_transport2 = self.beginAuthentication(factory)
 
85
 
 
86
        mind1 = server_transport1.service.getMind()
 
87
        mind2 = server_transport2.service.getMind()
 
88
 
 
89
        self.assertNotIdentical(mind1.cache, mind2.cache)
 
90
 
 
91
 
 
92
def test_suite():
 
93
    return unittest.TestLoader().loadTestsFromName(__name__)