Skip to content

Commit

Permalink
Further improve CVE fix coverage to 100% for sync and async.
Browse files Browse the repository at this point in the history
  • Loading branch information
rthalley committed Feb 16, 2024
1 parent ac6763f commit a1a9989
Show file tree
Hide file tree
Showing 2 changed files with 204 additions and 1 deletion.
184 changes: 183 additions & 1 deletion tests/test_async.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,7 +18,6 @@
import asyncio
import random
import socket
import sys
import time
import unittest

Expand All @@ -28,6 +27,7 @@
import dns.message
import dns.name
import dns.query
import dns.rcode
import dns.rdataclass
import dns.rdatatype
import dns.resolver
Expand Down Expand Up @@ -664,3 +664,185 @@ def async_run(self, afunc):

except ImportError:
pass


class MockSock:
def __init__(self, wire1, from1, wire2, from2):
self.family = socket.AF_INET
self.first_time = True
self.wire1 = wire1
self.from1 = from1
self.wire2 = wire2
self.from2 = from2

async def sendto(self, data, where, timeout):
return len(data)

async def recvfrom(self, bufsize, expiration):
if self.first_time:
self.first_time = False
return self.wire1, self.from1
else:
return self.wire2, self.from2


class IgnoreErrors(unittest.TestCase):
def setUp(self):
self.q = dns.message.make_query("example.", "A")
self.good_r = dns.message.make_response(self.q)
self.good_r.set_rcode(dns.rcode.NXDOMAIN)
self.good_r_wire = self.good_r.to_wire()
dns.asyncbackend.set_default_backend("asyncio")

def async_run(self, afunc):
return asyncio.run(afunc())

async def mock_receive(
self,
wire1,
from1,
wire2,
from2,
ignore_unexpected=True,
ignore_errors=True,
):
s = MockSock(wire1, from1, wire2, from2)
(r, when, _) = await dns.asyncquery.receive_udp(
s,
("127.0.0.1", 53),
time.time() + 2,
ignore_unexpected=ignore_unexpected,
ignore_errors=ignore_errors,
query=self.q,
)
self.assertEqual(r, self.good_r)

def test_good_mock(self):
async def run():
await self.mock_receive(self.good_r_wire, ("127.0.0.1", 53), None, None)

self.async_run(run)

def test_bad_address(self):
async def run():
await self.mock_receive(
self.good_r_wire, ("127.0.0.2", 53), self.good_r_wire, ("127.0.0.1", 53)
)

self.async_run(run)

def test_bad_address_not_ignored(self):
async def abad():
await self.mock_receive(
self.good_r_wire,
("127.0.0.2", 53),
self.good_r_wire,
("127.0.0.1", 53),
ignore_unexpected=False,
)

def bad():
self.async_run(abad)

self.assertRaises(dns.query.UnexpectedSource, bad)

def test_not_response_not_ignored_udp_level(self):
async def abad():
bad_r = dns.message.make_response(self.q)
bad_r.id += 1
bad_r_wire = bad_r.to_wire()
s = MockSock(
bad_r_wire, ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53)
)
await dns.asyncquery.udp(self.good_r, "127.0.0.1", sock=s)

def bad():
self.async_run(abad)

self.assertRaises(dns.query.BadResponse, bad)

def test_bad_id(self):
async def run():
bad_r = dns.message.make_response(self.q)
bad_r.id += 1
bad_r_wire = bad_r.to_wire()
await self.mock_receive(
bad_r_wire, ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53)
)

self.async_run(run)

def test_bad_id_not_ignored(self):
bad_r = dns.message.make_response(self.q)
bad_r.id += 1
bad_r_wire = bad_r.to_wire()

async def abad():
(r, wire) = await self.mock_receive(
bad_r_wire,
("127.0.0.1", 53),
self.good_r_wire,
("127.0.0.1", 53),
ignore_errors=False,
)

def bad():
self.async_run(abad)

self.assertRaises(AssertionError, bad)

def test_bad_wire(self):
async def run():
bad_r = dns.message.make_response(self.q)
bad_r.id += 1
bad_r_wire = bad_r.to_wire()
await self.mock_receive(
bad_r_wire[:10], ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53)
)

self.async_run(run)

def test_bad_wire_not_ignored(self):
bad_r = dns.message.make_response(self.q)
bad_r.id += 1
bad_r_wire = bad_r.to_wire()

async def abad():
await self.mock_receive(
bad_r_wire[:10],
("127.0.0.1", 53),
self.good_r_wire,
("127.0.0.1", 53),
ignore_errors=False,
)

def bad():
self.async_run(abad)

self.assertRaises(dns.message.ShortHeader, bad)

def test_trailing_wire(self):
async def run():
wire = self.good_r_wire + b"abcd"
await self.mock_receive(
wire, ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53)
)

self.async_run(run)

def test_trailing_wire_not_ignored(self):
wire = self.good_r_wire + b"abcd"

async def abad():
await self.mock_receive(
wire,
("127.0.0.1", 53),
self.good_r_wire,
("127.0.0.1", 53),
ignore_errors=False,
)

def bad():
self.async_run(abad)

self.assertRaises(dns.message.TrailingJunk, bad)
21 changes: 21 additions & 0 deletions tests/test_query.py
Original file line number Diff line number Diff line change
Expand Up @@ -683,6 +683,14 @@ def mock(sock, max_size, expiration):
dns.query._udp_recv = saved


class MockSock:
def __init__(self):
self.family = socket.AF_INET

def sendto(self, data, where):
return len(data)


class IgnoreErrors(unittest.TestCase):
def setUp(self):
self.q = dns.message.make_query("example.", "A")
Expand Down Expand Up @@ -758,6 +766,19 @@ def bad():

self.assertRaises(AssertionError, bad)

def test_not_response_not_ignored_udp_level(self):
def bad():
bad_r = dns.message.make_response(self.q)
bad_r.id += 1
bad_r_wire = bad_r.to_wire()
with mock_udp_recv(
bad_r_wire, ("127.0.0.1", 53), self.good_r_wire, ("127.0.0.1", 53)
):
s = MockSock()
dns.query.udp(self.good_r, "127.0.0.1", sock=s)

self.assertRaises(dns.query.BadResponse, bad)

def test_bad_wire(self):
bad_r = dns.message.make_response(self.q)
bad_r.id += 1
Expand Down

0 comments on commit a1a9989

Please sign in to comment.