1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
|
# Copyright 2011 Canonical Ltd. This software is licensed under the
# GNU Affero General Public License version 3 (see the file LICENSE).
"""Utilities for graceful shutdown of Twisted services."""
__metaclass__ = type
__all__ = [
'ConnTrackingFactoryWrapper',
'ShutdownCleanlyService',
'ServerAvailableResource',
'OrderedMultiService',
]
from twisted.application import service, strports
from twisted.protocols.policies import WrappingFactory
from twisted.internet.defer import (
Deferred,
gatherResults,
inlineCallbacks,
maybeDeferred,
)
from twisted.web import resource, server
from zope.interface import implements
class ConnTrackingFactoryWrapper(WrappingFactory):
"""A factory decorator that tracks the current connections made by this
factory.
"""
def __init__(self, wrappedFactory):
"""Constructor.
See WrappingFactory.__init__.
"""
WrappingFactory.__init__(self, wrappedFactory)
self.allConnectionsGone = None
def isAvailable(self):
"""Has this factory been stopped yet?"""
return self.allConnectionsGone is None
def stopFactory(self):
"""See WrappingFactory.stopFactory."""
WrappingFactory.stopFactory(self)
self.allConnectionsGone = Deferred()
if len(self.protocols) == 0:
self.allConnectionsGone.callback(None)
def unregisterProtocol(self, p):
"""See WrappingFactory.unregisterProtocol."""
WrappingFactory.unregisterProtocol(self, p)
if len(self.protocols) == 0:
if self.allConnectionsGone is not None:
self.allConnectionsGone.callback(None)
class ShutdownCleanlyService(service.MultiService):
"""A MultiService that doesn't stop until all connections of its factories
are closed.
This allows delaying a twistd process exiting until all clients have
disconnected from a server, for instance.
"""
def __init__(self, factories):
"""Constructor.
:param factories: A collection of ConnTrackingFactoryWrapper
instances.
"""
self.factories = factories
service.MultiService.__init__(self)
def stopService(self):
"""See service.MultiService.stopService."""
d = maybeDeferred(service.MultiService.stopService, self)
return d.addCallback(self._cbServicesStopped)
def _cbServicesStopped(self, ignored):
return gatherResults([f.allConnectionsGone for f in self.factories])
class ServerAvailableResource(resource.Resource):
"""A Resource indicating if a service is available for new connections.
A 200 response code (OK) indicates the service is available, and a 503
(Service Not Available) indicates the service is shutting down and no new
connections will be accepted.
This resource accepts both HEAD and GET requests. If the request is a GET
this resource also reports the number of connections and their peer
addresses in a human-friendly text/plain body.
"""
def __init__(self, tracked_factories):
resource.Resource.__init__(self)
self.tracked_factories = tracked_factories
def _render_common(self, request):
service_available = True
for tracked in self.tracked_factories:
if not tracked.isAvailable():
service_available = False
if service_available:
request.setResponseCode(200)
else:
request.setResponseCode(503)
request.setHeader('Content-Type', 'text/plain')
def render_GET(self, request):
"""Handler for GET requests. See resource.Resource.render."""
service_available = self._render_common(request)
# Generate a bit of text for humans' benefit.
tracked_connections = set()
for tracked in self.tracked_factories:
tracked_connections.update(tracked.protocols)
if service_available:
state_text = 'Available'
else:
state_text = 'Unavailable'
return '%s\n\n%d connections: \n\n%s\n' % (
state_text, len(tracked_connections),
'\n'.join(
[str(c.transport.getPeer()) for c in tracked_connections]))
def render_HEAD(self, request):
"""Handler for HEAD requests. See resource.Resource.render."""
self._render_common(request)
return ''
class OrderedMultiService(service.MultiService):
"""A MultiService that guarantees start and stop order.
Services are started in the order they are attached, and stopped in in
reverse order (waiting for each to stop before stopping the next).
"""
implements(service.IServiceCollection)
@inlineCallbacks
def stopService(self):
"""See service.MultiService.stopService."""
# intentionally skip MultiService.stopService
service.Service.stopService(self)
while self.services:
svc = self.services.pop()
yield maybeDeferred(svc.stopService)
def make_web_status_service(strport, tracking_factories):
"""Make a web site of ServerAvailableResource on a given port.
See daemons/sftp.tac for an example use.
:param strport: a strport describing the port the web service should
listen on.
:param tracking_factories: a collection of ConnTrackingFactoryWrapper
instances.
:returns: a service.Service
"""
server_available_resource = ServerAvailableResource(tracking_factories)
web_root = resource.Resource()
web_root.putChild('', server_available_resource)
web_factory = server.Site(web_root)
return strports.service(strport, web_factory)
|