Skip to content

Commit 4ab196e

Browse files
Francesco MucioFrancesco Mucio
authored andcommitted
rebased on 1.8.1
1 parent 491465f commit 4ab196e

File tree

7 files changed

+236
-78
lines changed

7 files changed

+236
-78
lines changed

dlt/sources/helpers/rest_client/client.py

Lines changed: 13 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -100,6 +100,7 @@ def _create_request(
100100
path_or_url: str,
101101
method: HTTPMethod,
102102
params: Optional[Dict[str, Any]] = None,
103+
headers: Optional[Dict[str, Any]] = None,
103104
json: Optional[Dict[str, Any]] = None,
104105
auth: Optional[AuthBase] = None,
105106
hooks: Optional[Hooks] = None,
@@ -110,10 +111,12 @@ def _create_request(
110111
else:
111112
url = join_url(self.base_url, path_or_url)
112113

114+
request_headers = (self.headers or {}) | (headers or {})
115+
113116
return Request(
114117
method=method,
115118
url=url,
116-
headers=self.headers,
119+
headers=request_headers,
117120
params=params,
118121
json=json,
119122
auth=auth or self.auth,
@@ -144,6 +147,7 @@ def request(self, path: str = "", method: HTTPMethod = "GET", **kwargs: Any) ->
144147
path_or_url=path,
145148
method=method,
146149
params=kwargs.pop("params", None),
150+
headers=kwargs.pop("headers", None),
147151
json=kwargs.pop("json", None),
148152
auth=kwargs.pop("auth", None),
149153
hooks=kwargs.pop("hooks", None),
@@ -161,6 +165,7 @@ def paginate(
161165
path: str = "",
162166
method: HTTPMethodBasic = "GET",
163167
params: Optional[Dict[str, Any]] = None,
168+
headers: Optional[Dict[str, Any]] = None,
164169
json: Optional[Dict[str, Any]] = None,
165170
auth: Optional[AuthBase] = None,
166171
paginator: Optional[BasePaginator] = None,
@@ -213,7 +218,13 @@ def paginate(
213218
hooks["response"] = [raise_for_status]
214219

215220
request = self._create_request(
216-
path_or_url=path, method=method, params=params, json=json, auth=auth, hooks=hooks
221+
path_or_url=path,
222+
headers=headers,
223+
method=method,
224+
params=params,
225+
json=json,
226+
auth=auth,
227+
hooks=hooks,
217228
)
218229

219230
if paginator:

dlt/sources/rest_api/__init__.py

Lines changed: 32 additions & 14 deletions
Original file line numberDiff line numberDiff line change
@@ -1,4 +1,5 @@
11
"""Generic API Source"""
2+
23
from copy import deepcopy
34
from typing import Any, Dict, List, Optional, Generator, Callable, cast, Union
45
import graphlib
@@ -70,7 +71,11 @@ def rest_api(
7071
) -> List[DltResource]:
7172
"""Creates and configures a REST API source with default settings"""
7273
return rest_api_resources(
73-
{"client": client, "resources": resources, "resource_defaults": resource_defaults}
74+
{
75+
"client": client,
76+
"resources": resources,
77+
"resource_defaults": resource_defaults,
78+
}
7479
)
7580

7681

@@ -242,6 +247,7 @@ def create_resources(
242247
endpoint_config = cast(Endpoint, endpoint_resource["endpoint"])
243248
request_params = endpoint_config.get("params", {})
244249
request_json = endpoint_config.get("json", None)
250+
request_headers = endpoint_config.get("headers")
245251
paginator = create_paginator(endpoint_config.get("paginator"))
246252
processing_steps = endpoint_resource.pop("processing_steps", [])
247253

@@ -288,6 +294,7 @@ def process(
288294
def paginate_resource(
289295
method: HTTPMethodBasic,
290296
path: str,
297+
headers: Dict[str, Any],
291298
params: Dict[str, Any],
292299
json: Optional[Dict[str, Any]],
293300
paginator: Optional[BasePaginator],
@@ -323,6 +330,7 @@ def paginate_resource(
323330
yield from client.paginate(
324331
method=method,
325332
path=path,
333+
headers=headers,
326334
params=params,
327335
json=json,
328336
paginator=paginator,
@@ -336,6 +344,7 @@ def paginate_resource(
336344
)(
337345
method=endpoint_config.get("method", "get"),
338346
path=endpoint_config.get("path"),
347+
headers=request_headers,
339348
params=request_params,
340349
json=request_json,
341350
paginator=paginator,
@@ -355,6 +364,7 @@ def paginate_dependent_resource(
355364
items: List[Dict[str, Any]],
356365
method: HTTPMethodBasic,
357366
path: str,
367+
request_headers: Optional[Dict[str, Any]],
358368
params: Dict[str, Any],
359369
json: Optional[Dict[str, Any]],
360370
paginator: Optional[BasePaginator],
@@ -378,23 +388,29 @@ def paginate_dependent_resource(
378388
)
379389

380390
for item in items:
381-
formatted_path, expanded_params, updated_json, parent_record = (
382-
process_parent_data_item(
383-
path=path,
384-
item=item,
385-
params=params,
386-
request_json=json,
387-
resolved_params=resolved_params,
388-
include_from_parent=include_from_parent,
389-
incremental=incremental_object,
390-
incremental_value_convert=incremental_cursor_transform,
391-
)
391+
(
392+
formatted_path,
393+
expanded_params,
394+
updated_json,
395+
updated_headers,
396+
parent_record,
397+
) = process_parent_data_item(
398+
path=path,
399+
item=item,
400+
params=params,
401+
request_headers=request_headers,
402+
request_json=json,
403+
resolved_params=resolved_params,
404+
include_from_parent=include_from_parent,
405+
incremental=incremental_object,
406+
incremental_value_convert=incremental_cursor_transform,
392407
)
393408

394409
for child_page in client.paginate(
395410
method=method,
396411
path=formatted_path,
397412
params=expanded_params,
413+
headers=updated_headers,
398414
json=updated_json,
399415
paginator=paginator,
400416
data_selector=data_selector,
@@ -413,6 +429,7 @@ def paginate_dependent_resource(
413429
method=endpoint_config.get("method", "get"),
414430
path=endpoint_config.get("path"),
415431
params=base_params,
432+
request_headers=request_headers,
416433
json=request_json,
417434
paginator=paginator,
418435
data_selector=endpoint_config.get("data_selector"),
@@ -456,7 +473,8 @@ def _mask_secrets(auth_config: AuthConfig) -> AuthConfig:
456473
has_sensitive_key = any(key in auth_config for key in SENSITIVE_KEYS)
457474
if (
458475
isinstance(
459-
auth_config, (APIKeyAuth, BearerTokenAuth, HttpBasicAuth, OAuth2ClientCredentials)
476+
auth_config,
477+
(APIKeyAuth, BearerTokenAuth, HttpBasicAuth, OAuth2ClientCredentials),
460478
)
461479
or has_sensitive_key
462480
):
@@ -503,7 +521,7 @@ def identity_func(x: Any) -> Any:
503521

504522

505523
def _validate_param_type(
506-
request_params: Dict[str, Union[ResolveParamConfig, IncrementalParamConfig, Any]]
524+
request_params: Dict[str, Union[ResolveParamConfig, IncrementalParamConfig, Any]],
507525
) -> None:
508526
for _, value in request_params.items():
509527
if isinstance(value, dict) and value.get("type") not in PARAM_TYPES:

dlt/sources/rest_api/config_setup.py

Lines changed: 24 additions & 6 deletions
Original file line numberDiff line numberDiff line change
@@ -264,7 +264,7 @@ def setup_incremental_object(
264264

265265

266266
def parse_convert_or_deprecated_transform(
267-
config: Union[IncrementalConfig, Dict[str, Any]]
267+
config: Union[IncrementalConfig, Dict[str, Any]],
268268
) -> Optional[Callable[..., Any]]:
269269
convert = config.get("convert", None)
270270
deprecated_transform = config.get("transform", None)
@@ -317,15 +317,20 @@ def build_resource_dependency_graph(
317317
endpoint_resource["endpoint"]["path"], available_contexts
318318
)
319319

320-
# Find all expressions in params and json, but error if any of them is not in available_contexts
320+
# Find all expressions in params, json, or header, but error if any of them is not in available_contexts
321321
params_expressions = _find_expressions(endpoint_resource["endpoint"].get("params", {}))
322322
_raise_if_any_not_in(params_expressions, available_contexts, message="params")
323323

324324
json_expressions = _find_expressions(endpoint_resource["endpoint"].get("json", {}))
325325
_raise_if_any_not_in(json_expressions, available_contexts, message="json")
326326

327+
headers_expressions = _find_expressions(endpoint_resource["endpoint"].get("headers", {}))
328+
_raise_if_any_not_in(headers_expressions, available_contexts, message="headers")
329+
327330
resolved_params += _expressions_to_resolved_params(
328-
_filter_resource_expressions(path_expressions | params_expressions | json_expressions)
331+
_filter_resource_expressions(
332+
path_expressions | params_expressions | json_expressions | headers_expressions
333+
)
329334
)
330335

331336
# set of resources in resolved params
@@ -723,11 +728,12 @@ def process_parent_data_item(
723728
item: Dict[str, Any],
724729
resolved_params: List[ResolvedParam],
725730
params: Optional[Dict[str, Any]] = None,
731+
request_headers: Optional[Dict[str, Any]] = None,
726732
request_json: Optional[Dict[str, Any]] = None,
727733
include_from_parent: Optional[List[str]] = None,
728734
incremental: Optional[Incremental[Any]] = None,
729735
incremental_value_convert: Optional[Callable[..., Any]] = None,
730-
) -> Tuple[str, Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
736+
) -> Tuple[str, Dict[str, Any], Dict[str, Any], Dict[str, Any], Dict[str, Any]]:
731737
params_values = collect_resolved_values(
732738
item, resolved_params, incremental, incremental_value_convert
733739
)
@@ -737,10 +743,20 @@ def process_parent_data_item(
737743
None if request_json is None else expand_placeholders(request_json, params_values)
738744
)
739745

746+
expanded_headers = (
747+
None if request_headers is None else expand_placeholders(request_headers, params_values)
748+
)
749+
740750
parent_resource_name = resolved_params[0].resolve_config["resource"]
741751
parent_record = build_parent_record(item, parent_resource_name, include_from_parent)
742752

743-
return expanded_path, expanded_params, expanded_json, parent_record
753+
return (
754+
expanded_path,
755+
expanded_params,
756+
expanded_json,
757+
expanded_headers,
758+
parent_record,
759+
)
744760

745761

746762
def convert_incremental_values(
@@ -819,7 +835,9 @@ def expand_placeholders(obj: Any, placeholders: Dict[str, Any]) -> Any:
819835

820836

821837
def build_parent_record(
822-
item: Dict[str, Any], parent_resource_name: str, include_from_parent: Optional[List[str]]
838+
item: Dict[str, Any],
839+
parent_resource_name: str,
840+
include_from_parent: Optional[List[str]],
823841
) -> Dict[str, Any]:
824842
"""
825843
Builds a dictionary of the `include_from_parent` fields from the parent,

dlt/sources/rest_api/typing.py

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -249,6 +249,7 @@ class Endpoint(TypedDict, total=False):
249249
response_actions: Optional[List[ResponseAction]]
250250
incremental: Optional[IncrementalConfig]
251251
auth: Optional[AuthConfig]
252+
headers: Optional[Dict[str, Any]]
252253

253254

254255
class ProcessingSteps(TypedDict):

0 commit comments

Comments
 (0)