dns-probe/dns_probe/__init__.py

231 lines
7.7 KiB
Python

import argparse
import itertools
import logging
import dns.query
import dns.message
import dns.resolver
import dns.rdatatype
import dns.exception
from prometheus_client import CollectorRegistry, generate_latest, CONTENT_TYPE_LATEST
from prometheus_client.core import GaugeMetricFamily
from wsgiref.simple_server import make_server
from pyramid.config import Configurator
from pyramid.response import Response
from pyramid.httpexceptions import HTTPBadRequest
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
__version__ = '0.3.1'
class DnsCollector(object):
def __init__(self, zone, nameservers=[], ipv4=True, ipv6=True, query_timeout=2):
self.zone = zone
self.nameservers = nameservers
self.ipv4 = ipv4
self.ipv6 = ipv6
self.query_timeout = query_timeout
self.ns_resolve_sucess_metrics = GaugeMetricFamily(
'dns_probe_resolve_nameservers_success',
'Probe sucessfully managed to fetch the list of ns',
labels=['zone']
)
self.soa_serial_metrics = GaugeMetricFamily(
'dns_probe_soa_serial',
'Serial of SOA',
labels=['zone', 'nameserver']
)
self.rrsig_expiration_metrics = GaugeMetricFamily(
'dns_probe_soa_rrsig_expiration',
'Expiration date of DNSSEC signature',
labels=['zone', 'nameserver', 'keytag']
)
self.ns_set_metrics = GaugeMetricFamily(
'dns_probe_ns_set',
'List of nameservers',
labels=['zone', 'target', 'nameserver']
)
self.query_success_metrics = GaugeMetricFamily(
'dns_probe_query_success',
'Status of DNS query to nameserver',
labels=['name', 'type', 'nameserver']
)
def fetch_ns(self):
self.nameservers = [
addr
for ns in self.resolve(self.zone, dns.rdatatype.NS)
for addr in self.resolve_addr(ns.target)
]
def resolve_addr(self, qname):
a_records = aaaa_records = ()
if self.ipv4:
a_records = self.resolve(qname, dns.rdatatype.A)
if self.ipv6:
aaaa_records = self.resolve(qname, dns.rdatatype.AAAA)
return (addr.address for addr in itertools.chain(a_records, aaaa_records))
def resolve(self, qname, qtype):
return itertools.chain(
*(
rrset.items.keys()
for rrset in dns.resolver.resolve(qname, qtype).response.answer
if rrset.rdtype == qtype
)
)
def query(self, qname, qtype, ns, dnssec=False):
try:
res, _is_tcp = dns.query.udp_with_fallback(
dns.message.make_query(qname, qtype, want_dnssec=dnssec),
ns,
timeout=self.query_timeout
)
except (dns.exception.DNSException, OSError):
self.query_success_metrics.add_metric([qname, qtype, ns], 0)
raise
else:
self.query_success_metrics.add_metric([qname, qtype, ns], 1)
return res
def check_soa(self, ns):
res_soa = self.query(self.zone, 'SOA', ns, dnssec=True)
soa = None
rrsig_set = []
for rrset in res_soa.answer:
if rrset.rdtype == dns.rdatatype.SOA:
soa = list(rrset.items.keys())[0]
if rrset.rdtype == dns.rdatatype.RRSIG:
rrsig_set = list(rrset.items.keys())
self.soa_serial_metrics.add_metric([self.zone, ns], soa.serial)
for rrsig in rrsig_set:
self.rrsig_expiration_metrics.add_metric([self.zone, ns, str(rrsig.key_tag)], rrsig.expiration)
def list_ns(self, ns):
res_ns = self.query(self.zone, 'NS', ns)
ns_set = []
for rrset in res_ns.answer:
if rrset.rdtype == dns.rdatatype.NS:
ns_set = list(rrset.items.keys())
for ns_record in ns_set:
target = ns_record.target.to_text()
self.ns_set_metrics.add_metric([self.zone, target, ns], 1)
def collect(self):
if not self.nameservers:
try:
self.fetch_ns()
except dns.exception.Timeout:
logger.error(f'Timeout while querying for NS for zone {self.zone}')
except (dns.exception.DNSException, OSError):
logger.exception(f'Failed to get fetch nameservers for zone {self.zone}')
self.ns_resolve_sucess_metrics.add_metric([self.zone], 0)
else:
self.ns_resolve_sucess_metrics.add_metric([self.zone], 1)
for ns in self.nameservers:
try:
self.check_soa(ns)
except dns.exception.Timeout:
logger.warning(f'NS {ns} timeout while querying for SOA for zone {self.zone}')
except (dns.exception.DNSException, OSError):
logger.exception(f'Failed to get SOA metrics from nameserver {ns} for zone {self.zone}')
try:
self.list_ns(ns)
except dns.exception.Timeout:
logger.warning(f'NS {ns} timeout while querying for NS for zone {self.zone}')
except (dns.exception.DNSException, OSError):
logger.exception(f'Failed to list NS from nameserver {ns} for zone {self.zone}')
yield self.ns_resolve_sucess_metrics
yield self.ns_set_metrics
yield self.soa_serial_metrics
yield self.rrsig_expiration_metrics
yield self.query_success_metrics
def parse_bool(val, name, default):
if val is None:
return default
if val.lower() in ('true', '1', 'enabled', 'yes'):
return True
if val.lower() in ('false', '0', 'disabled', 'no'):
return False
raise HTTPBadRequest(f'unknown value for param {name}')
def parse_float(val, name, default):
if val is None:
return default
try:
return float(val)
except ValueError:
raise HTTPBadRequest(f'Could not convert the value of {name} to float')
def probe_view(request):
zone = request.params.get('zone')
nameservers = request.params.getall('nameservers[]')
ipv4 = parse_bool(request.params.get('ipv4'), 'ipv4', True)
ipv6 = parse_bool(request.params.get('ipv6'), 'ipv6', True)
query_timeout = parse_float(request.params.get('query_timeout'), 'query_timeout', 2)
if zone is not None:
if not zone.endswith('.'):
zone += '.'
registry = CollectorRegistry()
registry.register(DnsCollector(zone, nameservers=nameservers, ipv4=ipv4, ipv6=ipv6, query_timeout=query_timeout))
data = generate_latest(registry)
return Response(data, content_type=CONTENT_TYPE_LATEST)
else:
raise HTTPBadRequest('zone parameter is required')
def make_app():
with Configurator() as config:
config.add_route('probe', '/probe')
config.add_view(probe_view, route_name='probe', request_method='GET')
app = config.make_wsgi_app()
return app
def serve(ip='127.0.0.1', port=8953):
web_server = make_server(ip, port, make_app())
logger.info(f'Starting webserver on {ip}:{port}')
web_server.serve_forever()
def parse_listen(listen_str):
ip, _sep, port = listen_str.rpartition(':')
if ip == '':
ip = '0.0.0.0'
if port == '':
raise ValueError('Port can not be empty')
return {'ip': ip, 'port': int(port)}
def main():
parser = argparse.ArgumentParser(description='DNS probe that exports Prometheus-like data')
parser.add_argument('-l', '--listen', default='127.0.0.1:8953', help='Address to listen to, default %(default)s', type=parse_listen)
args = parser.parse_args()
serve(**args.listen)
if __name__ == '__main__':
main()