Skip to content

Commit 8327347

Browse files
authored
Preserve custom attributes on Discord Access Okta Group UI updates (#249)
1 parent 3fb4f5a commit 8327347

File tree

3 files changed

+95
-3
lines changed

3 files changed

+95
-3
lines changed

api/services/okta_service.py

Lines changed: 29 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -164,10 +164,38 @@ async def _create_group(name: str, description: str) -> Group:
164164
def update_group(self, groupId: str, name: str, description: str) -> Group:
165165
async def _update_group(groupId: str, name: str, description: str) -> Group:
166166
async with self._get_sessioned_okta_request_executor() as _:
167+
# Fetch Existing Group Data
168+
existing_group_data, _, get_error = await OktaService._retry(self.okta_client.get_group, groupId)
169+
if get_error is not None:
170+
logger.error(f"Failed to fetch existing group {groupId} before update: {get_error}")
171+
raise Exception(f"Failed to fetch existing group {groupId} before update: {get_error}")
172+
if existing_group_data is None:
173+
logger.error(f"Group {groupId} not found in Okta before update.")
174+
raise Exception(f"Group {groupId} not found in Okta before update.")
175+
176+
# Extract Existing Profile
177+
existing_profile = {}
178+
if existing_group_data.profile:
179+
# Using __dict__ can be fragile if Okta changes internal representation.
180+
# If specific profile attributes are known, accessing them directly might be safer.
181+
# Filter out None values if needed by Okta API or desired.
182+
existing_profile = {k: v for k, v in existing_group_data.profile.__dict__.items() if v is not None}
183+
184+
# Merge Updated Profile Data
185+
new_profile = {**existing_profile} # Start with a copy of the existing profile
186+
new_profile["name"] = name # Update/set the name
187+
new_profile["description"] = (
188+
description if description is not None else ""
189+
) # Update/set the description (handle None)
190+
191+
# Construct the New Payload
192+
group_payload = OktaGroupType({"profile": new_profile})
193+
194+
# Modify the Update Call to use the new payload
167195
group, _, error = await OktaService._retry(
168196
self.okta_client.update_group,
169197
groupId,
170-
OktaGroupType({"profile": {"name": name, "description": description}}),
198+
group_payload,
171199
)
172200

173201
if error is not None:

tests/factories.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ class GroupProfileFactory(factory.Factory[GroupProfile]):
2929
{
3030
"name": factory.Faker("pystr"),
3131
"description": factory.Faker("pystr"),
32+
"allow_discord_access": None, # Default to None for the custom attribute
3233
}
3334
)
3435

tests/test_okta_service.py

Lines changed: 65 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,10 @@
1-
from unittest.mock import MagicMock, patch
1+
import asyncio
2+
from typing import Any
3+
from unittest.mock import AsyncMock, MagicMock, patch
24

35
from okta.models.group_rule import GroupRule as OktaGroupRuleType
46

5-
from api.services.okta_service import is_managed_group
7+
from api.services.okta_service import OktaService, is_managed_group
68

79

810
def test_is_managed_group_with_allow_discord_access_false() -> None:
@@ -60,3 +62,64 @@ def test_is_managed_group_with_allow_discord_access_undefined() -> None:
6062
# Call the function and assert the expected result
6163
result = is_managed_group(group, group_ids_with_group_rules, OKTA_GROUP_PROFILE_CUSTOM_ATTR)
6264
assert result is True
65+
66+
67+
def test_update_group_preserves_custom_attributes() -> None:
68+
"""Test that update_group preserves custom attributes when updating a group."""
69+
# Create a new event loop for this test
70+
loop = asyncio.new_event_loop()
71+
asyncio.set_event_loop(loop)
72+
73+
# Mock asyncio.run to use our event loop
74+
with patch("asyncio.run", side_effect=lambda x: loop.run_until_complete(x)):
75+
# Create OktaService instance with mock client
76+
service = OktaService()
77+
service.okta_client = MagicMock()
78+
79+
# Set up the mocks for the existing group and the update call
80+
group_id = "test-group-id"
81+
82+
# Create a mock group with a profile that has the custom attribute
83+
existing_group = MagicMock()
84+
# Instead of setting __dict__ directly, configure the mock properly
85+
existing_group.profile = MagicMock()
86+
existing_group.profile.name = "Old Name"
87+
existing_group.profile.description = "Old Description"
88+
existing_group.profile.allow_discord_access = True
89+
90+
# Mock the get_group and update_group methods
91+
service.okta_client.get_group = AsyncMock(return_value=(existing_group, None, None))
92+
service.okta_client.update_group = AsyncMock(return_value=(MagicMock(), None, None))
93+
94+
# Create a mock for the SessionedOktaRequestExecutor context manager
95+
# This is a special class that implements the async context manager protocol
96+
class MockSessionedExecutor:
97+
"""Mock class for SessionedOktaRequestExecutor with async context manager methods"""
98+
99+
async def __aenter__(self) -> None:
100+
"""Mock for __aenter__ - called when entering an 'async with' block"""
101+
return None
102+
103+
async def __aexit__(self, exc_type: Any, exc_val: Any, exc_tb: Any) -> None:
104+
"""Mock for __aexit__ - called when exiting an 'async with' block"""
105+
return None
106+
107+
# Create a mock context manager instance
108+
mock_executor = MockSessionedExecutor()
109+
110+
# Use patch to mock the _get_sessioned_okta_request_executor method
111+
# This avoids directly assigning to the method, which mypy doesn't like
112+
with patch.object(service, "_get_sessioned_okta_request_executor", return_value=mock_executor):
113+
# Call update_group
114+
service.update_group(group_id, "New Name", "New Description")
115+
116+
# Verify update_group was called with a payload that preserved the custom attribute
117+
args, _ = service.okta_client.update_group.call_args
118+
assert len(args) == 2
119+
assert args[0] == group_id
120+
121+
# Check that the payload contains both the updated fields and the preserved custom attribute
122+
updated_payload = args[1]
123+
assert updated_payload.profile.name == "New Name"
124+
assert updated_payload.profile.description == "New Description"
125+
assert updated_payload.profile.allow_discord_access is True

0 commit comments

Comments
 (0)