Skip to content

Commit 273c9b8

Browse files
qiaodevcopybara-github
authored andcommitted
feat: support extra_body in HttpOptions
PiperOrigin-RevId: 772616357
1 parent ab8da34 commit 273c9b8

File tree

6 files changed

+363
-1
lines changed

6 files changed

+363
-1
lines changed

google/genai/_api_client.py

Lines changed: 8 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -741,6 +741,14 @@ def _build_request(
741741
else:
742742
base_url = patched_http_options.base_url
743743

744+
if (
745+
hasattr(patched_http_options, 'extra_body')
746+
and patched_http_options.extra_body
747+
):
748+
_common.recursive_dict_update(
749+
request_dict, patched_http_options.extra_body
750+
)
751+
744752
url = _join_url_path(
745753
base_url,
746754
versioned_path,

google/genai/_common.py

Lines changed: 73 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -19,6 +19,7 @@
1919
import datetime
2020
import enum
2121
import functools
22+
import logging
2223
import typing
2324
from typing import Any, Callable, Optional, Union, get_origin, get_args
2425
import uuid
@@ -30,6 +31,7 @@
3031
from . import _api_client
3132
from . import errors
3233

34+
logger = logging.getLogger('google_genai._common')
3335

3436
def set_value_by_path(data: Optional[dict[Any, Any]], keys: list[str], value: Any) -> None:
3537
"""Examples:
@@ -365,3 +367,74 @@ def wrapper(*args: Any, **kwargs: Any) -> Any:
365367
return wrapper
366368
return decorator
367369

370+
371+
def _normalize_key_for_matching(key_str: str) -> str:
372+
"""Normalizes a key for case-insensitive and snake/camel matching."""
373+
return key_str.replace("_", "").lower()
374+
375+
376+
def align_key_case(target_dict: dict[str, Any], update_dict: dict[str, Any]) -> dict[str, Any]:
377+
"""Aligns the keys of update_dict to the case of target_dict keys.
378+
379+
Args:
380+
target_dict: The dictionary with the target key casing.
381+
update_dict: The dictionary whose keys need to be aligned.
382+
383+
Returns:
384+
A new dictionary with keys aligned to target_dict's key casing.
385+
"""
386+
aligned_update_dict: dict[str, Any] = {}
387+
target_keys_map = {_normalize_key_for_matching(key): key for key in target_dict.keys()}
388+
389+
for key, value in update_dict.items():
390+
normalized_update_key = _normalize_key_for_matching(key)
391+
392+
if normalized_update_key in target_keys_map:
393+
aligned_key = target_keys_map[normalized_update_key]
394+
else:
395+
aligned_key = key
396+
397+
if isinstance(value, dict) and isinstance(target_dict.get(aligned_key), dict):
398+
aligned_update_dict[aligned_key] = align_key_case(target_dict[aligned_key], value)
399+
elif isinstance(value, list) and isinstance(target_dict.get(aligned_key), list):
400+
# Direct assign as we treat update_dict list values as golden source.
401+
aligned_update_dict[aligned_key] = value
402+
else:
403+
aligned_update_dict[aligned_key] = value
404+
return aligned_update_dict
405+
406+
407+
def recursive_dict_update(
408+
target_dict: dict[str, Any], update_dict: dict[str, Any]
409+
) -> None:
410+
"""Recursively updates a target dictionary with values from an update dictionary.
411+
412+
We don't enforce the updated dict values to have the same type with the
413+
target_dict values except log warnings.
414+
Users providing the update_dict should be responsible for constructing correct
415+
data.
416+
417+
Args:
418+
target_dict (dict): The dictionary to be updated.
419+
update_dict (dict): The dictionary containing updates.
420+
"""
421+
# Python SDK http request may change in camel case or snake case:
422+
# If the field is directly set via setv() function, then it is camel case;
423+
# otherwise it is snake case.
424+
# Align the update_dict key case to target_dict to ensure correct dict update.
425+
aligned_update_dict = align_key_case(target_dict, update_dict)
426+
for key, value in aligned_update_dict.items():
427+
if (
428+
key in target_dict
429+
and isinstance(target_dict[key], dict)
430+
and isinstance(value, dict)
431+
):
432+
recursive_dict_update(target_dict[key], value)
433+
elif key in target_dict and not isinstance(target_dict[key], type(value)):
434+
logger.warning(
435+
f"Type mismatch for key '{key}'. Existing type:"
436+
f' {type(target_dict[key])}, new type: {type(value)}. Overwriting.'
437+
)
438+
target_dict[key] = value
439+
else:
440+
target_dict[key] = value

google/genai/tests/client/test_http_options.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,7 @@ def test_patch_http_options_with_copies_all_fields():
2929
timeout=10000,
3030
client_args={'http2': True},
3131
async_client_args={'http1': True},
32+
extra_body={'key': 'value'},
3233
)
3334
options = types.HttpOptions()
3435
patched = _api_client._patch_http_options(options, patch_options)

google/genai/tests/common/test_common.py

Lines changed: 246 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -18,6 +18,7 @@
1818

1919
import warnings
2020
import inspect
21+
import logging
2122
import typing
2223

2324
import pytest
@@ -60,4 +61,248 @@ def test_is_struct_type():
6061
assert not _common._is_struct_type(typing.List[typing.Dict[str, int]])
6162
assert not _common._is_struct_type(typing.List[typing.Dict[int, typing.Any]])
6263
assert not _common._is_struct_type(typing.List[str])
63-
assert not _common._is_struct_type(typing.Dict[str, typing.Any])
64+
assert not _common._is_struct_type(typing.Dict[str, typing.Any])
65+
66+
67+
68+
@pytest.mark.parametrize(
69+
"test_id, initial_target, update_dict, expected_target",
70+
[
71+
(
72+
"simple_update",
73+
{"a": 1, "b": 2},
74+
{"b": 3, "c": 4},
75+
{"a": 1, "b": 3, "c": 4},
76+
),
77+
(
78+
"nested_update",
79+
{"a": 1, "b": {"x": 10, "y": 20}},
80+
{"b": {"y": 30, "z": 40}, "c": 3},
81+
{"a": 1, "b": {"x": 10, "y": 30, "z": 40}, "c": 3},
82+
),
83+
(
84+
"add_new_nested_dict",
85+
{"a": 1},
86+
{"b": {"x": 10, "y": 20}},
87+
{"a": 1, "b": {"x": 10, "y": 20}},
88+
),
89+
(
90+
"empty_target",
91+
{},
92+
{"a": 1, "b": {"x": 10}},
93+
{"a": 1, "b": {"x": 10}},
94+
),
95+
(
96+
"empty_update",
97+
{"a": 1, "b": {"x": 10}},
98+
{},
99+
{"a": 1, "b": {"x": 10}},
100+
),
101+
(
102+
"overwrite_non_dict_with_dict",
103+
{"a": 1, "b": 2},
104+
{"b": {"x": 10}},
105+
{"a": 1, "b": {"x": 10}},
106+
),
107+
(
108+
"overwrite_dict_with_non_dict",
109+
{"a": 1, "b": {"x": 10}},
110+
{"b": 2},
111+
{"a": 1, "b": 2},
112+
),
113+
(
114+
"deeper_nesting",
115+
{"a": {"b": {"c": 1, "d": 2}, "e": 3}},
116+
{"a": {"b": {"d": 4, "f": 5}, "g": 6}, "h": 7},
117+
{"a": {"b": {"c": 1, "d": 4, "f": 5}, "e": 3, "g": 6}, "h": 7},
118+
),
119+
(
120+
"different_value_types",
121+
{"key1": "string_val", "key2": {"nested_int": 100}},
122+
{"key1": 123, "key2": {"nested_list": [1, 2, 3]}, "key3": True},
123+
{
124+
"key1": 123,
125+
"key2": {"nested_int": 100, "nested_list": [1, 2, 3]},
126+
"key3": True,
127+
},
128+
),
129+
(
130+
"update_with_empty_nested_dict", # Existing nested dict in target should not be cleared
131+
{"a": {"b": 1}},
132+
{"a": {}},
133+
{"a": {"b": 1}},
134+
),
135+
(
136+
"target_with_empty_nested_dict",
137+
{"a": {}},
138+
{"a": {"b": 1}},
139+
{"a": {"b": 1}},
140+
),
141+
(
142+
"key_case_alignment_check",
143+
{"first_name": "John", "contact_info": {"email_address": "[email protected]"}},
144+
{"firstName": "Jane", "contact_info": {"email_address": "[email protected]", "phone_number": "123"}},
145+
{"first_name": "Jane", "contact_info": {"email_address": "[email protected]", "phone_number": "123"}},
146+
)
147+
],
148+
)
149+
def test_recursive_dict_update(
150+
test_id: str, initial_target: dict, update_dict: dict, expected_target: dict
151+
):
152+
_common.recursive_dict_update(initial_target, update_dict)
153+
assert initial_target == expected_target
154+
155+
156+
@pytest.mark.parametrize(
157+
"test_id, initial_target, update_dict, expected_target, expect_warning, expected_log_message_part",
158+
[
159+
(
160+
"type_match_int",
161+
{"a": 1},
162+
{"a": 2},
163+
{"a": 2},
164+
False,
165+
"",
166+
),
167+
(
168+
"type_match_dict",
169+
{"a": {"b": 1}},
170+
{"a": {"b": 2}},
171+
{"a": {"b": 2}},
172+
False,
173+
"",
174+
),
175+
(
176+
"type_mismatch_int_to_str",
177+
{"a": 1},
178+
{"a": "hello"},
179+
{"a": "hello"},
180+
True,
181+
"Type mismatch for key 'a'. Existing type: <class 'int'>, new type: <class 'str'>. Overwriting.",
182+
),
183+
(
184+
"type_mismatch_dict_to_int",
185+
{"a": {"b": 1}},
186+
{"a": 100},
187+
{"a": 100},
188+
True,
189+
"Type mismatch for key 'a'. Existing type: <class 'dict'>, new type: <class 'int'>. Overwriting.",
190+
),
191+
(
192+
"type_mismatch_int_to_dict",
193+
{"a": 100},
194+
{"a": {"b": 1}},
195+
{"a": {"b": 1}},
196+
True,
197+
"Type mismatch for key 'a'. Existing type: <class 'int'>, new type: <class 'dict'>. Overwriting.",
198+
),
199+
("add_new_key", {"a": 1}, {"b": "new"}, {"a": 1, "b": "new"}, False, ""),
200+
],
201+
)
202+
def test_recursive_dict_update_type_warnings(test_id, initial_target, update_dict, expected_target, expect_warning, expected_log_message_part, caplog):
203+
_common.recursive_dict_update(initial_target, update_dict)
204+
assert initial_target == expected_target
205+
if expect_warning:
206+
assert len(caplog.records) == 1
207+
assert caplog.records[0].levelname == "WARNING"
208+
assert expected_log_message_part in caplog.records[0].message
209+
else:
210+
for record in caplog.records:
211+
if record.levelname == "WARNING" and expected_log_message_part in record.message:
212+
pytest.fail(f"Unexpected warning logged for {test_id}: {record.message}")
213+
214+
215+
@pytest.mark.parametrize(
216+
"test_id, target_dict, update_dict, expected_aligned_dict",
217+
[
218+
(
219+
"simple_snake_to_camel",
220+
{"first_name": "John", "last_name": "Doe"},
221+
{"firstName": "Jane", "lastName": "Doe"},
222+
{"first_name": "Jane", "last_name": "Doe"},
223+
),
224+
(
225+
"simple_camel_to_snake",
226+
{"firstName": "John", "lastName": "Doe"},
227+
{"first_name": "Jane", "last_name": "Doe"},
228+
{"firstName": "Jane", "lastName": "Doe"},
229+
),
230+
(
231+
"nested_dict_alignment",
232+
{"user_info": {"contact_details": {"email_address": ""}}},
233+
{"userInfo": {"contactDetails": {"emailAddress": "[email protected]"}}},
234+
{"user_info": {"contact_details": {"email_address": "[email protected]"}}},
235+
),
236+
(
237+
"list_of_dicts_alignment",
238+
{"users_list": [{"user_id": 0, "user_name": ""}]},
239+
{"usersList": [{"userId": 1, "userName": "Alice"}]},
240+
{"users_list": [{"userId": 1, "userName": "Alice"}]},
241+
),
242+
(
243+
"list_of_dicts_alignment_mixed_case_in_update",
244+
{"users_list": [{"user_id": 0, "user_name": ""}]},
245+
{"usersList": [{"user_id": 1, "UserName": "Alice"}]},
246+
{"users_list": [{"user_id": 1, "UserName": "Alice"}]},
247+
),
248+
(
249+
"list_of_dicts_different_lengths_update_longer",
250+
{"items_data": [{"item_id": 0}]},
251+
{"itemsData": [{"itemId": 1}, {"item_id": 2, "itemName": "Extra"}]},
252+
{"items_data": [{"itemId": 1}, {"item_id": 2, "itemName": "Extra"}]},
253+
),
254+
(
255+
"list_of_dicts_different_lengths_target_longer",
256+
{"items_data": [{"item_id": 0, "item_name": ""}, {"item_id": 1}]},
257+
{"itemsData": [{"itemId": 10}]},
258+
{"items_data": [{"itemId": 10}]},
259+
),
260+
(
261+
"no_matching_keys_preserves_update_case",
262+
{"key_one": 1},
263+
{"KEY_TWO": 2, "keyThree": 3},
264+
{"KEY_TWO": 2, "keyThree": 3},
265+
),
266+
(
267+
"mixed_match_and_no_match",
268+
{"first_name": "John", "age_years": 30},
269+
{"firstName": "Jane", "AGE_YEARS": 28, "occupation_title": "Engineer"},
270+
{"first_name": "Jane", "age_years": 28, "occupation_title": "Engineer"},
271+
),
272+
(
273+
"empty_target_dict",
274+
{},
275+
{"new_key": "new_value", "anotherKey": "anotherValue"},
276+
{"new_key": "new_value", "anotherKey": "anotherValue"},
277+
),
278+
(
279+
"empty_update_dict",
280+
{"existing_key": "value"},
281+
{},
282+
{},
283+
),
284+
(
285+
"target_has_non_dict_value_for_nested_key",
286+
{"config_settings": 123},
287+
{"configSettings": {"themeName": "dark"}},
288+
{"config_settings": {"themeName": "dark"}}, # Overwrites as per recursive_dict_update logic
289+
),
290+
(
291+
"update_has_non_dict_value_for_nested_key",
292+
{"config_settings": {"theme_name": "light"}},
293+
{"configSettings": "dark_theme_string"},
294+
{"config_settings": "dark_theme_string"}, # Overwrites
295+
),
296+
(
297+
"deeply_nested_with_lists",
298+
{"level_one": {"list_items": [{"item_name": "", "item_value": 0}]}},
299+
{"levelOne": {"listItems": [{"itemName": "Test", "itemValue": 100}, {"itemName": "Test2", "itemValue": 200}]}},
300+
{"level_one": {"list_items": [{"itemName": "Test", "itemValue": 100}, {"itemName": "Test2", "itemValue": 200}]}},
301+
),
302+
],
303+
)
304+
def test_align_key_case(
305+
test_id: str, target_dict: dict, update_dict: dict, expected_aligned_dict: dict
306+
):
307+
aligned_dict = _common.align_key_case(target_dict, update_dict)
308+
assert aligned_dict == expected_aligned_dict, f"Test failed for: {test_id}"

0 commit comments

Comments
 (0)