bind-conf/dns_zone/src/reader.rs

293 lines
8.9 KiB
Rust

use crate::error::{Result, Error, ErrorType};
use crate::name::Name;
use crate::parser::{Parser, Token, TokenStream};
use crate::rtype::RType;
use crate::context::CtxString;
use std::collections::{HashMap, HashSet};
type Zone = HashMap<String, HashSet<ResourceRecord>>;
#[derive(Debug, PartialEq)]
pub struct ZoneReader<'a> {
tokens: TokenStream<'a>,
origin: Name,
last_domain: Option<Name>,
last_class: Option<String>,
last_ttl: Option<u32>,
zone: Zone,
}
#[derive(Debug, PartialEq, Hash, Eq)]
pub struct ResourceRecord {
pub rtype: RType,
pub class: String,
pub ttl: u32,
}
enum ClassTtl {
Class(String),
Ttl(u32)
}
impl<'a> ZoneReader<'a> {
pub fn from_str(input: &'a str) -> Self {
ZoneReader {
tokens: Parser::from_str(input).tokens(),
origin: Name::root(),
last_domain: None,
last_class: None,
last_ttl: None,
zone: HashMap::new()
}
}
pub fn origin(mut self, origin: Name) -> Self {
self.origin = origin;
self
}
pub fn class(mut self, class: &str) -> Self {
self.last_class = Some(class.into());
self
}
pub fn ttl(mut self, ttl: u32) -> Self {
self.last_ttl = Some(ttl);
self
}
pub fn read(mut self) -> Result<Zone> {
loop {
if self.forward()? == None {
break;
}
}
Ok(self.zone)
}
pub fn forward(&mut self) -> Result<Option<()>> {
let token = self.tokens.next();
if token == None {
return Ok(None);
}
Some(match token.unwrap()? {
Token::EndOfLine => Ok(()),
Token::Blank => {
let rr = self.read_record()?;
if let Some(record) = rr {
self.insert_record(None, record)
} else {
Ok(())
}
},
Token::Value(domain) => {
let rr = self.read_record()?;
if let Some(record) = rr {
self.insert_record(Some(Name::from_ctx_string(&domain)?), record)
} else {
Err(ErrorType::NoRecord.into())
}
},
Token::Control(entry) => self.interpret_control_entry(entry),
}).transpose()
}
fn insert_record(&mut self, partial_domain: Option<Name>, record: ResourceRecord) -> Result<()> {
let domain = partial_domain.or(self.last_domain.clone()).ok_or(ErrorType::NoOwner)?;
let named_domain = self.get_absolute_domain(domain);
self.last_class = Some(record.class.clone());
self.last_ttl = Some(record.ttl);
self.last_domain = Some(named_domain.clone());
let record_list = self.zone.entry(named_domain.to_string()).or_default();
record_list.insert(record);
Ok(())
}
fn get_absolute_domain(&self, domain: Name) -> Name {
self.origin.prefixed_by(domain)
}
fn read_record(&mut self) -> Result<Option<ResourceRecord>> {
let line = self.read_token_line();
if line.len() == 0 {
return Ok(None);
}
let class_ttl: Vec<_> = line.iter()
.take(2)
.map(|e| self.read_class_ttl(e.to_string()))
.take_while(|e| e.is_some())
.map(|e| e.unwrap()).collect();
let mut ttl: Option<u32> = None;
let mut class: Option<String> = None;
let index = class_ttl.len();
for element in class_ttl {
match element {
ClassTtl::Class(elem_class) => {
class = Some(elem_class)
},
ClassTtl::Ttl(elem_ttl) => {
ttl = Some(elem_ttl)
}
}
}
let rtype = line.get(index).ok_or(ErrorType::NoRecordType)?.to_string();
let data = line[(index + 1)..].to_vec();
let rr = ResourceRecord {
// TODO: error handling
rtype: RType::from_data(rtype, &self.origin, data).unwrap(),
class: class.or(self.last_class.clone()).ok_or(ErrorType::NoClass)?,
ttl: ttl.or(self.last_ttl).ok_or(ErrorType::NoTtl)?,
};
Ok(Some(rr))
}
fn read_class_ttl(&mut self, value: String) -> Option<ClassTtl> {
if value.chars().all(|c| c.is_ascii_digit()) {
Some(ClassTtl::Ttl(value.parse().unwrap()))
} else {
match value.as_str() {
"IN" | "CH" | "HS" => {
Some(ClassTtl::Class(value.to_string()))
},
_ => None
}
}
}
fn read_token_line(&mut self) -> Vec<CtxString> {
let mut line = Vec::new();
while let Some(Ok(Token::Value(value))) = self.tokens.next() {
line.push(value)
}
line
}
fn interpret_control_entry(&mut self, entry: CtxString) -> Result<()> {
let line = self.read_token_line();
if line.len() != 1 {
return Err(ErrorType::BadParameter.into());
}
match (*entry).as_str() {
"ORIGIN" => {
self.origin = self.get_absolute_domain(Name::from_ctx_string(&line[0])?);
},
// TODO: Implement that somehow
// Inlcude not implemented, silently discard
"INCLUDE" => (),
"TTL" => {
let ttl = line[0].parse().map_err(|_| ErrorType::BadParameter)?;
self.last_ttl = Some(ttl);
},
_ => return Err(ErrorType::UnknownControlEntry.into())
};
Ok(())
}
}
#[cfg(test)]
impl ResourceRecord {
fn new(ttl: u32, class: &str, rtype: &str, data: &[&str], origin: &str) -> Self {
ResourceRecord {
rtype: RType::from_data(
rtype.to_string(),
&origin.parse().unwrap(),
data.iter().map(|&s| s.into()).collect()
).unwrap(),
class: class.to_string(),
ttl: ttl,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_read_record() {
let inputs = [
"300 IN A 198.51.100.1",
"IN 300 A 198.51.100.2",
"TXT 300",
"300 TXT hello",
"A 198.51.100.3"
].iter();
let expected = [
ResourceRecord::new(300, "IN", "A", &["198.51.100.1"], "."),
ResourceRecord::new(300, "IN", "A", &["198.51.100.2"], "."),
ResourceRecord::new(600, "IN", "TXT", &["300"], "."),
ResourceRecord::new(300, "IN", "TXT", &["hello"], "."),
ResourceRecord::new(600, "IN", "A", &["198.51.100.3"], "."),
];
for (input, expected) in inputs.zip(expected.iter()) {
let mut reader = ZoneReader::from_str(input).ttl(600).class("IN");
let record = reader.read_record();
assert_eq!(*expected, record.unwrap().unwrap());
}
}
#[test]
fn test_read_zone() {
let input = r#"
$TTL 300 ; 5 minutes
$ORIGIN .
example.com IN SOA ns.example.com. admin\.domain.example.com. (
2020250101 ; serial
28800 ; refresh (8 hours)
7200 ; retry (2 hours)
2419200 ; expire (4 weeks)
300 ; minimum (5 minutes)
)
NS ns.example.com.
; an empty line
$ORIGIN example.com.
srv1 600 A 198.51.100.3
srv1.example.com. AAAA 2001:db8:cafe:bc68::2
www CNAME srv1
@ CNAME srv1"#;
let mut expected = Zone::new();
let mut record_list1 = HashSet::new();
record_list1.insert(ResourceRecord::new(
300, "IN", "SOA",
&["ns.example.com.", "admin\\.domain.example.com.", "2020250101", "28800", "7200", "2419200", "300"], "."));
record_list1.insert(ResourceRecord::new(300, "IN", "NS", &["ns.example.com."], "."));
record_list1.insert(ResourceRecord::new(600, "IN", "CNAME", &["srv1"], "example.com."));
expected.insert("example.com.".to_owned(), record_list1);
let mut record_list2 = HashSet::new();
record_list2.insert(ResourceRecord::new(600, "IN", "A", &["198.51.100.3"], "example.com."));
record_list2.insert(ResourceRecord::new(600, "IN", "AAAA", &["2001:db8:cafe:bc68::2"], "example.com."));
expected.insert("srv1.example.com.".to_owned(), record_list2);
let mut record_list3 = HashSet::new();
record_list3.insert(ResourceRecord::new(600, "IN", "CNAME", &["srv1"], "example.com."));
expected.insert("www.example.com.".to_owned(), record_list3);
let zone = ZoneReader::from_str(input)
.origin("example.com.".parse::<Name>().unwrap())
.read().unwrap();
assert_eq!(expected, zone);
}
}