forked from uktrade/dns-rewrite-proxy
-
Notifications
You must be signed in to change notification settings - Fork 0
/
dnsrewriteproxy.py
220 lines (180 loc) · 7.51 KB
/
dnsrewriteproxy.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
from asyncio import (
CancelledError,
Queue,
create_task,
get_running_loop,
)
from enum import (
IntEnum,
)
import logging
import re
from random import (
choices,
)
import string
import socket
from aiodnsresolver import (
RESPONSE,
TYPES,
DnsRecordDoesNotExist,
DnsResponseCode,
Message,
Resolver,
ResourceRecord,
ResolverLoggerAdapter,
pack,
parse,
recvfrom,
)
def get_socket_default():
sock = socket.socket(family=socket.AF_INET, type=socket.SOCK_DGRAM)
sock.setblocking(False)
sock.bind(('', 53))
return sock
def get_resolver_default():
return Resolver()
class DnsProxyLoggerAdapter(logging.LoggerAdapter):
def process(self, msg, kwargs):
return \
('[dnsproxy] %s' % (msg,), kwargs) if not self.extra else \
('[dnsproxy:%s] %s' % (','.join(str(v) for v in self.extra.values()), msg), kwargs)
def get_logger_adapter_default(extra):
return DnsProxyLoggerAdapter(logging.getLogger('dnsrewriteproxy'), extra)
def get_resolver_logger_adapter_default(parent_adapter):
def _get_resolver_logger_adapter_default(dns_extra):
return ResolverLoggerAdapter(parent_adapter, dns_extra)
return _get_resolver_logger_adapter_default
def DnsProxy(
get_resolver=get_resolver_default,
get_logger_adapter=get_logger_adapter_default,
get_resolver_logger_adapter=get_resolver_logger_adapter_default,
get_socket=get_socket_default, num_workers=1000,
rules=(),
):
class ERRORS(IntEnum):
FORMERR = 1
SERVFAIL = 2
NXDOMAIN = 3
REFUSED = 5
loop = get_running_loop()
logger = get_logger_adapter({})
request_id_alphabet = string.ascii_letters + string.digits
# The "main" task of the server: it receives incoming requests and puts
# them in a queue that is then fetched from and processed by the proxy
# workers
async def server_worker(sock, resolve, stop):
upstream_queue = Queue(maxsize=num_workers)
# We have multiple upstream workers to be able to send multiple
# requests upstream concurrently
upstream_worker_tasks = [
create_task(upstream_worker(sock, resolve, upstream_queue))
for _ in range(0, num_workers)]
try:
while True:
logger.info('Waiting for next request')
request_data, addr = await recvfrom(loop, [sock], 512)
request_logger = get_logger_adapter(
{'dnsrewriteproxy_requestid': ''.join(choices(request_id_alphabet, k=8))})
request_logger.info('Received request from %s', addr)
await upstream_queue.put((request_logger, request_data, addr))
finally:
logger.info('Stopping: waiting for requests to finish')
await upstream_queue.join()
logger.info('Stopping: cancelling workers...')
for upstream_task in upstream_worker_tasks:
upstream_task.cancel()
for upstream_task in upstream_worker_tasks:
try:
await upstream_task
except CancelledError:
pass
logger.info('Stopping: cancelling workers... (done)')
logger.info('Stopping: final cleanup')
await stop()
logger.info('Stopping: done')
async def upstream_worker(sock, resolve, upstream_queue):
while True:
request_logger, request_data, addr = await upstream_queue.get()
try:
request_logger.info('Processing request')
response_data = await get_response_data(request_logger, resolve, request_data)
# Sendto for non-blocking UDP sockets cannot raise a BlockingIOError
# https://stackoverflow.com/a/59794872/1319998
sock.sendto(response_data, addr)
except Exception:
request_logger.exception('Error processing request')
finally:
request_logger.info('Finished processing request')
upstream_queue.task_done()
async def get_response_data(request_logger, resolve, request_data):
# This may raise an exception, which is handled at a higher level.
# We can't [and I suspect shouldn't try to] return an error to the
# client, since we're not able to extract the QID, so the client won't
# be able to match it with an outgoing request
query = parse(request_data)
try:
return pack(await proxy(request_logger, resolve, query))
except Exception:
request_logger.exception('Failed to proxy %s', query)
return pack(error(query, ERRORS.SERVFAIL))
async def proxy(request_logger, resolve, query):
name_bytes = query.qd[0].name
request_logger.info('Name: %s', name_bytes)
name_str_lower = query.qd[0].name.lower().decode('idna')
request_logger.info('Decoded: %s', name_str_lower)
if query.qd[0].qtype != TYPES.A:
request_logger.info('Unhandled query type: %s', query.qd[0].qtype)
return error(query, ERRORS.REFUSED)
for pattern, replace in rules:
rewritten_name_str, num_matches = re.subn(pattern, replace, name_str_lower)
if num_matches:
request_logger.info('Matches rule (%s, %s)', pattern, replace)
break
else:
# No break was triggered, i.e. no match
request_logger.info('Does not match a rule')
return error(query, ERRORS.REFUSED)
try:
ip_addresses = await resolve(
rewritten_name_str, TYPES.A,
get_logger_adapter=get_resolver_logger_adapter(request_logger))
except DnsRecordDoesNotExist:
request_logger.info('Does not exist')
return error(query, ERRORS.NXDOMAIN)
except DnsResponseCode as dns_response_code_error:
request_logger.info('Received error from upstream: %s',
dns_response_code_error.args[0])
return error(query, dns_response_code_error.args[0])
request_logger.info('Resolved to %s', ip_addresses)
now = loop.time()
def ttl(ip_address):
return int(max(0.0, ip_address.expires_at - now))
reponse_records = tuple(
ResourceRecord(name=name_bytes, qtype=TYPES.A,
qclass=1, ttl=ttl(ip_address), rdata=ip_address.packed)
for ip_address in ip_addresses
)
return Message(
qid=query.qid, qr=RESPONSE, opcode=0, aa=0, tc=0, rd=0, ra=1, z=0, rcode=0,
qd=query.qd, an=reponse_records, ns=(), ar=(),
)
async def start():
# The socket is created synchronously and passed to the server worker,
# so if there is an error creating it, this function will raise an
# exception. If no exeption is raise, we are indeed listening#
sock = get_socket()
# The resolver is also created synchronously, since it can parse
# /etc/hosts or /etc/resolve.conf, and can raise an exception if
# something goes wrong with that
resolve, clear_cache = get_resolver()
async def stop():
sock.close()
await clear_cache()
return create_task(server_worker(sock, resolve, stop))
return start
def error(query, rcode):
return Message(
qid=query.qid, qr=RESPONSE, opcode=0, aa=0, tc=0, rd=0, ra=1, z=0, rcode=rcode,
qd=query.qd, an=(), ns=(), ar=(),
)