Skip to content

Commit

Permalink
Add EDE retrieval helper [#969] and a get_options() helper. (#1056)
Browse files Browse the repository at this point in the history
  • Loading branch information
rthalley committed Feb 20, 2024
1 parent 45c8652 commit 6e4bb27
Show file tree
Hide file tree
Showing 2 changed files with 18 additions and 1 deletion.
10 changes: 9 additions & 1 deletion dns/message.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@
import contextlib
import io
import time
from typing import Any, Dict, List, Optional, Tuple, Union
from typing import Any, Dict, List, Optional, Tuple, Union, cast

import dns.edns
import dns.entropy
Expand Down Expand Up @@ -912,6 +912,14 @@ def set_opcode(self, opcode: dns.opcode.Opcode) -> None:
self.flags &= 0x87FF
self.flags |= dns.opcode.to_flags(opcode)

def get_options(self, otype: dns.edns.OptionType) -> List[dns.edns.Option]:
"""Return the list of options of the specified type."""
return [option for option in self.options if option.otype == otype]

def extended_errors(self) -> List[dns.edns.EDEOption]:
"""Return the list of Extended DNS Error (EDE) options in the message"""
return cast(List[dns.edns.EDEOption], self.get_options(dns.edns.OptionType.EDE))

def _get_one_rr_per_rrset(self, value):
# What the caller picked is fine.
return value
Expand Down
9 changes: 9 additions & 0 deletions tests/test_message.py
Original file line number Diff line number Diff line change
Expand Up @@ -991,6 +991,15 @@ def test_section_count_update(self):
self.assertEqual(update.section_count(dns.update.UpdateSection.PREREQ), 5)
self.assertEqual(update.section_count(dns.update.UpdateSection.UPDATE), 7)

def test_extended_errors(self):
options = [
dns.edns.EDEOption(dns.edns.EDECode.NETWORK_ERROR, "tubes not tubing"),
dns.edns.EDEOption(dns.edns.EDECode.OTHER, "catch all code"),
]
r = dns.message.make_query("example", "A", use_edns=0, options=options)
r.flags |= dns.flags.QR
self.assertEqual(r.extended_errors(), options)


if __name__ == "__main__":
unittest.main()

0 comments on commit 6e4bb27

Please sign in to comment.