|
14 | 14 | import time
|
15 | 15 | import urllib
|
16 | 16 | import uuid
|
17 |
| -from typing import Dict, Optional |
| 17 | +from contextlib import nullcontext as does_not_raise |
| 18 | +from typing import Any, Dict, Optional |
18 | 19 | from unittest import mock
|
19 | 20 | from urllib.parse import urlparse
|
20 | 21 |
|
| 22 | +import gssapi |
21 | 23 | import httpretty
|
22 | 24 | import pytest
|
23 | 25 | import requests
|
@@ -865,6 +867,73 @@ def test_extra_credential_value_encoding(mock_get_and_post):
|
865 | 867 | assert headers[constants.HEADER_EXTRA_CREDENTIAL] == "foo=bar+%E7%9A%84"
|
866 | 868 |
|
867 | 869 |
|
| 870 | +class MockGssapiCredentials: |
| 871 | + def __init__(self, name: gssapi.Name, usage: str): |
| 872 | + self.name = name |
| 873 | + self.usage = usage |
| 874 | + |
| 875 | + def __eq__(self, other: Any) -> bool: |
| 876 | + if not isinstance(other, MockGssapiCredentials): |
| 877 | + return False |
| 878 | + return ( |
| 879 | + self.name == other.name, |
| 880 | + self.usage == other.usage, |
| 881 | + ) |
| 882 | + |
| 883 | + |
| 884 | +@pytest.fixture |
| 885 | +def mock_gssapi_creds(monkeypatch): |
| 886 | + monkeypatch.setattr("gssapi.Credentials", MockGssapiCredentials) |
| 887 | + |
| 888 | + |
| 889 | +def _gssapi_uname(spn: str): |
| 890 | + return gssapi.Name(spn, gssapi.NameType.user) |
| 891 | + |
| 892 | + |
| 893 | +def _gssapi_sname(principal: str): |
| 894 | + return gssapi.Name(principal, gssapi.NameType.hostbased_service) |
| 895 | + |
| 896 | + |
| 897 | +@pytest.mark.parametrize( |
| 898 | + "options, expected_credentials, expected_hostname, expected_exception", |
| 899 | + [ |
| 900 | + ( |
| 901 | + {}, None, None, does_not_raise(), |
| 902 | + ), |
| 903 | + ( |
| 904 | + {"hostname_override": "foo"}, None, "foo", does_not_raise(), |
| 905 | + ), |
| 906 | + ( |
| 907 | + {"service_name": "bar"}, None, None, |
| 908 | + pytest.raises(ValueError, match=r"must be used together with hostname_override"), |
| 909 | + ), |
| 910 | + ( |
| 911 | + {"hostname_override": "foo", "service_name": "bar"}, None, _gssapi_sname("bar@foo"), does_not_raise(), |
| 912 | + ), |
| 913 | + ( |
| 914 | + {"principal": "foo"}, MockGssapiCredentials(_gssapi_uname("foo"), "initial"), None, does_not_raise(), |
| 915 | + ), |
| 916 | + ] |
| 917 | +) |
| 918 | +def test_authentication_gssapi_init_arguments( |
| 919 | + options, |
| 920 | + expected_credentials, |
| 921 | + expected_hostname, |
| 922 | + expected_exception, |
| 923 | + mock_gssapi_creds, |
| 924 | + monkeypatch, |
| 925 | +): |
| 926 | + auth = GSSAPIAuthentication(**options) |
| 927 | + |
| 928 | + session = requests.Session() |
| 929 | + |
| 930 | + with expected_exception: |
| 931 | + auth.set_http_session(session) |
| 932 | + |
| 933 | + assert session.auth.target_name == expected_hostname |
| 934 | + assert session.auth.creds == expected_credentials |
| 935 | + |
| 936 | + |
868 | 937 | class RetryRecorder(object):
|
869 | 938 | def __init__(self, error=None, result=None):
|
870 | 939 | self.__name__ = "RetryRecorder"
|
|
0 commit comments