diff --git a/sickle/app.py b/sickle/app.py index 6fa3b1b..4f9526c 100644 --- a/sickle/app.py +++ b/sickle/app.py @@ -12,6 +12,8 @@ import time import requests +from requests.adapters import HTTPAdapter +from requests.packages.urllib3.util.retry import Retry from sickle.iterator import BaseOAIIterator, OAIItemIterator from sickle.response import OAIResponse @@ -56,11 +58,11 @@ class Sickle(object): use the value from the retry-after header (if present) and will wait the specified number of seconds between retries. :type max_retries: int - :param retry_status_codes: HTTP status codes to retry (default will only retry on 503) + :param retry_status_codes: HTTP status codes to retry (default will retry on 429, 500, 502, 503 and 504) :type retry_status_codes: iterable - :param default_retry_after: default number of seconds to wait between retries in case no retry-after header is found - on the response (defaults to 60 seconds) - :type default_retry_after: int + :param retry_backoff_factor: Backoff factor to apply between retries after the second try, + if no Retry-After header is sent by the server. Default: 2.0 + :type retry_backoff_factor: float :type protocol_version: str :param class_mapping: A dictionary that maps OAI verbs to classes representing OAI items. If not provided, @@ -86,7 +88,8 @@ def __init__(self, endpoint, iterator=OAIItemIterator, max_retries=0, retry_status_codes=None, - default_retry_after=60, + default_retry_after=None, + retry_backoff_factor=2, class_mapping=None, encoding=None, **request_args): @@ -104,9 +107,20 @@ def __init__(self, endpoint, else: raise TypeError( "Argument 'iterator' must be subclass of %s" % BaseOAIIterator.__name__) - self.max_retries = max_retries - self.retry_status_codes = retry_status_codes or [503] - self.default_retry_after = default_retry_after + + if default_retry_after is not None: + logger.warning("default_retry_after is no longer supported, please use retry_backoff_factor instead.") + + retry_adapter = requests.adapters.HTTPAdapter(max_retries=Retry( + total=max_retries, + backoff_factor=retry_backoff_factor, + status_forcelist=retry_status_codes or [429, 500, 502, 503, 504], + method_whitelist=frozenset(['GET', 'POST']) + )) + self.session = requests.Session() + self.session.mount('https://', retry_adapter) + self.session.mount('http://', retry_adapter) + self.oai_namespace = OAI_NAMESPACE % self.protocol_version self.class_mapping = class_mapping or DEFAULT_CLASS_MAP self.encoding = encoding @@ -119,23 +133,17 @@ def harvest(self, **kwargs): # pragma: no cover :rtype: :class:`sickle.OAIResponse` """ http_response = self._request(kwargs) - for _ in range(self.max_retries): - if self._is_error_code(http_response.status_code) \ - and http_response.status_code in self.retry_status_codes: - retry_after = self.get_retry_after(http_response) - logger.warning( - "HTTP %d! Retrying after %d seconds..." % (http_response.status_code, retry_after)) - time.sleep(retry_after) - http_response = self._request(kwargs) - http_response.raise_for_status() if self.encoding: http_response.encoding = self.encoding return OAIResponse(http_response, params=kwargs) def _request(self, kwargs): if self.http_method == 'GET': - return requests.get(self.endpoint, params=kwargs, **self.request_args) - return requests.post(self.endpoint, data=kwargs, **self.request_args) + response = self.session.get(self.endpoint, params=kwargs, **self.request_args) + else: + response = self.session.post(self.endpoint, data=kwargs, **self.request_args) + response.raise_for_status() + return response def ListRecords(self, ignore_deleted=False, **kwargs): """Issue a ListRecords request. diff --git a/sickle/tests/test_harvesting.py b/sickle/tests/test_harvesting.py index b0c2886..81b80d5 100644 --- a/sickle/tests/test_harvesting.py +++ b/sickle/tests/test_harvesting.py @@ -238,14 +238,10 @@ class TestCaseWrongEncoding(unittest.TestCase): def __init__(self, methodName='runTest'): super(TestCaseWrongEncoding, self).__init__(methodName) - self.patch = mock.patch('sickle.app.requests.get', mock_get) def setUp(self): - self.patch.start() self.sickle = Sickle('http://localhost') - - def tearDown(self): - self.patch.stop() + self.sickle.session.get = mock_get def test_GetRecord(self): oai_id = 'oai:test.example.com:1996652' diff --git a/sickle/tests/test_sickle.py b/sickle/tests/test_sickle.py index 62b75f5..4c423f6 100644 --- a/sickle/tests/test_sickle.py +++ b/sickle/tests/test_sickle.py @@ -33,66 +33,44 @@ def test_invalid_iterator(self): def test_pass_request_args(self): mock_response = Mock(text=u'', content='', status_code=200) mock_get = Mock(return_value=mock_response) - with patch('sickle.app.requests.get', mock_get): - sickle = Sickle('url', timeout=10, proxies=dict(), - auth=('user', 'password')) - sickle.ListRecords() - mock_get.assert_called_once_with('url', - params={'verb': 'ListRecords'}, - timeout=10, proxies=dict(), - auth=('user', 'password')) + sickle = Sickle('url', timeout=10, proxies=dict(), + auth=('user', 'password')) + sickle.session.get = mock_get + sickle.ListRecords() + mock_get.assert_called_once_with('url', + params={'verb': 'ListRecords'}, + timeout=10, proxies=dict(), + auth=('user', 'password')) def test_override_encoding(self): mock_response = Mock(text='', content='', status_code=200) mock_get = Mock(return_value=mock_response) - with patch('sickle.app.requests.get', mock_get): - sickle = Sickle('url', encoding='encoding') - sickle.ListSets() - mock_get.assert_called_once_with('url', - params={'verb': 'ListSets'}) + sickle = Sickle('url', encoding='encoding') + sickle.session.get = mock_get + sickle.ListSets() + mock_get.assert_called_once_with('url', + params={'verb': 'ListSets'}) def test_no_retry(self): mock_response = Mock(status_code=503, headers={'retry-after': '10'}, raise_for_status=Mock(side_effect=HTTPError)) mock_get = Mock(return_value=mock_response) - with patch('sickle.app.requests.get', mock_get): - sickle = Sickle('url') - try: - sickle.ListRecords() - except HTTPError: - pass - self.assertEqual(1, mock_get.call_count) + sickle = Sickle('url') + sickle.session.get = mock_get + try: + sickle.ListRecords() + except HTTPError: + pass + self.assertEqual(1, mock_get.call_count) - def test_retry_on_503(self): - mock_response = Mock(status_code=503, - headers={'retry-after': '10'}, - raise_for_status=Mock(side_effect=HTTPError)) - mock_get = Mock(return_value=mock_response) - sleep_mock = Mock() - with patch('time.sleep', sleep_mock): - with patch('sickle.app.requests.get', mock_get): - sickle = Sickle('url', max_retries=3, default_retry_after=0) - try: - sickle.ListRecords() - except HTTPError: - pass - mock_get.assert_called_with('url', - params={'verb': 'ListRecords'}) - self.assertEqual(4, mock_get.call_count) - self.assertEqual(3, sleep_mock.call_count) - sleep_mock.assert_called_with(10) + def test_retry_arguments(self): + sickle = Sickle('url', retry_backoff_factor=1.1234, max_retries=99, retry_status_codes=(418,)) - def test_retry_on_custom_code(self): - mock_response = Mock(status_code=500, - raise_for_status=Mock(side_effect=HTTPError)) - mock_get = Mock(return_value=mock_response) - with patch('sickle.app.requests.get', mock_get): - sickle = Sickle('url', max_retries=3, default_retry_after=0, retry_status_codes=(503, 500)) - try: - sickle.ListRecords() - except HTTPError: - pass - mock_get.assert_called_with('url', - params={'verb': 'ListRecords'}) - self.assertEqual(4, mock_get.call_count) + adapter = sickle.session.get_adapter('https://localhost/oai') + retries = adapter.max_retries + + assert retries.total == 99 + assert retries.backoff_factor == 1.1234 + assert retries.status_forcelist == (418,) + assert retries.method_whitelist == frozenset(['POST', 'GET'])