From f52b13ce4a6364898e7ca68c5a9971a8c6a7adb2 Mon Sep 17 00:00:00 2001 From: Chris Caron Date: Sun, 24 Jul 2016 00:04:18 -0400 Subject: [PATCH] massive refactoring; all tests pass now on python 2.7.10 --- newsreap/lib/NNTPConnection.py | 83 +-- newsreap/lib/SocketBase.py | 246 +++++-- newsreap/tests/NNTPConnection_Test.py | 124 +--- newsreap/tests/NNTPSocketServer.py | 866 ++++++++++++------------- newsreap/tests/NNTPYencArticle_Test.py | 51 +- newsreap/tests/TestBase.py | 1 - 6 files changed, 677 insertions(+), 694 deletions(-) diff --git a/newsreap/lib/NNTPConnection.py b/newsreap/lib/NNTPConnection.py index 3b31810..e5577ec 100644 --- a/newsreap/lib/NNTPConnection.py +++ b/newsreap/lib/NNTPConnection.py @@ -36,7 +36,6 @@ from lib.NNTPIOStream import NNTP_DEFAULT_ENCODING from lib.SocketBase import SocketBase -from lib.SocketBase import ConnectionType from lib.SocketBase import SocketException from lib.SocketBase import SignalCaughtException from lib.Utils import mkdir @@ -88,7 +87,7 @@ # Group Response (when we switch to a group) NNTP_GROUP_RESPONSE_RE = re.compile( - r'(?P[0-9]+)\s+(?P[0-9]+)\s+' + \ + r'(?P[0-9]+)\s+(?P[0-9]+)\s+' + r'(?P[0-9]+)\s+(?P.*[^\s]*[^.])(\.|\s)*$', ) @@ -132,14 +131,15 @@ # query fails NNTP_XOVER_RETRIES = 5 -#RAW_TEXT_MESSAGE = re.compile( -# r'^message-id:\s*]+)>?\s*$', -#) +# RAW_TEXT_MESSAGE = re.compile( +# r'^message-id:\s*]+)>?\s*$', +# ) # Used with BytesIO seek(). These variables # are part of python v2.7 but included here for python v2.6 # support too. + class NNTPConnection(SocketBase): """ NNTPConnection is a class that wraps and eases the comunication @@ -229,10 +229,10 @@ def __init__(self, username=None, password=None, secure=False, # get connection mode if secure: - kwargs['mode'] = ConnectionType.SECURE_CONNECT + kwargs['secure'] = True self.protocol = 'nntps' else: - kwargs['mode'] = ConnectionType.CONNECT + kwargs['secure'] = False self.protocol = 'nntp' # Default Filters @@ -274,24 +274,24 @@ def __init__(self, username=None, password=None, secure=False, if self._iostream not in NNTP_SUPPORTED_IO_STREAMS: # Default self._iostream = NNTPIOStream.RFC3977_GZIP - logger.warning('An unknown iostream was specified; ' + \ + logger.warning('An unknown iostream was specified; ' + "using default: '%s'." % self._iostream) except (TypeError, ValueError): # Default self._iostream = NNTPIOStream.RFC3977_GZIP - logger.warning('A malformed iostream was specified; ' + \ + logger.warning('A malformed iostream was specified; ' + "using default: '%s'." % self._iostream) elif iostream: # Default - logger.warning('An invalid iostream was specified; ' + \ + logger.warning('An invalid iostream was specified; ' + "using default: '%s'." % self._iostream) self._iostream = NNTPIOStream.RFC3977_GZIP else: # iostream is None (0, or False) # used RFC3977 Standards but without Compresssion - logger.info('An invalid iostream was specified; ' + \ + logger.info('An invalid iostream was specified; ' + "using default: '%s'." % self._iostream) self._iostream = NNTPIOStream.RFC3977 @@ -369,14 +369,6 @@ def __init__(self, username=None, password=None, secure=False, # Used to cache group list responses self._grouplist = None - - def __del__(self): - """ - Handle Deconstruction - """ - self.close() - - def append(self, connection, *args, **kwargs): """ Add a backup NNTP Server (Block Account) which is only @@ -390,7 +382,6 @@ def append(self, connection, *args, **kwargs): self._backups.append(connection) return True - def connect(self, *args, **kwargs): """ Establishes a connection to an NNTP Server @@ -405,7 +396,6 @@ def connect(self, *args, **kwargs): # call _connect() return self._connect(*args, **kwargs) - def _connect(self, *args, **kwargs): """ _connect() performs the actual connection with little @@ -476,7 +466,6 @@ def _connect(self, *args, **kwargs): return True - def post(self, payload): """ Allows posting content to a NNTP Server @@ -493,8 +482,7 @@ def post(self, payload): return response # The server - #header=('From: %s' % terry@richard.geek.org.au - + # header=('From: %s' % terry@richard.geek.org.au # STUB: TODO # payload should be a class that takes all the required @@ -504,7 +492,6 @@ def post(self, payload): # contents and closing it afterwards. return NNTPResponse(239, 'Article transferred OK') - def group(self, name): """ Changes to a specific group @@ -559,7 +546,6 @@ def group(self, name): logger.warning('Bad Group: %s' % name) return (None, None, None, self.group_name) - def groups(self, filters=None, lazy=True): """ Retrieves a list of groups from the server and returns them in an easy @@ -616,16 +602,16 @@ def groups(self, filters=None, lazy=True): elif isinstance(filter, basestring): try: - filter =r'^.*%s.*$' % re.escape(filter) - _filters.append(re.compile( - filter, flags=re.IGNORECASE), + filter = r'^.*%s.*$' % re.escape(filter) + _filters.append( + re.compile(filter, flags=re.IGNORECASE), ) logger.debug('Compiled group regex "%s"' % filter) except: logger.error( - 'Invalid group regular expression: "%s"' % filter, - ) + 'Invalid group regular expression: "%s"' % filter, + ) else: logger.error( 'Ignored group expression: "%s"' % filter, @@ -674,14 +660,12 @@ def groups(self, filters=None, lazy=True): )) return self._grouplist - def tell(self): """ Returns the current index """ return self.group_index - def seek_by_date(self, refdate, group=None): """ Similar to the seek() function in the sense it changes the @@ -740,7 +724,6 @@ def seek_by_date(self, refdate, group=None): logger.info('Matched index: %d' % (self.group_index)) return self.group_index - def _seek_by_date(self, refdate, head=None, tail=None): """ @@ -788,13 +771,13 @@ def _seek_by_date(self, refdate, head=None, tail=None): if response is None: # Nothing Retrieved return -1 - #if response.code != 423: + # if response.code != 423: # # 423 means there were no more items to fetch # # this is a common error that even this class produces # # and therefore we do not need to create a verbose # # message from it. # logger.error('NNTP Server responded %s' % response) - #return -1 + # return -1 # Deal with our response if len(response): @@ -807,12 +790,12 @@ def _seek_by_date(self, refdate, head=None, tail=None): # dict() _refkeys = response.keys() - #logger.debug('total=%d, left=%s, right=%s, ref=%s' % ( + # logger.debug('total=%d, left=%s, right=%s, ref=%s' % ( # end-start, # _refkeys[0], # _refkeys[-1], # _refdate, - #)) + # )) # # Decisions @@ -876,7 +859,6 @@ def _seek_by_date(self, refdate, head=None, tail=None): # direction; we need to return -1 - def seek(self, index, whence=None): """ Sets a default nntp index @@ -951,7 +933,8 @@ def prev(self, count=50000): return response - def xover(self, group=None, start=None, end=None, sort=XoverGrouping.BY_POSTER_TIME): + def xover(self, group=None, start=None, end=None, + sort=XoverGrouping.BY_POSTER_TIME): """ xover Returns a NNTPRequest object @@ -1678,12 +1661,15 @@ def _recv(self, decoders=None, timeout=None): # when retrieving server-side listings; it's best to # just alert the end user and move along logger.warning( - '_recv() %d byte(s) ZLIB decompression failure.' % ( - bytes, - )) + '_recv() %d byte(s) ZLIB decompression failure.' \ + % (bytes), + ) # Convert our response to that of an response Fetch # Error - return NNTPResponse(NNTPResponseCode.FETCH_ERROR, 'Fetch Error') + return NNTPResponse( + NNTPResponseCode.FETCH_ERROR, + 'Fetch Error', + ) else: # No compression self._data.write(self._buffer.read(tail_ptr-head_ptr)) @@ -1848,7 +1834,7 @@ def close(self): if self.connected: try: # Prevent Recursion by calling parent send() - super(NNTPConnection, self).send('QUIT') + super(NNTPConnection, self).send('QUIT' + EOL) except: pass @@ -1891,6 +1877,11 @@ def _hard_reset(self, wait=True): self.close() self.connect() + def __del__(self): + """ + Handle Deconstruction + """ + self.close() def __str__(self): return '%s://%s@%s:%d' % ( @@ -1901,7 +1892,6 @@ def __unicode__(self): return u'%s://%s@%s:%d' % ( self.protocol, self.username, self.host, self.port) - def __repr__(self): """ Return a printable object @@ -1913,4 +1903,3 @@ def __repr__(self): self.host, self.port, ) - diff --git a/newsreap/lib/SocketBase.py b/newsreap/lib/SocketBase.py index fd68388..0d0ba8c 100644 --- a/newsreap/lib/SocketBase.py +++ b/newsreap/lib/SocketBase.py @@ -40,16 +40,6 @@ # or fail to connect before failing out right socket.setdefaulttimeout(30.0) -class ConnectionType: - # Establish a remote connection - CONNECT = 'connect' - # Establish a secure remote connection - SECURE_CONNECT = 'ssl_connect' - # Listen for a connection - LISTEN = 'listen' - # Listen for a secure connection - SECURE_LISTEN = 'ssl_listen' - try: SECURE_PROTOCOL_PRIORITY = ( # The first element is the default priority @@ -114,13 +104,15 @@ class SocketBase(object): - connection timeout - mode Define how this socket will be used (see - ConnectionType class for more details) + secure Use encryption when managing the connection + Expects a True or False, but you can also + Specify the encryption Cypher to use here + too. """ def __init__(self, host=None, port=0, bindaddr=None, bindport=0, - mode=ConnectionType.CONNECT, *args, **kwargs): + secure=False, *args, **kwargs): try: self.port = int(port) @@ -132,18 +124,35 @@ def __init__(self, host=None, port=0, bindaddr=None, bindport=0, self.bindport = bindport self.connected = False - self.mode = mode + self.secure = secure + + if self.secure is None: + # a little qwirky, but allow users to set secure to + # None and have it treated the same way as False + self.secure = False self.socket = None # Track the current index of the secure protocol singleton to use self.secure_protocol_idx = 0 + if self.secure not in (True, False): + # If self.secure identifies an actual protocol, we want to use it + # The below simply looks for it's existance, and if it isn't + # present then we return None + self.secure_protocol_idx = next((i for i, v in \ + enumerate(SECURE_PROTOCOL_PRIORITY) \ + if v[0] == self.secure), None) + + if self.secure_protocol_idx is None: + # Protocol specified was not found and/or supported; alert the + # user with a loud bang; we're done here. + raise AttributeError("Invalid secure protocol specified.") + # A spot we can store our peer certificate; this is only used # if we're dealing with a secure connection self.peer_certificate = {} - # CA stands for Certificate Authority (for those reading this code) # This is the master list of servers that you trust for verifying # your certificates against. If you're using a self-signed key @@ -151,7 +160,7 @@ def __init__(self, host=None, port=0, bindaddr=None, bindport=0, # You'll need these if you want to verify your host self._ca_certs = kwargs.get('ca_certs', "/etc/ssl/certs/ca-bundle.crt") - # These keys are needed for hosting / listen type modes only + # These keys are needed for hosting / listen type connections only self._keyfile = kwargs.get('keyfile', None) self._certfile = kwargs.get('certfile', None) @@ -159,6 +168,12 @@ def __init__(self, host=None, port=0, bindaddr=None, bindport=0, self.bytes_in = 0 self.bytes_out = 0 + # Calculated through connections + self._local_addr = None + self._local_port = None + self._remote_addr = None + self._remote_port = None + def can_read(self, timeout=0.0): """ Checks if there is data that can be read from the @@ -264,6 +279,14 @@ def close(self): self.bytes_in = 0 self.bytes_out = 0 + # reset remote connection details only + # we keep the local ones so we can re-use them + # if possible (especially the port) + + # Calculated through connections + self._remote_addr = None + self._remote_port = None + # return any lingering data in buffer if data: return data @@ -292,19 +315,19 @@ def bind(self, timeout=None, retries=3, retry_wait=10.0): # We don't need to bind connections that are not configured with # bind information. if not (self.bindaddr or self.bindport): + # Store local connection details return True - bindaddr = self.bindaddr - bindport = self.bindport if not self.bindaddr: - bindaddr = '0.0.0.0' + self.bindaddr = '0.0.0.0' if not self.bindport: - bindport = 0 + self.bindport = 0 try: - bind_str = '%s:%s' % (bindaddr, bindport) - self.socket.bind((bindaddr, self.bindport)) + bind_str = '%s:%s' % (self.bindaddr, self.bindport) + self.socket.bind((self.bindaddr, self.bindport)) logger.debug("Socket bound to %s" % bind_str) + # Store local connection details return True except socket.error, e: @@ -337,6 +360,10 @@ def bind(self, timeout=None, retries=3, retry_wait=10.0): # no need to thrash sleep(retry_wait) + # Code can't ever reach here; but to satisfy lint + # we'll put a return statement here + return False + def connect(self, timeout=None, retry_wait=1.00): """ @@ -356,9 +383,11 @@ def connect(self, timeout=None, retry_wait=1.00): """ - connection_str = '%s:%s' % (self.host, self.port) # Blocking until a connection - logger.debug("Connecting to host: %s" % connection_str) + logger.debug("Connecting to host: %s:%d" % ( + self.host, + self.port, + )) if timeout: socket.setdefaulttimeout(timeout) @@ -387,14 +416,22 @@ def connect(self, timeout=None, retry_wait=1.00): # Disable Blocking self.socket.setblocking(False) - logger.info("Connection established to %s", connection_str) + # Store local details of our socket + (self._local_addr, self._local_port) = self.socket.getsockname() + (self._remote_addr, self._remote_port) = self.socket.getpeername() - if self.mode == ConnectionType.CONNECT: + logger.info( + "Connection established to %s:%d" % ( + self._remote_addr, + self._remote_port, + )) + + if self.secure is False: # A non secure connection; we're done break # Encrypt our socket (changing it into an SSLSocket Object) - self.__encrypt_socket(timeout=timeout) + self.__encrypt_socket(timeout=timeout, server_side=False) # If we get here, we were successful in encrypting the # connection; so let's go ahead and break out of our @@ -411,10 +448,20 @@ def connect(self, timeout=None, retry_wait=1.00): ), ) self.close() - # Fetch next - self.__ssl_version(try_next=True) - raise SocketException('Secure Connection Failed') + if self.secure is True: + # Fetch next (but only if nothing was explicitly + # specified) + self.__ssl_version(try_next=True) + raise SocketException('Secure Connection Failed') + + # If we reach here, we had a problem with our secure connection + # handshaking and we were explicitly told to only use 1 (one) + # protocol. Thus there is nothing more to retry. So throwing + # a SocketException() is not a good idea. Instead, we throw + # a SocketRetryLimit() so it can be handled differently + # upstream + raise SocketRetryLimit('There are no protocols left to try.') except socket.error, e: #logger.debug("Exception received: %s " % (e)); @@ -452,7 +499,7 @@ def connect(self, timeout=None, retry_wait=1.00): return True - def listen(self, timeout=None, retry_wait=1.00): + def listen(self, timeout=None, retry_wait=1.00, reuse_port=True): """ input Parameters: -timeout How long accept() should block for before giving up and @@ -470,9 +517,17 @@ def listen(self, timeout=None, retry_wait=1.00): """ + if reuse_port and self._local_port is not None and self.port == 0: + # Re-use the last port we acquired that way we can close a + # connection gracefully and not have to re-acquire a new + # ephemeral port + self.port = self._local_port + # Blocking until a connection - connection_str = '%s:%s' % (self.bindaddr, self.port) - logger.debug("Listening for a connection at: %s" % connection_str) + logger.debug("Listening for a connection at: %s:%d" % ( + self.bindaddr, + self.port, + )) if timeout: socket.setdefaulttimeout(timeout) logger.debug("Socket timeout set to :%ds" % (timeout)) @@ -481,25 +536,36 @@ def listen(self, timeout=None, retry_wait=1.00): if not self.bind(timeout=timeout, retries=3): return False - if timeout is None: - self.socket.setblocking(True) - else: - self.socket.setblocking(False) + # Never use blocking + self.socket.setblocking(False) # Listen Enabled, the 1 identifies the number of connections we # will accept; never handle more then 1 at a time. self.socket.listen(1) + # Store local details of our socket + (self._local_addr, self._local_port) = self.socket.getsockname() + # Get reference time cur_time = datetime.now() while True: try: - conn, self.host = self.socket.accept() + conn, (self._remote_addr, self._remote_port) = \ + self.socket.accept() # If we get here, we've got a connection break + except TypeError, e: # Timeout occurred pass + + except AttributeError, e: + # Usually means someone called close() while accept() was + # blocked. Happens when using this class with threads. + # No problem... we'll just finish up here + self.close() + raise SocketException('Connection broken abruptly') + except socket.error, e: if e[0] == errno.EAGAIN: # Timeout occurred @@ -529,18 +595,29 @@ def listen(self, timeout=None, retry_wait=1.00): self.close() raise SocketException('Connection timeout') - # Throttle retry - sleep(retry_wait) + # Throttle until data is available + if self.can_read(retry_wait) is None: + # Something very bad happened + self.close() + raise SocketException('Connection broken') # Close listening connection - self.close() + self.socket.close() # Swap socket with new self.socket = conn - logger.info("Connection established to %s", connection_str) + # Update our local information + (self._local_addr, self._local_port) = self.socket.getsockname() + (self._remote_addr, self._remote_port) = self.socket.getpeername() + + logger.info( + "Connection established to %s:%d" % ( + self._remote_addr, + self._remote_port, + )) - if self.mode == ConnectionType.LISTEN: + if self.secure is False: # A non secure connection; we're done # Toggle our connection flag @@ -551,7 +628,7 @@ def listen(self, timeout=None, retry_wait=1.00): try: # Encrypt our socket (changing it into an SSLSocket Object) - self.__encrypt_socket(timeout=timeout) + self.__encrypt_socket(timeout=timeout, server_side=True) # If we get here, we were successful in encrypting the connection; # so let's go ahead and break out of our connection loop @@ -566,10 +643,19 @@ def listen(self, timeout=None, retry_wait=1.00): ), ) self.close() - # Fetch next - self.__ssl_version(try_next=True) - raise SocketException('Secure Connection Failed') + if self.secure is True: + # Fetch next (but only if nothing was explicitly + # specified) + self.__ssl_version(try_next=True) + raise SocketException('Secure Connection Failed') + + # If we reach here, we had a problem with our secure connection + # handshaking and we were explicitly told to only use 1 (one) + # protocol. Thus there is nothing more to retry. So throwing + # a SocketException() is not a good idea. Instead, we throw + # a SocketRetryLimit() so it can be handled differently upstream + raise SocketRetryLimit('There are no protocols left to try.') except socket.error, e: #logger.debug("Exception received: %s " % (e)); @@ -661,8 +747,9 @@ def read(self, max_bytes=32768, timeout=None, retry_wait=0.25): self.bytes_in += len(data) total_data.append(data) - if not timeout: - raise SocketException('Connection lost') + #if not timeout: + # raise SocketException('Connection lost') + break if timeout and bytes_read == 0: # We're done @@ -675,6 +762,14 @@ def read(self, max_bytes=32768, timeout=None, retry_wait=0.25): continue + except ssl.SSLWantReadError, e: + # Raised by SSL Socket; This is okay data was received, but not + # all of it. Be patient and try again. + if self.can_read(retry_wait) is None: + self.close() + raise SocketException('Connection broken') + continue + except ssl.SSLZeroReturnError, e: # Raised by SSL Socket self.close() @@ -792,6 +887,36 @@ def send(self, data, max_bytes=None, retry_wait=0.25): return tot_bytes + def local_connection_info(self): + """ + Returns a tuple of current address of 'this' server + if listening, then it is the listing server. If performing a + remote connection, then it is the address that was made in + a bind() call + + If no connection has been established, the connection returns None. + """ + + if self.socket is None: + return None + + return (self._local_addr, self._local_port) + + def remote_connection_info(self): + """ + Returns a tuple of current address of 'this' server + if listening, then it is the listing server. If performing a + remote connection, then it is the address that was made in + a bind() call + + If no connection has been established, the connection returns None. + """ + if self.socket is None: + return None + + return (self._remote_addr, self._remote_port) + + def __ssl_version(self, try_next=False): """ Returns an SSL Context Object while handling the many @@ -810,7 +935,8 @@ def __ssl_version(self, try_next=False): # of SocketException() since we're at the end of the line now. raise SocketRetryLimit('There are no protocols left to try.') - def __encrypt_socket(self, timeout=None, retry_wait=1.00, verify=False): + def __encrypt_socket(self, timeout=None, retry_wait=1.00, verify=False, + server_side=False): """ Wrap an existing Python socket and return an SSLSocket Object. this is iternally called if we're dealing with a secure @@ -842,11 +968,11 @@ def __encrypt_socket(self, timeout=None, retry_wait=1.00, verify=False): 'suppress_ragged_eofs': True, } - if self.mode == ConnectionType.SECURE_LISTEN: + kwargs['server_side'] = server_side + if server_side: # We need to add a few more parameters kwargs['keyfile'] = self._keyfile kwargs['certfile'] = self._certfile - kwargs['server_side'] = True # Verify our certificates/keys exist or abort # These checks are nessisary otherwise you'll get strange errors like: @@ -863,7 +989,7 @@ def __encrypt_socket(self, timeout=None, retry_wait=1.00, verify=False): raise ValueError( 'Could not locate Private Key: %s' % self._keyfile) - if self.mode == ConnectionType.SECURE_CONNECT: + else: if verify: # Verify Certificate cert_reqs = ssl.CERT_REQUIRED @@ -897,7 +1023,7 @@ def __encrypt_socket(self, timeout=None, retry_wait=1.00, verify=False): )) # Store our certificate - if self.mode == ConnectionType.SECURE_CONNECT: + if not server_side: self.peer_certificate = \ self.socket.getpeercert(binary_form=False) @@ -952,13 +1078,13 @@ def __del__(self): self.close() def __str__(self): - if self.mode == ConnectionType.SECURE_CONNECT: - return 'tcps://%s:%d' % (self.host, self.port) + if self.secure is False: + return 'tcp%s:%d' % (self.host, self.port) # else - return 'tcp%s:%d' % (self.host, self.port) + return 'tcps://%s:%d' % (self.host, self.port) def __unicode__(self): - if self.mode == ConnectionType.SECURE_CONNECT: - return u'tcps://%s:%d' % (self.host, self.port) + if self.secure is False: + return u'tcp%s:%d' % (self.host, self.port) # else - return u'tcp%s:%d' % (self.host, self.port) + return u'tcps://%s:%d' % (self.host, self.port) diff --git a/newsreap/tests/NNTPConnection_Test.py b/newsreap/tests/NNTPConnection_Test.py index 973a2d4..4fa42c5 100644 --- a/newsreap/tests/NNTPConnection_Test.py +++ b/newsreap/tests/NNTPConnection_Test.py @@ -23,6 +23,7 @@ import gevent.monkey gevent.monkey.patch_all() + # Import threading after monkey patching # see: http://stackoverflow.com/questions/8774958/\ # keyerror-in-module-threading-after-a-successful-py-test-run @@ -40,7 +41,6 @@ from tests.TestBase import TestBase from tests.NNTPSocketServer import NNTPSocketServer -from tests.NNTPSocketServer import NNTPBaseRequestHandler from lib.NNTPConnection import NNTPConnection from lib.NNTPIOStream import NNTPIOStream @@ -60,44 +60,29 @@ def setUp(self): """ super(NNTPConnection_Test, self).setUp() - self.hostname = "localhost" - - ## Secure NNTP Server + # Secure NNTP Server self.nntps = NNTPSocketServer( - (self.hostname, 0), - NNTPBaseRequestHandler, secure=True, ) - ## Insecure NNTP Server + + # Insecure NNTP Server self.nntp = NNTPSocketServer( - (self.hostname, 0), - NNTPBaseRequestHandler, secure=False, ) - # Get our connection stats - self.nttps_ipaddr, self.nntps_portno = self.nntps.server_address - self.nttp_ipaddr, self.nntp_portno = self.nntp.server_address - - # Push DUMMY NTP Server To Thread - self.nntps_thread = threading.Thread( - target=self.nntps.serve_forever, - name='NNTPSServer', - ) - - self.nntp_thread = threading.Thread( - target=self.nntp.serve_forever, - name='NNTPServer', - ) - # Exit the server thread when the main thread terminates self.nntps.daemon = True self.nntp.daemon = True - # Start Threads - self.nntps_thread.start() - self.nntp_thread.start() + # Start Our Server Threads + self.nntps.start() + self.nntp.start() + # Acquire our configuration + self.nttp_ipaddr, self.nntp_portno = \ + self.nntp.local_connection_info() + self.nttps_ipaddr, self.nntps_portno = \ + self.nntps.local_connection_info() def tearDown(self): # Shutdown NNTP Dummy Servers Daemons @@ -106,7 +91,6 @@ def tearDown(self): super(NNTPConnection_Test, self).tearDown() - def test_authentication(self): sock = NNTPConnection( host=self.nttp_ipaddr, @@ -116,7 +100,7 @@ def test_authentication(self): secure=False, join_group=False, ) - assert sock.connect(timeout=5.0) == True + assert sock.connect(timeout=5.0) is True assert sock._iostream == NNTPIOStream.RFC3977_GZIP sock.close() @@ -129,7 +113,7 @@ def test_authentication(self): join_group=False, ) # Invalid Username - assert sock.connect(timeout=5.0) == False + assert sock.connect(timeout=5.0) is False sock = NNTPConnection( host=self.nttp_ipaddr, @@ -140,7 +124,7 @@ def test_authentication(self): join_group=False, ) # Invalid Password - assert sock.connect(timeout=5.0) == False + assert sock.connect(timeout=5.0) is False def test_secure_authentication(self): @@ -152,7 +136,7 @@ def test_secure_authentication(self): secure=True, join_group=False, ) - assert sock.connect(timeout=5.0) == True + assert sock.connect(timeout=5.0) is True assert sock._iostream == NNTPIOStream.RFC3977_GZIP sock.close() @@ -165,7 +149,7 @@ def test_secure_authentication(self): join_group=False, ) # Invalid Username - assert sock.connect(timeout=5.0) == False + assert sock.connect(timeout=5.0) is False sock = NNTPConnection( host=self.nttps_ipaddr, @@ -176,10 +160,10 @@ def test_secure_authentication(self): join_group=False, ) # Invalid Password - assert sock.connect(timeout=5.0) == False + assert sock.connect(timeout=5.0) is False - - def test_regular_expressions(self): + @classmethod + def test_regular_expressions(cls): """ Tests XOVER Regular Expressions @@ -306,71 +290,3 @@ def test_group_searching(self): # Again without the lazy flag set groups = sock.groups(filters='alt.binaries') assert len(groups) == 5270 - - -if __name__ == '__main__': - - hostname = "localhost" - - ## Secure NNTP Server - nntps = NNTPSocketServer( - (hostname, 0), - NNTPBaseRequestHandler, - secure=True, - ) - - ## Insecure NNTP Server - nntp = NNTPSocketServer( - (hostname, 0), - NNTPBaseRequestHandler, - secure=False, - ) - - # Get our connection stats - nttps_ipaddr, nntps_portno = nntps.server_address - nttp_ipaddr, nntp_portno = nntp.server_address - - # Push DUMMY NTP Server To Thread - nntps_thread = threading.Thread( - target=nntps.serve_forever, - name='NTPS_Server', - ) - - nntp_thread = threading.Thread( - target=nntp.serve_forever, - name='NTP_Server', - ) - - # Exit the server thread when the main thread terminates - nntps.daemon = True - nntp.daemon = True - - # Start Threads - nntps_thread.start() - nntp_thread.start() - - socket = NNTPConnection( - host=nttp_ipaddr, - port=nntp_portno, - username='valid', - password='user', - secure=False, - join_group=False, - ) - - ssocket = NNTPConnection( - host=nttps_ipaddr, - port=nntps_portno, - username='valid', - password='user', - secure=True, - join_group=False, - ) - - print 'DEBUG: CLIENT CONNECT' - ssocket.connect(timeout=20.0) - print 'DEBUG: CLIENT CONNECTED' - print 'DEBUG: CLIENT CLOSING CONNECTION' - nntp.shutdown() - nntps.shutdown() - exit(0) diff --git a/newsreap/tests/NNTPSocketServer.py b/newsreap/tests/NNTPSocketServer.py index c536e9c..6b3202b 100644 --- a/newsreap/tests/NNTPSocketServer.py +++ b/newsreap/tests/NNTPSocketServer.py @@ -26,25 +26,30 @@ from gevent import ssl from gevent import socket -from gevent.select import select gevent.monkey.patch_all() -#import socket # Import threading after monkey patching # see: http://stackoverflow.com/questions/8774958/\ # keyerror-in-module-threading-after-a-successful-py-test-run import threading import re -import SocketServer from pprint import pformat from io import BytesIO from zlib import compress from os.path import dirname from os.path import join -from os.path import isfile from os.path import abspath +try: + from lib.SocketBase import SocketBase + from lib.SocketBase import SocketException + +except ImportError: + sys.path.insert(0, dirname(dirname(abspath(__file__)))) + from lib.SocketBase import SocketBase + from lib.SocketBase import SocketException + # Our internal server is only used for testing, therefore we can get away with # having a really low timeout socket.setdefaulttimeout(10.0) @@ -63,6 +68,7 @@ # Article ID ARTICLE_ID_RE = re.compile(r'\s*<*\s*(?P[^>]+)>?.*') + # All of the default NNTP Responses are defined here by their # compiled regular expression NNTP_DEFAULT_MAP = { @@ -99,6 +105,8 @@ }, re.compile('^QUIT'): { 'response': '200 See you later!', + # Reset our current state of the map + 'reset': True, }, } @@ -107,17 +115,67 @@ NNTP_EOD = '.\r\n' -class NNTPSocketServer(SocketServer.TCPServer): - def __init__(self, server_address, RequestHandlerClass, - bind_and_activate=True, secure=True, join_group=True): +class NNTPClient(SocketBase): + """ + An NNTPClient is produced from a NNTPSocketServer by simply calling + get_client(). + + With an NNTPClient() you call send() to push your commands and + will get the results returned to you + + """ + def __init__(self, *args, **kwargs): + # Initialize the Socket Base Class + super(NNTPClient, self).__init__(*args, **kwargs) + + def put(self, line, eol=True): + """ + A Simple put() script to simplify transmitting data to the fake server + + You can directly interface through here, or you can use the + get_client() to spin an actual a TCP/IP Client + """ + + if eol: + line = line + NNTP_EOL + + print "Sent: %s" % line.strip() + self.send(line) + response = self.read() + print "Received: %s" % response.strip() + print - # Hostname is used for SSL Verification (if set) - self.hostname = server_address[0] + return response + + def close(self): + """ + Gracefully disconnects from the server + """ + try: + # Prevent Recursion by calling parent send() + super(NNTPClient, self).send('QUIT' + EOL) + except: + # well.. we tried at least + pass + + try: + # close the port + self.socket.close() + except: + pass + + +class NNTPSocketServer(threading.Thread): + def __init__(self, join_group=True, host='localhost', port=0, + secure=None, *args, **kwargs): + + # Handle Threading + threading.Thread.__init__(self) self._can_post = True self._has_yenc = True - self._secure = secure + self._active = threading.Event() self._io_wait = threading.Event() self._maplock = threading.Lock() @@ -126,33 +184,10 @@ def __init__(self, server_address, RequestHandlerClass, self._join_group = join_group # Server (self-signed) Certificates for SSL Testing - self.certfile = abspath( - join(dirname(__file__), 'var', 'ssl','localhost.crt')) - self.keyfile = abspath( - join(dirname(__file__), 'var', 'ssl','localhost.key')) - - # These checks are very nessisary; you'll get strange errors like: - # _ssl.c:341: error:140B0002:SSL \ - # routines:SSL_CTX_use_PrivateKey_file:system lib - # - # The error itself will surface during the call to wrap_socket() which - # will throw the exception ssl.SSLError - # - # it doesn't hurt to just check ahead of time. - if not isfile(self.certfile): - raise ValueError( - 'Could not locate Certificate: %s' % self.certfile) - if not isfile(self.keyfile): - raise ValueError( - 'Could not locate Private Key: %s' % self.keyfile) - - # Secure Protocol to use - try: - # Python v2.7+ - self.ssl_version = ssl.PROTOCOL_TLSv1_2 - except AttributeError: - # Python v2.6+ - self.ssl_version = ssl.PROTOCOL_TLSv1 + kwargs['certfile'] = abspath( + join(dirname(__file__), 'var', 'ssl', 'localhost.crt')) + kwargs['keyfile'] = abspath( + join(dirname(__file__), 'var', 'ssl', 'localhost.key')) # sent welcome self.sent_welcome = False @@ -160,16 +195,6 @@ def __init__(self, server_address, RequestHandlerClass, # Override Map self.override_map = {} - # Monkey Patch so that we can toggle the reuse_address - SocketServer.TCPServer.allow_reuse_address = True - - # Initialize Server - SocketServer.TCPServer.__init__( - self, server_address=server_address, - RequestHandlerClass=RequestHandlerClass, - bind_and_activate=bind_and_activate, - ) - # A map of id's to filenames (stored locally on disk) # if a fetch is made to an item that contains a map, then # it is retrieved from disk and delivered @@ -186,190 +211,263 @@ def __init__(self, server_address, RequestHandlerClass, # If you set this to None, then nothing is returned. self.default_fetch = DEFAULT_EMPTY_FILE + # Initialize the Socket Base Class + self.socket = SocketBase( + host=host, port=port, secure=secure, *args, **kwargs) - def shutdown(self): + def local_connection_info(self, timeout=3.0): """ - Handle shutdown + Returns local configuration (listening info) """ - # Clear io wait flag - #print('DEBUG: SERVER GOT SHUTDOWN') - self._io_wait.clear() - try: - self.socket.close() - except: - pass - # Reset welcome flag - self.sent_welcome = False - return SocketServer.TCPServer.shutdown(self) + if self.socket is None: + return None + # Block until server is active + if not self._active.wait(timeout): + return None - def set_override(self, override=None): + return self.socket.local_connection_info() + + def get_client(self, timeout=3.0): """ - Sets an override map (or resets it to nothing) + Returns a client after establishing a connection to the server. + """ + connection_info = self.local_connection_info(timeout=timeout) + if connection_info: + _ipaddr, _portno = connection_info + + # create a socket + sock = NNTPClient( + host=_ipaddr, + port=_portno, + secure=self.socket.secure, + ) - if not override: - override = {} + # connect + sock.connect(5.0) - # Store copy of passed in override + # return the socket + return sock + + def put(self, line, eol=True): + """ + Pushes directly to the NNTPSocketServer without need of a remote + connection and acquires the response + """ + # print 'Scanning Against: "%s"' % line + + # cur_thread = threading.current_thread() + # response = "{}: {}".format(cur_thread.name, data) + # self.socket.send(response) + + # Process over-ride map self._maplock.acquire() - self.override_map = dict(override) + override = self.override_map.items() self._maplock.release() + response = None + for k, v in override + NNTP_DEFAULT_MAP.items(): + result = k.search(line) + if result: + # we matched + if 'response' in v: + response = v['response'] - def reset(self): - """ - This function is called to let the server handle disconnects faster - """ - self._io_wait.clear() + if 'reset' in v: + # Reset our current state + self.reset() - # Reset the current group - self.current_group = None + if 'stat' in v: + entry = str(result.group(v['stat'])) + if not self.current_group: + response = '412 No newsgroup selected' - # sent welcome - self.sent_welcome = False + elif not entry: + response = '423 No article with that number' + else: + response = '223 %s Article exists' % entry - def map(self, id, groups, filepath=None): - """ - Maps an article (and groups) id to a filepath - """ - if isinstance(groups, basestring): - # expect a list of groups, but allow single - # entries too; just convert them before - # moving on - groups = [groups, ] + break - self._maplock.acquire() - for group in groups: - if group not in self.group_map: - # Create Group Entry - self.group_map[group] = [0, 0, 0, group] - # Create fetch_map entry (empty) - self.fetch_map[group] = {} + if 'group' in v: + entry = str(result.group(v['group'])) + if not entry: + response = '423 No such article in this group' + self.current_group = None + + elif entry not in self.group_map: + response = '423 No such article in this group' + self.current_group = None + + else: + response = '211 %d %d %d %s' % ( + self.group_map[entry][0], + self.group_map[entry][1], + self.group_map[entry][2], + self.group_map[entry][3], + ) - if filepath: - # If a file was specified, update our details - self.fetch_map[group][str(id)] = str(filepath) - # Increment tail - self.group_map[group][1] += 1 - # Increment count - self.group_map[group][2] += 1 + # Set Group + self.current_group = entry - self._maplock.release() + # We're done handling GROUP command + break + # checking that we're good to go that way + if 'article' in v: + # Tidy up our article id + _result = ARTICLE_ID_RE.match( + str(result.group(v['article'])), + ) + if not _result: + response = '423 No article with that number' + break -class NNTPBaseRequestHandler(SocketServer.BaseRequestHandler): - """ - The RequestHandler class for our server. + if self._join_group: + # A group join is required; perform some overhead + if not self.current_group: + response = '412 No newsgroup selected' + break - It is instantiated once per connection to the server, and must - override the handle() method to implement communication to the - client. - """ + elif self.current_group \ + not in self.fetch_map: + # Not found + response = '423 No article with that number' + break - def pending_data(self, timeout=0.0): - """ - Checks if there is data that can be read from the - socket (if open). Returns True if there is data and - False if not. - """ - # rs = Read Sockets - # ws = Write Sockets - # es = Error Sockets - if self.request: - rs, _, _ = select([self.request] , [], [self.request], timeout) - return len(rs) > 0 - return None + # create a file from our fetch map + entry = self.fetch_map\ + [self.current_group].get( + str(_result.group('id')), + self.default_fetch, + ) + else: + try: + # If we are in a group, test it first + entry = self.fetch_map\ + [self.current_group].get( + str(_result.group('id')), + self.default_fetch, + ) - def apply_ssl(self): - """ - Wraps connection with SSL - """ + except KeyError: + # Otherwise, iterate through our list and find + # match since there is no join_group + # requirement + found = False + for g in self.fetch_map.iterkeys(): + entry = self.fetch_map[g].get( + str(_result.group('id')), + False, + ) + if entry is False: + # No match + continue - # Swap out old socket - self._request = self.request + # Toggle found flag + found = True - try: - # Python 2.7.9 - context = ssl.SSLContext(self.server.ssl_version) - #context.check_hostname = True - context.check_hostname = False - #context.load_verify_locations(ca_cert) - #context.load_default_certs() - context.load_cert_chain( - certfile=self.server.certfile, - keyfile=self.server.keyfile, - ) - - # Save new SSL Socket - self.request = context.wrap_socket( - self.request, - server_side=True, - ) - self.request.connect(self.server.server_address) - - except (ValueError, AttributeError, TypeError): - try: - # <=Python 2.7.8 - self.request = ssl.wrap_socket( - self.request, - server_side=True, - certfile=self.server.certfile, - keyfile=self.server.keyfile, - ssl_version=self.server.ssl_version, - ) - #except ssl.SSLError, e: - except ssl.SSLError: - #print 'DEBUG: SERVER DENIED CLIENT SSL (wrong version)' - #print str(e) - return False - - #except ssl.SSLError, e: - except ssl.SSLError: - #print 'DEBUG: SERVER DENIED CLIENT SSL (wrong version)' - #print str(e) - return False + if entry is None: + # Assign Default + entry = self.default_fetch - return True + # We're done + break + if not found: + # Not found + response = '423 No article with that number' + break - def handle(self): - # self.request is the TCP socket connected to the client + if isinstance(entry, basestring): + response = '230 Retrieving article.' + # Store our file we mapped to + v['file'] = entry - #print 'DEBUG: HANDLE IN' - if self.server._secure: - #print 'DEBUG: SECURING CONNECTION' - # Applying SSL - if not self.apply_ssl(): - self.server._io_wait.clear() - return + # Fall through to handle file + else: + # Not found + response = '423 No article with that number' + break + + if 'file' in v: + try: + # Read in a file and send it + fd = open(v['file'], 'rb') + response += NNTP_EOL + fd.read() + fd.close() + + except IOError: + response = "501 file '%s' is missing." % v['file'] + break + + elif 'gzip' in v: + # If the file isn't gzipped and you 'want' to gzip it + # then you use this + try: + # Read in a gzipped file and send it + fd = open(v['gzip'], 'rb') + if response is None: + response = '230 Retrieving article.' + + response += " [COMPRESS=GZIP]" + NNTP_EOL + \ + compress(fd.read()) + fd.close() + + except IOError: + response = "502 file '%s' is missing." % v['file'] + break + + # We're done + break + + if response is None: + response = "503 No handler for the request" + + return response + + def is_ready(self, timeout=5.0): + """ + A thread safe way of finding out if we're up and listen for a + connection + """ + return self._active.wait(timeout) + + def nntp_server(self): + """ + A fake nntp server that generates responses like a real one + + It lets us test the protocol by simulating different responses. + """ + + # Set io_wait flag + self._io_wait.set() # Send Welcome Message - if not self.server.sent_welcome: + if not self.sent_welcome: welcome_str = "200 l2g.caronc.dummy NNRP Service Ready" - if self.server._can_post: + if self._can_post: welcome_str += " (posting ok)" - if self.server._has_yenc: + if self._has_yenc: welcome_str += " (yEnc enabled)" try: - self.request.sendall(welcome_str + NNTP_EOD) + self.socket.send(welcome_str + NNTP_EOD) except: # connection lost - #print 'DEBUG: SOCKET ERROR DURING SEND (EXITING)....' + # print 'DEBUG: SOCKET ERROR DURING SEND (EXITING)....' return - self.server.sent_welcome = True + self.sent_welcome = True data = BytesIO() - # Set the io_wait() flag for we're waiting on data now - self.server._io_wait.set() d_len = data.tell() - while self.server._io_wait.is_set(): - #print 'DEBUG: SERVER LOOP' + while self._active.is_set() and self.socket.connected: + # print 'DEBUG: SERVER LOOP' # ptr manipulation d_ptr = data.tell() @@ -380,339 +478,209 @@ def handle(self): data.seek(d_ptr) try: - #print 'DEBUG: SERVER BLOCKING FOR DATA' - pending = self.pending_data(0.8) + # print 'DEBUG: SERVER BLOCKING FOR DATA' + pending = self.socket.can_read(0.8) if pending is None: # No more data - self.server._io_wait.clear() continue if not pending: # nothing pending; back to io_wait continue - while self.pending_data(): - #print 'DEBUG: SERVER BLOCKING FOR DATA....' - _data = self.request.recv(4096) + while self.socket.can_read(): + # print 'DEBUG: SERVER BLOCKING FOR DATA....' + _data = self.socket.read() if not _data: - #print 'DEBUG: SERVER NO DATA (EXITING)....' - # Reset our sent_welcome flag - self.server.sent_welcome = False + # print 'DEBUG: SERVER NO DATA (EXITING)....' + # Reset our settings to prepare for another connection + self.reset() return # Buffer response data.write(_data) d_len = data.tell() - except socket.error, e: + except socket.error: # Socket Issue - self.server._io_wait.clear() - #print 'DEBUG: SOCKET ERROR (EXITING)....' - #print 'DEBUG: ERROR %s' % str(e) + # print 'DEBUG: SOCKET ERROR (EXITING)....' + # print 'DEBUG: ERROR %s' % str(e) # Reset our sent_welcome flag - self.server.sent_welcome = False + self.sent_welcome = False return - # Seek End for size if d_ptr == d_len: continue data.seek(d_ptr) + # Acquire our line line = data.readline() - #print 'Scanning Against: "%s"' % line - - #cur_thread = threading.current_thread() - #response = "{}: {}".format(cur_thread.name, data) - #self.request.sendall(response) - - # Process over-ride map - self.server._maplock.acquire() - override = self.server.override_map.items() - self.server._maplock.release() - - response = None - for k, v in override + NNTP_DEFAULT_MAP.items(): - result = k.search(line) - if result: - # we matched - if 'response' in v: - response = v['response'] - - if 'stat' in v: - entry = str(result.group(v['stat'])) - if not self.server.current_group: - response = '412 No newsgroup selected' - elif not entry: - response = '423 No article with that number' + # Build our response + response = self.put(line) - else: - response = '223 %s Article exists' % entry - - break - - if 'group' in v: - entry = str(result.group(v['group'])) - if not entry: - response = '423 No such article in this group' - self.server.current_group = None - - elif entry not in self.server.group_map: - response = '423 No such article in this group' - self.server.current_group = None - - else: - response = '211 %d %d %d %s' % ( - self.server.group_map[entry][0], - self.server.group_map[entry][1], - self.server.group_map[entry][2], - self.server.group_map[entry][3], - ) - - # Set Group - self.server.current_group = entry - - # We're done handling GROUP command - break - - # checking that we're good to go that way - if 'article' in v: - # Tidy up our article id - _result = ARTICLE_ID_RE.match( - str(result.group(v['article'])), - ) - if not _result: - response = '423 No article with that number' - break - - if self.server._join_group: - # A group join is required; perform some overhead - #import pdb - #pdb.set_trace() - if not self.server.current_group: - response = '412 No newsgroup selected' - break - - elif self.server.current_group \ - not in self.server.fetch_map: - # Not found - response = '423 No article with that number' - break + # Return it on the socket + try: + self.socket.send(response + NNTP_EOD) + except: + # connection lost + # print 'DEBUG: SOCKET ERROR DURING SEND (EXITING)....' + return - # create a file from our fetch map - entry = self.server.fetch_map\ - [self.server.current_group].get( - str(_result.group('id')), - self.server.default_fetch, - ) + # print 'DEBUG: handle() (EXITING)....' - else: - try: - # If we are in a group, test it first - entry = self.server.fetch_map\ - [self.server.current_group].get( - str(_result.group('id')), - self.server.default_fetch, - ) + def run(self): + """ + Run thread + """ + # Enable Server + self._active.set() - except KeyError: - # Otherwise, iterate through our list and find - # match since there is no join_group - # requirement - found = False - for g in self.server.fetch_map.iterkeys(): - entry = self.server.fetch_map[g].get( - str(_result.group('id')), - False, - ) - if entry is False: - # No match - continue - - # Toggle found flag - found = True - - if entry is None: - # Assign Default - entry = self.server.default_fetch - - # We're done - break - - if not found: - # Not found - response = '423 No article with that number' - break - - if isinstance(entry, basestring): - response = '230 Retrieving article.' - # Store our file we mapped to - v['file'] = entry + while self._active.is_set(): + # Thread Main Loop + try: + if not self.socket.listen(): + # Wait for a connection + continue + except SocketException: + # Lost the connection; loop + continue - # Fall through to handle file - else: - # Not found - response = '423 No article with that number' - break + # print 'DEBUG: SERVERSIDE CONNECTION ESTABLISHED!' - if 'file' in v: - try: - # Read in a file and send it - fd = open(v['file'], 'rb') - response += NNTP_EOL + fd.read() - fd.close() + # If we reach we have a connection + self.nntp_server() - except IOError: - response = "501 file '%s' is missing." % v['file'] - break + # We're probably here because we lost our connection + self.reset() + # We're finished (close our socket if not already done so) + self.socket.close() - elif 'gzip' in v: - # If the file isn't gzipped and you 'want' to gzip it - # then you use this - try: - # Read in a gzipped file and send it - fd = open(v['gzip'], 'rb') - if response is None: - response = '230 Retrieving article.' + def shutdown(self): + """ + Handle shutdown + """ + # Clear active flag + self._active.clear() + # print('DEBUG: SERVER GOT SHUTDOWN') + try: + self.socket.close() + except: + pass - response += " [COMPRESS=GZIP]" + NNTP_EOL + \ - compress(fd.read()) - fd.close() + # Reset welcome flag + self.sent_welcome = False - except IOError: - response = "502 file '%s' is missing." % v['file'] - break + # We're done + return True - # We're done - break + def set_override(self, override=None): + """ + Sets an override map (or resets it to nothing) + """ - if response is None: - response = "503 No handler for the request" + # Clear io_wait flag + self._io_wait.clear() - try: - self.request.sendall(response + NNTP_EOD) - except: - # connection lost - #print 'DEBUG: SOCKET ERROR DURING SEND (EXITING)....' - return - #print 'DEBUG: handle() (EXITING)....' + if not override: + override = {} + # Store copy of passed in override + self._maplock.acquire() + self.override_map = dict(override) + self._maplock.release() -def ssl_client(hostname, port, message, version=ssl.PROTOCOL_TLSv1): + def reset(self): + """ + This function is called to let the server handle disconnects faster + """ + # Clear io_wait flag + self._io_wait.clear() - # Possible Verify Modes: - # - ssl.CERT_NONE - # - ssl.CERT_OPTIONAL - # - ssl.CERT_REQUIRED + # Reset the current group + self.current_group = None - # We don't want to veryify our key since it's just a - # localhost self signed one - cert_reqs = ssl.CERT_NONE + # sent welcome + self.sent_welcome = False - _sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - _sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) + def map(self, article_id, groups, filepath=None): + """ + Maps an article (and groups) article_id to a filepath + """ + if isinstance(groups, basestring): + # expect a list of groups, but allow single + # entries too; just convert them before + # moving on + groups = [groups, ] - sock = ssl.wrap_socket( - _sock, - ca_certs="/etc/ssl/certs/ca-bundle.crt", - cert_reqs=cert_reqs, - ssl_version=version, - ) - try: - sock.connect((hostname, port)) - except ssl.SSLError, e: - #print 'DEBUG: CLIENT DENIED SSL BY SERVER' - #print str(e) - sock.close() - return - - #print repr(sock.getpeername()) - #print sock.cipher() - #print pformat(sock.getpeercert()) - - try: - sock.sendall(message) - print "Sent: %s" % message.strip() - response = sock.recv(4096) - print "Received: %s" % response.strip() - print - finally: - sock.close() + self._maplock.acquire() + for group in groups: + if group not in self.group_map: + # Create Group Entry + self.group_map[group] = [0, 0, 0, group] + # Create fetch_map entry (empty) + self.fetch_map[group] = {} + if filepath: + # If a file was specified, update our details + self.fetch_map[group][str(article_id)] = str(filepath) + # Increment tail + self.group_map[group][1] += 1 + # Increment count + self.group_map[group][2] += 1 -def client(ip, port, message): - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) - sock.connect((ip, port)) + self._maplock.release() - try: - sock.sendall(message) - print "Sent: %s" % message.strip() - response = sock.recv(4096) - print "Received: %s" % response.strip() - print - finally: - sock.close() if __name__ == "__main__": - hostname, portno = "localhost", 0 - ## SSL Checking + # SSL Checking nntp_server = NNTPSocketServer( - (hostname, portno), - NNTPBaseRequestHandler, - secure=True, + secure=ssl.PROTOCOL_TLSv1, + #secure=False, ) - # Get our connection stats - ipaddr, portno = nntp_server.server_address + # Launch thread + # nntp_server.daemon = True + nntp_server.start() - # Push DUMMY NTP Server To Thread - t = threading.Thread( - target=nntp_server.serve_forever, - name='NTPServer', - ) - # Exit the server thread when the main thread terminates - t.daemon = True - t.start() + # Acquire a client connection + socket = nntp_server.get_client() - ssl_client(hostname, portno, "AUTHINFO USER valid\r\n", ssl.PROTOCOL_TLSv1) - ssl_client(hostname, portno, "AUTHINFO USER valid\r\n", ssl.PROTOCOL_SSLv3) + socket.put("AUTHINFO USER valid") + # client(hostname, portno, "AUTHINFO USER valid\r\n") + # client(hostname, portno, "AUTHINFO USER valid\r\n", ssl.PROTOCOL_TLSv1) + # client(hostname, portno, "AUTHINFO USER valid\r\n", ssl.PROTOCOL_SSLv3) nntp_server.shutdown() ## NON SSL nntp_server = NNTPSocketServer( - (hostname, portno), - NNTPBaseRequestHandler, secure=False, ) - # Get our connection stats - ipaddr, portno = nntp_server.server_address - - # Push DUMMY NTP Server To Thread - t = threading.Thread( - target=nntp_server.serve_forever, - name='NTPServer', - ) # Append file to map - nntp_server.map('3', 'alt.bin.test', join(NNTP_TEST_VAR_PATH, '00000005.ntx')) + nntp_server.map( + '3', 'alt.bin.test', + join(NNTP_TEST_VAR_PATH, '00000005.ntx'), + ) # Exit the server thread when the main thread terminates - t.daemon = True - t.start() + # nntp_server.daemon = True + nntp_server.start() - client(ipaddr, portno, "AUTHINFO USER valid\r\n") - #nntp_server.reset() + # Acquire a client connection + socket = nntp_server.get_client() - client(ipaddr, portno, "AUTHINFO PASS user\r\n") - #nntp_server.reset() + socket.put("AUTHINFO USER valid") + socket.put("AUTHINFO PASS user") + socket.put("READ FILE 3") + socket.put("GROUP alt.bin.test") + socket.put("ARTICLE 3") + socket.put("Hello World 3") - client(ipaddr, portno, "READ FILE 3\r\n") - client(ipaddr, portno, "GROUP alt.bin.test\r\n") - client(ipaddr, portno, "ARTICLE 3\r\n") - client(ipaddr, portno, "Hello World 3\r\n") - #nntp_server.reset() + # Close our client + socket.close() + # Shutdown our server nntp_server.shutdown() diff --git a/newsreap/tests/NNTPYencArticle_Test.py b/newsreap/tests/NNTPYencArticle_Test.py index 9aaa0bc..b5ab251 100644 --- a/newsreap/tests/NNTPYencArticle_Test.py +++ b/newsreap/tests/NNTPYencArticle_Test.py @@ -44,7 +44,6 @@ from tests.NNTPSocketServer import NNTPSocketServer -from tests.NNTPSocketServer import NNTPBaseRequestHandler from tests.NNTPSocketServer import NNTP_TEST_VAR_PATH as VAR_PATH from lib.NNTPConnection import NNTPConnection @@ -60,70 +59,50 @@ def setUp(self): """ super(NNTPYencArticle_Test, self).setUp() - self.hostname = "localhost" - - ## Secure NNTP Server + # Secure NNTP Server self.nntps = NNTPSocketServer( - (self.hostname, 0), - NNTPBaseRequestHandler, secure=True, join_group=True, ) - ## Insecure NNTP Server + + # Insecure NNTP Server self.nntp = NNTPSocketServer( - (self.hostname, 0), - NNTPBaseRequestHandler, secure=False, join_group=True, ) - # Get our connection stats - self.nttps_ipaddr, self.nntps_portno = self.nntps.server_address - self.nttp_ipaddr, self.nntp_portno = self.nntp.server_address - - # Push DUMMY NTP Server To Thread - self.nntps_thread = threading.Thread( - target=self.nntps.serve_forever, - name='NTPServer', - ) - - self.nntp_thread = threading.Thread( - target=self.nntp.serve_forever, - name='NTPServer', - ) - # Common Group Name self.common_group = 'alt.binaries.test' # Map Articles (to groups) for fetching self.nntp.map( - id='5', + article_id='5', groups=(self.common_group, ), filepath=join(VAR_PATH, '00000005.ntx'), ) self.nntps.map( - id='5', + article_id='5', groups=(self.common_group, ), filepath=join(VAR_PATH, '00000005.ntx'), ) self.nntp.map( - id='20', + article_id='20', groups=(self.common_group, ), filepath=join(VAR_PATH, '00000020.ntx'), ) self.nntp.map( - id='21', + article_id='21', groups=(self.common_group, ), filepath=join(VAR_PATH, '00000021.ntx'), ) self.nntps.map( - id='20', + article_id='20', groups=(self.common_group, ), filepath=join(VAR_PATH, '00000020.ntx'), ) self.nntps.map( - id='21', + article_id='21', groups=(self.common_group, ), filepath=join(VAR_PATH, '00000021.ntx'), ) @@ -132,9 +111,15 @@ def setUp(self): self.nntps.daemon = True self.nntp.daemon = True - # Start Threads - self.nntps_thread.start() - self.nntp_thread.start() + # Start Our Server Threads + self.nntps.start() + self.nntp.start() + + # Acquire our configuration + self.nttp_ipaddr, self.nntp_portno = \ + self.nntp.local_connection_info() + self.nttps_ipaddr, self.nntps_portno = \ + self.nntps.local_connection_info() def tearDown(self): diff --git a/newsreap/tests/TestBase.py b/newsreap/tests/TestBase.py index a32153f..bdad177 100644 --- a/newsreap/tests/TestBase.py +++ b/newsreap/tests/TestBase.py @@ -38,7 +38,6 @@ from lib.codecs import CodecBase except ImportError: - print 'importing %s' % dirname(dirname(abspath(__file__))) sys.path.insert(0, dirname(dirname(abspath(__file__)))) from lib.codecs import CodecBase