|
1 |
| -from unittest.mock import MagicMock, patch |
| 1 | +import asyncio |
| 2 | +from typing import Any |
| 3 | +from unittest.mock import AsyncMock, MagicMock, patch |
2 | 4 |
|
3 | 5 | from okta.models.group_rule import GroupRule as OktaGroupRuleType
|
4 | 6 |
|
5 |
| -from api.services.okta_service import is_managed_group |
| 7 | +from api.services.okta_service import OktaService, is_managed_group |
6 | 8 |
|
7 | 9 |
|
8 | 10 | def test_is_managed_group_with_allow_discord_access_false() -> None:
|
@@ -60,3 +62,64 @@ def test_is_managed_group_with_allow_discord_access_undefined() -> None:
|
60 | 62 | # Call the function and assert the expected result
|
61 | 63 | result = is_managed_group(group, group_ids_with_group_rules, OKTA_GROUP_PROFILE_CUSTOM_ATTR)
|
62 | 64 | assert result is True
|
| 65 | + |
| 66 | + |
| 67 | +def test_update_group_preserves_custom_attributes() -> None: |
| 68 | + """Test that update_group preserves custom attributes when updating a group.""" |
| 69 | + # Create a new event loop for this test |
| 70 | + loop = asyncio.new_event_loop() |
| 71 | + asyncio.set_event_loop(loop) |
| 72 | + |
| 73 | + # Mock asyncio.run to use our event loop |
| 74 | + with patch("asyncio.run", side_effect=lambda x: loop.run_until_complete(x)): |
| 75 | + # Create OktaService instance with mock client |
| 76 | + service = OktaService() |
| 77 | + service.okta_client = MagicMock() |
| 78 | + |
| 79 | + # Set up the mocks for the existing group and the update call |
| 80 | + group_id = "test-group-id" |
| 81 | + |
| 82 | + # Create a mock group with a profile that has the custom attribute |
| 83 | + existing_group = MagicMock() |
| 84 | + # Instead of setting __dict__ directly, configure the mock properly |
| 85 | + existing_group.profile = MagicMock() |
| 86 | + existing_group.profile.name = "Old Name" |
| 87 | + existing_group.profile.description = "Old Description" |
| 88 | + existing_group.profile.allow_discord_access = True |
| 89 | + |
| 90 | + # Mock the get_group and update_group methods |
| 91 | + service.okta_client.get_group = AsyncMock(return_value=(existing_group, None, None)) |
| 92 | + service.okta_client.update_group = AsyncMock(return_value=(MagicMock(), None, None)) |
| 93 | + |
| 94 | + # Create a mock for the SessionedOktaRequestExecutor context manager |
| 95 | + # This is a special class that implements the async context manager protocol |
| 96 | + class MockSessionedExecutor: |
| 97 | + """Mock class for SessionedOktaRequestExecutor with async context manager methods""" |
| 98 | + |
| 99 | + async def __aenter__(self) -> None: |
| 100 | + """Mock for __aenter__ - called when entering an 'async with' block""" |
| 101 | + return None |
| 102 | + |
| 103 | + async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None: |
| 104 | + """Mock for __aexit__ - called when exiting an 'async with' block""" |
| 105 | + return None |
| 106 | + |
| 107 | + # Create a mock context manager instance |
| 108 | + mock_executor = MockSessionedExecutor() |
| 109 | + |
| 110 | + # Use patch to mock the _get_sessioned_okta_request_executor method |
| 111 | + # This avoids directly assigning to the method, which mypy doesn't like |
| 112 | + with patch.object(service, "_get_sessioned_okta_request_executor", return_value=mock_executor): |
| 113 | + # Call update_group |
| 114 | + service.update_group(group_id, "New Name", "New Description") |
| 115 | + |
| 116 | + # Verify update_group was called with a payload that preserved the custom attribute |
| 117 | + args, _ = service.okta_client.update_group.call_args |
| 118 | + assert len(args) == 2 |
| 119 | + assert args[0] == group_id |
| 120 | + |
| 121 | + # Check that the payload contains both the updated fields and the preserved custom attribute |
| 122 | + updated_payload = args[1] |
| 123 | + assert updated_payload.profile.name == "New Name" |
| 124 | + assert updated_payload.profile.description == "New Description" |
| 125 | + assert updated_payload.profile.allow_discord_access is True |
0 commit comments