1
# Copyright 2010 Canonical Ltd. This software is licensed under the
2
# GNU Affero General Public License version 3 (see the file LICENSE).
4
"""Tests for the codehosting SSH server glue."""
10
from twisted.conch.ssh.common import NS
11
from twisted.conch.ssh.keys import Key
12
from twisted.test.proto_helpers import StringTransport
13
from twisted.trial.unittest import TestCase as TrialTestCase
15
from canonical.testing.layers import TwistedLayer
17
from lp.codehosting.sshserver.auth import SSHUserAuthServer
18
from lp.codehosting.sshserver.daemon import (
19
get_key_path, get_portal, PRIVATE_KEY_FILE, PUBLIC_KEY_FILE)
20
from lp.codehosting.sshserver.service import Factory
23
class StringTransportWith_setTcpKeepAlive(StringTransport):
24
def __init__(self, hostAddress=None, peerAddress=None):
25
StringTransport.__init__(self, hostAddress, peerAddress)
26
self._keepAlive = False
28
def setTcpKeepAlive(self, flag):
29
self._keepAlive = flag
32
class TestFactory(TrialTestCase):
33
"""Tests for our SSH factory."""
37
def makeFactory(self):
38
"""Create and start the factory that our SSH server uses."""
40
get_portal(None, None),
41
private_key=Key.fromFile(
42
get_key_path(PRIVATE_KEY_FILE)),
43
public_key=Key.fromFile(
44
get_key_path(PUBLIC_KEY_FILE)))
45
factory.startFactory()
48
def startConnecting(self, factory):
49
"""Connect to the `factory`."""
50
server_transport = factory.buildProtocol(None)
51
server_transport.makeConnection(StringTransportWith_setTcpKeepAlive())
52
return server_transport
54
def test_set_keepalive_on_connection(self):
55
# The server transport sets TCP keep alives on the underlying
57
factory = self.makeFactory()
58
server_transport = self.startConnecting(factory)
59
self.assertTrue(server_transport.transport._keepAlive)
61
def beginAuthentication(self, factory):
62
"""Connect to `factory` and begin authentication on this connection.
64
:return: The `SSHServerTransport` after the process of authentication
67
server_transport = self.startConnecting(factory)
68
server_transport.ssh_SERVICE_REQUEST(NS('ssh-userauth'))
69
self.addCleanup(server_transport.service.serviceStopped)
70
return server_transport
72
def test_authentication_uses_our_userauth_service(self):
73
# The service of a SSHServerTransport after authentication has started
74
# is an instance of our SSHUserAuthServer class.
75
factory = self.makeFactory()
76
transport = self.beginAuthentication(factory)
77
self.assertIsInstance(transport.service, SSHUserAuthServer)
79
def test_two_connections_two_minds(self):
80
# Two attempts to authenticate do not share the user-details cache.
81
factory = self.makeFactory()
83
server_transport1 = self.beginAuthentication(factory)
84
server_transport2 = self.beginAuthentication(factory)
86
mind1 = server_transport1.service.getMind()
87
mind2 = server_transport2.service.getMind()
89
self.assertNotIdentical(mind1.cache, mind2.cache)
93
return unittest.TestLoader().loadTestsFromName(__name__)