Source code for pkb_client.client.bind_file

import logging
from dataclasses import dataclass
from enum import Enum
from typing import Optional, List

from pkb_client.client.dns import DNSRecordType, DNS_RECORDS_WITH_PRIORITY


[docs] class RecordClass(str, Enum): IN = "IN" def __str__(self): return self.value
[docs] @dataclass class BindRecord: name: str ttl: int record_class: RecordClass record_type: DNSRecordType data: str prio: Optional[int] = None comment: Optional[str] = None def __str__(self): record_string = f"{self.name} {self.ttl} {self.record_class} {self.record_type}" if self.prio is not None: record_string += f" {self.prio}" record_string += f" {self.data}" if self.comment: record_string += f" ; {self.comment}" return record_string
[docs] class BindFile: origin: str ttl: Optional[int] = None records: List[BindRecord] def __init__( self, origin: str, ttl: Optional[int] = None, records: Optional[List[BindRecord]] = None, ) -> None: self.origin = origin self.ttl = ttl self.records = records or []
[docs] @staticmethod def from_file(file_path: str) -> "BindFile": with open(file_path, "r") as f: file_data = f.readlines() # parse the file line by line origin = None ttl = None records = [] for line in file_data: if line.startswith("$ORIGIN"): origin = line.split()[1] elif line.startswith("$TTL"): ttl = int(line.split()[1]) else: # parse the records with the two possible formats: # 1: name ttl record-class record-type record-data # 2: name record-class ttl record-type record-data # whereby the ttl is optional # drop any right trailing comments line_parts = line.split(";", 1) line = line_parts[0].strip() comment = line_parts[1].strip() if len(line_parts) > 1 else None prio = None # skip empty lines if not line: continue # find which format the line is record_parts = line.split() if record_parts[1].isdigit(): # scheme 1 if record_parts[3] not in DNSRecordType.__members__: logging.warning(f"Ignoring unsupported record type: {line}") continue if record_parts[2] not in RecordClass.__members__: logging.warning(f"Ignoring unsupported record class: {line}") continue record_name = record_parts[0] record_ttl = int(record_parts[1]) record_class = RecordClass[record_parts[2]] record_type = DNSRecordType[record_parts[3]] if record_type in DNS_RECORDS_WITH_PRIORITY: prio = int(record_parts[4]) record_data = " ".join(record_parts[5:]) else: record_data = " ".join(record_parts[4:]) elif record_parts[2].isdigit(): # scheme 2 if record_parts[3] not in DNSRecordType.__members__: logging.warning(f"Ignoring unsupported record type: {line}") continue if record_parts[1] not in RecordClass.__members__: logging.warning(f"Ignoring unsupported record class: {line}") continue record_name = record_parts[0] record_ttl = int(record_parts[2]) record_class = RecordClass[record_parts[1]] record_type = DNSRecordType[record_parts[3]] if record_type in DNS_RECORDS_WITH_PRIORITY: prio = int(record_parts[4]) record_data = " ".join(record_parts[5:]) else: record_data = " ".join(record_parts[4:]) else: # no ttl, use default or previous if record_parts[2] not in DNSRecordType.__members__: logging.warning(f"Ignoring unsupported record type: {line}") continue if record_parts[1] not in RecordClass.__members__: logging.warning(f"Ignoring unsupported record class: {line}") continue record_name = record_parts[0] if ttl is None and not records: raise ValueError("No TTL found in file") record_ttl = ttl or records[-1].ttl record_class = RecordClass[record_parts[1]] record_type = DNSRecordType[record_parts[2]] if record_type in DNS_RECORDS_WITH_PRIORITY: prio = int(record_parts[3]) record_data = " ".join(record_parts[4:]) else: record_data = " ".join(record_parts[3:]) # replace @ in record name with origin record_name = record_name.replace("@", origin) records.append( BindRecord( record_name, record_ttl, record_class, record_type, record_data, prio=prio, comment=comment, ) ) if origin is None: raise ValueError("No origin found in file") return BindFile(origin, ttl, records)
[docs] def to_file(self, file_path: str) -> None: with open(file_path, "w") as f: f.write(str(self))
def __str__(self) -> str: bind = f"$ORIGIN {self.origin}\n" if self.ttl is not None: bind += f"$TTL {self.ttl}\n" for record in self.records: bind += f"{record}\n" return bind