Skip to content

Commit aca8568

Browse files
committed
Add localhost to default resolver
1 parent b980eb4 commit aca8568

File tree

4 files changed

+104
-0
lines changed

4 files changed

+104
-0
lines changed

dns_cache/__init__.py

Lines changed: 9 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -7,9 +7,15 @@
77

88
from .expiration import _NO_EXPIRY as NO_EXPIRY
99
from .expiration import FIVE_MINS, MinExpirationCache, NoExpirationCache
10+
from .persistence import _LayeredCache
1011
from .pickle import PickableCache
1112
from .resolver import AggressiveCachingResolver, ExceptionCachingResolver
1213

14+
try:
15+
from .hosts import HostsCache
16+
except ImportError:
17+
HostsCache = None
18+
1319
__version__ = "0.2.0"
1420

1521

@@ -49,6 +55,9 @@ def override_system_resolver(
4955
else:
5056
cache = MinExpirationCache(min_ttl=min_ttl)
5157

58+
if HostsCache:
59+
cache = _LayeredCache(HostsCache(filename=None), cache)
60+
5261
if not resolver:
5362
resolver = Resolver(configure=False)
5463
try:

tests/test_hosts.py

Lines changed: 58 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,9 +1,67 @@
1+
import socket
12
import unittest
23

4+
from dns.name import from_text
5+
from dns.rdataclass import IN
6+
from dns.rdatatype import A
7+
from dns.resolver import restore_system_resolver
8+
9+
from dns_cache import override_system_resolver
10+
from dns_cache.block import dnspython_resolver_socket_block
311
from dns_cache.hosts import loads
12+
from dns_cache.persistence import _LayeredCache
13+
from dns_cache.resolver import DNSPYTHON_2
14+
15+
from .test_upstream import orig_gethostbyname
416

517

618
class TestHostsSerializer(unittest.TestCase):
719
def test_loads(self):
820
data = loads()
921
assert data
22+
23+
24+
class TestHostsCache(unittest.TestCase):
25+
26+
def test_hit_localhost(self):
27+
name = "localhost"
28+
assert socket.gethostbyname == orig_gethostbyname
29+
30+
try:
31+
socket.gethostbyname(name)
32+
except Exception as e:
33+
raise unittest.SkipTest("gethostbyname: {}".format(e))
34+
35+
resolver = override_system_resolver()
36+
assert isinstance(resolver.cache, _LayeredCache)
37+
38+
if DNSPYTHON_2:
39+
query = resolver.resolve
40+
else:
41+
query = resolver.query
42+
43+
q1 = query(name)
44+
45+
assert len(resolver.cache._read_only_cache.data) >= 1
46+
# The layering does a put, which pushes localhost into cache2
47+
# TODO this needs to be blocked
48+
assert len(resolver.cache._writable_cache.data) == 1
49+
50+
assert q1
51+
52+
name = from_text(name)
53+
54+
assert (name, A, IN) in resolver.cache.data
55+
assert resolver.cache.get((name, A, IN))
56+
57+
with dnspython_resolver_socket_block():
58+
q2 = query(name)
59+
60+
assert q2 is q1
61+
62+
with dnspython_resolver_socket_block():
63+
ip = socket.gethostbyname(name)
64+
assert ip == "127.0.0.1"
65+
66+
restore_system_resolver()
67+
assert socket.gethostbyname == orig_gethostbyname

tests/test_persistence.py

Lines changed: 7 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@
33
import os.path
44
import pickle
55
import shutil
6+
import socket
67
import sys
78
import unittest
89

@@ -26,6 +27,8 @@
2627
DNSPYTHON_2,
2728
dnspython_resolver_socket_block,
2829
get_test_resolver,
30+
orig_gethostbyname,
31+
restore_system_resolver,
2932
)
3033

3134
try:
@@ -68,6 +71,10 @@ class _TestPersistentCacheBase(object):
6871
cache_cls = None
6972
kwargs = {}
7073

74+
def tearDown(self):
75+
restore_system_resolver()
76+
assert socket.gethostbyname == orig_gethostbyname
77+
7178
def is_jsonpickle(self):
7279
serializer = self.kwargs.get("serializer", None)
7380
if hasattr(serializer, "startswith") and serializer.startswith(

tests/test_upstream.py

Lines changed: 30 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -23,6 +23,7 @@
2323
Resolver,
2424
_getaddrinfo,
2525
override_system_resolver,
26+
restore_system_resolver,
2627
)
2728

2829
import pubdns
@@ -45,6 +46,7 @@
4546
PY2 = sys.version_info < (3, 0)
4647

4748
pd = pubdns.PubDNS()
49+
orig_gethostbyname = socket.gethostbyname
4850

4951

5052
def compare_response(a, b):
@@ -144,6 +146,10 @@ class _TestCacheBase(object):
144146
cache_cls = Cache
145147
expiration = 60 * 5
146148

149+
def tearDown(self):
150+
restore_system_resolver()
151+
assert socket.gethostbyname == orig_gethostbyname
152+
147153
def get_test_resolver(self, nameserver=None):
148154
resolver = get_test_resolver(self.resolver_cls, nameserver)
149155
resolver.cache = self.cache_cls()
@@ -571,6 +577,30 @@ def test_no_answer(self, expected_extra=0):
571577

572578
return resolver
573579

580+
def test_miss_localhost(self):
581+
name = "localhost"
582+
583+
assert socket.gethostbyname == orig_gethostbyname
584+
585+
try:
586+
socket.gethostbyname(name)
587+
except Exception as e:
588+
raise unittest.SkipTest("gethostbyname: {}".format(e))
589+
590+
resolver = self.get_test_resolver()
591+
592+
if DNSPYTHON_2:
593+
query = resolver.resolve
594+
else:
595+
query = resolver.query
596+
597+
with self.assertRaises(_SocketBlockedError):
598+
with dnspython_resolver_socket_block():
599+
query(name)
600+
601+
with self.assertRaises((NXDOMAIN, NoAnswer)):
602+
query(name)
603+
574604

575605
@expand
576606
class TestCache(_TestCacheBase, unittest.TestCase):

0 commit comments

Comments
 (0)