dnsfixture.py 4.6 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130
  1. # SPDX-FileCopyrightText: 2024 Espressif Systems (Shanghai) CO LTD
  2. # SPDX-License-Identifier: Unlicense OR CC0-1.0
  3. import logging
  4. import re
  5. import socket
  6. import sys
  7. import dns.message
  8. import dns.query
  9. import dns.rdataclass
  10. import dns.rdatatype
  11. import dns.resolver
  12. # Configure logging
  13. logging.basicConfig(level=logging.INFO)
  14. logger = logging.getLogger(__name__)
  15. class DnsPythonWrapper:
  16. def __init__(self, server='224.0.0.251', port=5353, retries=3):
  17. self.server = server
  18. self.port = port
  19. self.retries = retries
  20. def send_and_receive_query(self, query, timeout=3):
  21. logger.info(f'Sending DNS query to {self.server}:{self.port}')
  22. try:
  23. # Create a UDP socket
  24. with socket.socket(socket.AF_INET, socket.SOCK_DGRAM) as sock:
  25. sock.settimeout(timeout)
  26. # Send the DNS query
  27. query_data = query.to_wire()
  28. sock.sendto(query_data, (self.server, self.port))
  29. # Receive the DNS response
  30. response_data, _ = sock.recvfrom(512) # 512 bytes is the typical size for a DNS response
  31. # Parse the response
  32. response = dns.message.from_wire(response_data)
  33. return response
  34. except socket.timeout as e:
  35. logger.warning(f'DNS query timed out: {e}')
  36. return None
  37. except dns.exception.DNSException as e:
  38. logger.error(f'DNS query failed: {e}')
  39. return None
  40. def run_query(self, name, query_type='PTR', timeout=3):
  41. logger.info(f'Running DNS query for {name} with type {query_type}')
  42. query = dns.message.make_query(name, dns.rdatatype.from_text(query_type), dns.rdataclass.IN)
  43. # Print the DNS question section
  44. logger.info(f'DNS question section: {query.question}')
  45. # Send and receive the DNS query
  46. response = None
  47. for attempt in range(1, self.retries + 1):
  48. logger.info(f'Attempt {attempt}/{self.retries}')
  49. response = self.send_and_receive_query(query, timeout)
  50. if response:
  51. break
  52. if response:
  53. logger.info(f'DNS query response:\n{response}')
  54. else:
  55. logger.warning('No response received or response was invalid.')
  56. return response
  57. def parse_answer_section(self, response, query_type):
  58. answers = []
  59. if response:
  60. for answer in response.answer:
  61. if dns.rdatatype.to_text(answer.rdtype) == query_type:
  62. for item in answer.items:
  63. full_answer = (
  64. f'{answer.name} {answer.ttl} '
  65. f'{dns.rdataclass.to_text(answer.rdclass)} '
  66. f'{dns.rdatatype.to_text(answer.rdtype)} '
  67. f'{item.to_text()}'
  68. )
  69. answers.append(full_answer)
  70. return answers
  71. def check_record(self, name, query_type, expected=True, expect=None):
  72. output = self.run_query(name, query_type=query_type)
  73. answers = self.parse_answer_section(output, query_type)
  74. logger.info(f'answers: {answers}')
  75. if expect is None:
  76. expect = name
  77. if expected:
  78. assert any(expect in answer for answer in answers), f"Expected record '{expect}' not found in answer section"
  79. else:
  80. assert not any(expect in answer for answer in answers), f"Unexpected record '{expect}' found in answer section"
  81. if __name__ == '__main__':
  82. if len(sys.argv) < 3:
  83. print('Usage: python dns_fixture.py <query_type> <name>')
  84. sys.exit(1)
  85. query_type = sys.argv[1]
  86. name = sys.argv[2]
  87. ip_only = len(sys.argv) > 3 and sys.argv[3] == '--ip_only'
  88. if ip_only:
  89. logger.setLevel(logging.WARNING)
  90. dns_wrapper = DnsPythonWrapper()
  91. if query_type == 'X' and '.' in name:
  92. # Sends an IPv4 reverse query
  93. reversed_ip = '.'.join(reversed(name.split('.')))
  94. name = f'{reversed_ip}.in-addr.arpa'
  95. query_type = 'PTR'
  96. response = dns_wrapper.run_query(name, query_type=query_type)
  97. answers = dns_wrapper.parse_answer_section(response, query_type)
  98. if answers:
  99. for answer in answers:
  100. logger.info(f'DNS query response: {answer}')
  101. if ip_only:
  102. ipv4_pattern = re.compile(r'\b(?:\d{1,3}\.){3}\d{1,3}\b')
  103. ipv4_addresses = ipv4_pattern.findall(answer)
  104. if ipv4_addresses:
  105. print(f"{', '.join(ipv4_addresses)}")
  106. else:
  107. logger.info(f'No response for {name} with query type {query_type}')
  108. exit(9) # Same as dig timeout