231 lines
7.7 KiB
Python
231 lines
7.7 KiB
Python
import argparse
|
|
import itertools
|
|
import logging
|
|
|
|
import dns.query
|
|
import dns.message
|
|
import dns.resolver
|
|
import dns.rdatatype
|
|
import dns.exception
|
|
|
|
from prometheus_client import CollectorRegistry, generate_latest, CONTENT_TYPE_LATEST
|
|
from prometheus_client.core import GaugeMetricFamily
|
|
|
|
from wsgiref.simple_server import make_server
|
|
from pyramid.config import Configurator
|
|
from pyramid.response import Response
|
|
from pyramid.httpexceptions import HTTPBadRequest
|
|
|
|
logging.basicConfig(level=logging.INFO)
|
|
logger = logging.getLogger(__name__)
|
|
|
|
__version__ = '0.3.1'
|
|
|
|
|
|
class DnsCollector(object):
|
|
def __init__(self, zone, nameservers=[], ipv4=True, ipv6=True, query_timeout=2):
|
|
self.zone = zone
|
|
self.nameservers = nameservers
|
|
self.ipv4 = ipv4
|
|
self.ipv6 = ipv6
|
|
self.query_timeout = query_timeout
|
|
|
|
self.ns_resolve_sucess_metrics = GaugeMetricFamily(
|
|
'dns_probe_resolve_nameservers_success',
|
|
'Probe sucessfully managed to fetch the list of ns',
|
|
labels=['zone']
|
|
)
|
|
|
|
self.soa_serial_metrics = GaugeMetricFamily(
|
|
'dns_probe_soa_serial',
|
|
'Serial of SOA',
|
|
labels=['zone', 'nameserver']
|
|
)
|
|
|
|
self.rrsig_expiration_metrics = GaugeMetricFamily(
|
|
'dns_probe_soa_rrsig_expiration',
|
|
'Expiration date of DNSSEC signature',
|
|
labels=['zone', 'nameserver', 'keytag']
|
|
)
|
|
|
|
self.ns_set_metrics = GaugeMetricFamily(
|
|
'dns_probe_ns_set',
|
|
'List of nameservers',
|
|
labels=['zone', 'target', 'nameserver']
|
|
)
|
|
|
|
self.query_success_metrics = GaugeMetricFamily(
|
|
'dns_probe_query_success',
|
|
'Status of DNS query to nameserver',
|
|
labels=['name', 'type', 'nameserver']
|
|
)
|
|
|
|
def fetch_ns(self):
|
|
self.nameservers = [
|
|
addr
|
|
for ns in self.resolve(self.zone, dns.rdatatype.NS)
|
|
for addr in self.resolve_addr(ns.target)
|
|
]
|
|
|
|
def resolve_addr(self, qname):
|
|
a_records = aaaa_records = ()
|
|
if self.ipv4:
|
|
a_records = self.resolve(qname, dns.rdatatype.A)
|
|
if self.ipv6:
|
|
aaaa_records = self.resolve(qname, dns.rdatatype.AAAA)
|
|
|
|
return (addr.address for addr in itertools.chain(a_records, aaaa_records))
|
|
|
|
def resolve(self, qname, qtype):
|
|
return itertools.chain(
|
|
*(
|
|
rrset.items.keys()
|
|
for rrset in dns.resolver.resolve(qname, qtype).response.answer
|
|
if rrset.rdtype == qtype
|
|
)
|
|
)
|
|
|
|
def query(self, qname, qtype, ns, dnssec=False):
|
|
try:
|
|
res, _is_tcp = dns.query.udp_with_fallback(
|
|
dns.message.make_query(qname, qtype, want_dnssec=dnssec),
|
|
ns,
|
|
timeout=self.query_timeout
|
|
)
|
|
except (dns.exception.DNSException, OSError):
|
|
self.query_success_metrics.add_metric([qname, qtype, ns], 0)
|
|
raise
|
|
else:
|
|
self.query_success_metrics.add_metric([qname, qtype, ns], 1)
|
|
return res
|
|
|
|
def check_soa(self, ns):
|
|
res_soa = self.query(self.zone, 'SOA', ns, dnssec=True)
|
|
soa = None
|
|
rrsig_set = []
|
|
|
|
for rrset in res_soa.answer:
|
|
if rrset.rdtype == dns.rdatatype.SOA:
|
|
soa = list(rrset.items.keys())[0]
|
|
if rrset.rdtype == dns.rdatatype.RRSIG:
|
|
rrsig_set = list(rrset.items.keys())
|
|
|
|
self.soa_serial_metrics.add_metric([self.zone, ns], soa.serial)
|
|
|
|
for rrsig in rrsig_set:
|
|
self.rrsig_expiration_metrics.add_metric([self.zone, ns, str(rrsig.key_tag)], rrsig.expiration)
|
|
|
|
def list_ns(self, ns):
|
|
res_ns = self.query(self.zone, 'NS', ns)
|
|
ns_set = []
|
|
for rrset in res_ns.answer:
|
|
if rrset.rdtype == dns.rdatatype.NS:
|
|
ns_set = list(rrset.items.keys())
|
|
for ns_record in ns_set:
|
|
target = ns_record.target.to_text()
|
|
self.ns_set_metrics.add_metric([self.zone, target, ns], 1)
|
|
|
|
def collect(self):
|
|
if not self.nameservers:
|
|
try:
|
|
self.fetch_ns()
|
|
except dns.exception.Timeout:
|
|
logger.error(f'Timeout while querying for NS for zone {self.zone}')
|
|
except (dns.exception.DNSException, OSError):
|
|
logger.exception(f'Failed to get fetch nameservers for zone {self.zone}')
|
|
self.ns_resolve_sucess_metrics.add_metric([self.zone], 0)
|
|
else:
|
|
self.ns_resolve_sucess_metrics.add_metric([self.zone], 1)
|
|
|
|
for ns in self.nameservers:
|
|
try:
|
|
self.check_soa(ns)
|
|
except dns.exception.Timeout:
|
|
logger.warning(f'NS {ns} timeout while querying for SOA for zone {self.zone}')
|
|
except (dns.exception.DNSException, OSError):
|
|
logger.exception(f'Failed to get SOA metrics from nameserver {ns} for zone {self.zone}')
|
|
|
|
try:
|
|
self.list_ns(ns)
|
|
except dns.exception.Timeout:
|
|
logger.warning(f'NS {ns} timeout while querying for NS for zone {self.zone}')
|
|
except (dns.exception.DNSException, OSError):
|
|
logger.exception(f'Failed to list NS from nameserver {ns} for zone {self.zone}')
|
|
|
|
yield self.ns_resolve_sucess_metrics
|
|
yield self.ns_set_metrics
|
|
yield self.soa_serial_metrics
|
|
yield self.rrsig_expiration_metrics
|
|
yield self.query_success_metrics
|
|
|
|
|
|
def parse_bool(val, name, default):
|
|
if val is None:
|
|
return default
|
|
if val.lower() in ('true', '1', 'enabled', 'yes'):
|
|
return True
|
|
if val.lower() in ('false', '0', 'disabled', 'no'):
|
|
return False
|
|
raise HTTPBadRequest(f'unknown value for param {name}')
|
|
|
|
def parse_float(val, name, default):
|
|
if val is None:
|
|
return default
|
|
try:
|
|
return float(val)
|
|
except ValueError:
|
|
raise HTTPBadRequest(f'Could not convert the value of {name} to float')
|
|
|
|
def probe_view(request):
|
|
zone = request.params.get('zone')
|
|
nameservers = request.params.getall('nameservers[]')
|
|
ipv4 = parse_bool(request.params.get('ipv4'), 'ipv4', True)
|
|
ipv6 = parse_bool(request.params.get('ipv6'), 'ipv6', True)
|
|
query_timeout = parse_float(request.params.get('query_timeout'), 'query_timeout', 2)
|
|
|
|
if zone is not None:
|
|
if not zone.endswith('.'):
|
|
zone += '.'
|
|
|
|
registry = CollectorRegistry()
|
|
registry.register(DnsCollector(zone, nameservers=nameservers, ipv4=ipv4, ipv6=ipv6, query_timeout=query_timeout))
|
|
data = generate_latest(registry)
|
|
|
|
return Response(data, content_type=CONTENT_TYPE_LATEST)
|
|
else:
|
|
raise HTTPBadRequest('zone parameter is required')
|
|
|
|
|
|
def make_app():
|
|
with Configurator() as config:
|
|
config.add_route('probe', '/probe')
|
|
config.add_view(probe_view, route_name='probe', request_method='GET')
|
|
app = config.make_wsgi_app()
|
|
return app
|
|
|
|
|
|
def serve(ip='127.0.0.1', port=8953):
|
|
web_server = make_server(ip, port, make_app())
|
|
logger.info(f'Starting webserver on {ip}:{port}')
|
|
web_server.serve_forever()
|
|
|
|
|
|
def parse_listen(listen_str):
|
|
ip, _sep, port = listen_str.rpartition(':')
|
|
if ip == '':
|
|
ip = '0.0.0.0'
|
|
if port == '':
|
|
raise ValueError('Port can not be empty')
|
|
return {'ip': ip, 'port': int(port)}
|
|
|
|
|
|
def main():
|
|
parser = argparse.ArgumentParser(description='DNS probe that exports Prometheus-like data')
|
|
parser.add_argument('-l', '--listen', default='127.0.0.1:8953', help='Address to listen to, default %(default)s', type=parse_listen)
|
|
args = parser.parse_args()
|
|
serve(**args.listen)
|
|
|
|
|
|
if __name__ == '__main__':
|
|
main()
|