Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

FEAT: parameterised headers rest_api_source #2084

Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 8 additions & 1 deletion dlt/common/libs/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -628,7 +628,14 @@ def row_tuples_to_arrow(
" extracting an SQL VIEW that selects with cast."
)
json_str_array = pa.array(
[None if s is None else json.dumps(s) if not issubclass(type(s), set) else json.dumps(list(s)) for s in columnar_known_types[field.name]]
[
(
None
if s is None
else json.dumps(s) if not issubclass(type(s), set) else json.dumps(list(s))
)
for s in columnar_known_types[field.name]
]
)
columnar_known_types[field.name] = json_str_array

Expand Down
19 changes: 17 additions & 2 deletions dlt/sources/helpers/rest_client/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -100,6 +100,7 @@ def _create_request(
path_or_url: str,
method: HTTPMethod,
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, Any]] = None,
json: Optional[Dict[str, Any]] = None,
auth: Optional[AuthBase] = None,
hooks: Optional[Hooks] = None,
Expand All @@ -110,10 +111,14 @@ def _create_request(
else:
url = join_url(self.base_url, path_or_url)

headers = headers or {}
if self.headers:
headers.update(self.headers)

return Request(
method=method,
url=url,
headers=self.headers,
headers=headers,
params=params,
json=json,
auth=auth or self.auth,
Expand All @@ -124,6 +129,7 @@ def _send_request(self, request: Request, **kwargs: Any) -> Response:
logger.info(
f"Making {request.method.upper()} request to {request.url}"
f" with params={request.params}, json={request.json}"
f" with headers={request.headers}"
)

prepared_request = self.session.prepare_request(request)
Expand All @@ -143,6 +149,7 @@ def request(self, path: str = "", method: HTTPMethod = "GET", **kwargs: Any) ->
prepared_request = self._create_request(
path_or_url=path,
method=method,
headers=kwargs.pop("headers", None),
params=kwargs.pop("params", None),
json=kwargs.pop("json", None),
auth=kwargs.pop("auth", None),
Expand All @@ -161,6 +168,7 @@ def paginate(
path: str = "",
method: HTTPMethodBasic = "GET",
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, Any]] = None,
json: Optional[Dict[str, Any]] = None,
auth: Optional[AuthBase] = None,
paginator: Optional[BasePaginator] = None,
Expand All @@ -176,6 +184,7 @@ def paginate(
be used instead of the base_url + path.
method (HTTPMethodBasic): HTTP method for the request, defaults to 'get'.
params (Optional[Dict[str, Any]]): URL parameters for the request.
headers (Optional[Dict[str, Any]]): Headers for the request.
json (Optional[Dict[str, Any]]): JSON payload for the request.
auth (Optional[AuthBase): Authentication configuration for the request.
paginator (Optional[BasePaginator]): Paginator instance for handling
Expand Down Expand Up @@ -213,7 +222,13 @@ def paginate(
hooks["response"] = [raise_for_status]

request = self._create_request(
path_or_url=path, method=method, params=params, json=json, auth=auth, hooks=hooks
path_or_url=path,
headers=headers,
method=method,
params=params,
json=json,
auth=auth,
hooks=hooks,
)

if paginator:
Expand Down
10 changes: 8 additions & 2 deletions dlt/sources/rest_api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -291,6 +291,7 @@ def process(
def paginate_resource(
method: HTTPMethodBasic,
path: str,
headers: Dict[str, Any],
params: Dict[str, Any],
json: Optional[Dict[str, Any]],
paginator: Optional[BasePaginator],
Expand All @@ -313,6 +314,7 @@ def paginate_resource(

yield from client.paginate(
method=method,
headers=headers,
path=path,
params=params,
json=json,
Expand All @@ -327,6 +329,7 @@ def paginate_resource(
)(
method=endpoint_config.get("method", "get"),
path=endpoint_config.get("path"),
headers=endpoint_config.get("headers"),
params=request_params,
json=request_json,
paginator=paginator,
Expand All @@ -346,6 +349,7 @@ def paginate_dependent_resource(
items: List[Dict[str, Any]],
method: HTTPMethodBasic,
path: str,
headers: Dict[str, Any],
params: Dict[str, Any],
paginator: Optional[BasePaginator],
data_selector: Optional[jsonpath.TJsonPath],
Expand All @@ -368,12 +372,13 @@ def paginate_dependent_resource(
)

for item in items:
formatted_path, parent_record = process_parent_data_item(
path, item, resolved_params, include_from_parent
formatted_path, formatted_headers, parent_record = process_parent_data_item(
path, item, resolved_params, include_from_parent, headers
)

for child_page in client.paginate(
method=method,
headers=formatted_headers,
path=formatted_path,
params=params,
paginator=paginator,
Expand All @@ -392,6 +397,7 @@ def paginate_dependent_resource(
)(
method=endpoint_config.get("method", "get"),
path=endpoint_config.get("path"),
headers=endpoint_config.get("headers"),
params=base_params,
paginator=paginator,
data_selector=endpoint_config.get("data_selector"),
Expand Down
120 changes: 91 additions & 29 deletions dlt/sources/rest_api/config_setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,6 +64,7 @@
ResponseActionDict,
Endpoint,
EndpointResource,
ResolveParamLocation,
)


Expand Down Expand Up @@ -330,6 +331,7 @@ def expand_and_index_resources(
assert isinstance(endpoint_resource["endpoint"], dict)
_setup_single_entity_endpoint(endpoint_resource["endpoint"])
_bind_path_params(endpoint_resource)
_bind_header_params(endpoint_resource)

resource_name = endpoint_resource["name"]
assert isinstance(
Expand Down Expand Up @@ -375,50 +377,84 @@ def _make_endpoint_resource(
return _merge_resource_endpoints(default_config, resource)


def _bind_path_params(resource: EndpointResource) -> None:
"""Binds params declared in path to params available in `params`. Pops the
bound params but. Params of type `resolve` and `incremental` are skipped
and bound later.
"""
path_params: Dict[str, Any] = {}
assert isinstance(resource["endpoint"], dict) # type guard
resolve_params = [r.param_name for r in _find_resolved_params(resource["endpoint"])]
path = resource["endpoint"]["path"]
for format_ in string.Formatter().parse(path):
def _bind_string(
target: str,
resource: EndpointResource,
resolve_params: List[str],
location_params: Dict[str, Any],
location: str,
) -> None:
for format_ in string.Formatter().parse(target):
name = format_[1]
if name:
params = resource["endpoint"].get("params", {})
if name not in params and name not in path_params:
params = resource["endpoint"].get("params", {}) # type: ignore[union-attr]
if name not in params and name not in location_params:
raise ValueError(
f"The path {path} defined in resource {resource['name']} requires param with"
f" name {name} but it is not found in {params}"
f"The {location} {target} defined in resource {resource['name']} requires param"
f" with name {name} but it is not found in {params}"
)
if name in resolve_params:
resolve_params.remove(name)
if name in params:
if not isinstance(params[name], dict):
# bind resolved param and pop it from endpoint
path_params[name] = params.pop(name)
location_params[name] = params.pop(name)
else:
param_type = params[name].get("type")
if param_type != "resolve":
raise ValueError(
f"The path {path} defined in resource {resource['name']} tries to bind"
f" param {name} with type {param_type}. Paths can only bind 'resolve'"
" type params."
f"The {location} {target} defined in resource {resource['name']} tries"
f" to bind param {name} with type {param_type}. {location} can only"
" bind 'resolve' type params."
)
# resolved params are bound later
path_params[name] = "{" + name + "}"
location_params[name] = "{" + name + "}"


def _bind_path_params(resource: EndpointResource) -> None:
"""Binds params declared in path to params available in `params`. Pops the
bound params but. Params of type `resolve` and `incremental` are skipped
and bound later.
"""
path_params: Dict[str, Any] = {}
assert isinstance(resource["endpoint"], dict) # type guard
resolve_params = [r.param_name for r in _find_resolved_params(resource["endpoint"], "path")]
path = resource["endpoint"]["path"]
_bind_string(str(path), resource, resolve_params, path_params, "path")

if len(resolve_params) > 0:
raise NotImplementedError(
f"Resource {resource['name']} defines resolve params {resolve_params} that are not"
f" bound in path {path}. Resolve query params not supported yet."
f" bound in path {path}."
)

resource["endpoint"]["path"] = path.format(**path_params)


def _bind_header_params(resource: EndpointResource) -> None:
"""Binds params declared in headers to params available in `params`. Pops the
bound params but skips params of type `resolve` and `incremental`, which are bound later.
"""
header_params: Dict[str, Any] = {}
assert isinstance(resource["endpoint"], dict) # type guard
resolve_params = [r.param_name for r in _find_resolved_params(resource["endpoint"], "header")]
headers = resource["endpoint"].get("headers", {})
formatted_headers = {}
for header_name, header_value in headers.items():
_bind_string(str(header_name), resource, resolve_params, header_params, "header")
_bind_string(str(header_value), resource, resolve_params, header_params, "header")
formatted_headers[header_name.format(**header_params)] = header_value.format(
**header_params
)

if len(resolve_params) > 0:
raise NotImplementedError(
f"Resource {resource['name']} defines resolve params {resolve_params} that are not"
" bound in headers."
)

resource["endpoint"]["headers"] = formatted_headers


def _setup_single_entity_endpoint(endpoint: Endpoint) -> Endpoint:
"""Tries to guess if the endpoint refers to a single entity and when detected:
* if `data_selector` was not specified (or is None), "$" is selected
Expand All @@ -435,18 +471,26 @@ def _setup_single_entity_endpoint(endpoint: Endpoint) -> Endpoint:
return endpoint


def _find_resolved_params(endpoint_config: Endpoint) -> List[ResolvedParam]:
def _find_resolved_params(
endpoint_config: Endpoint, location: Optional[ResolveParamLocation] = None
) -> List[ResolvedParam]:
"""
Find all resolved params in the endpoint configuration and return
a list of ResolvedParam objects.

Param:
location: Optional[ResolveParamLocation] = None - filter resolved params by location if provided.

Resolved params are of type ResolveParamConfig (bound param with a key "type" set to "resolve".)
"""
return [
resolved_params = [
ResolvedParam(key, value) # type: ignore[arg-type]
for key, value in endpoint_config.get("params", {}).items()
if (isinstance(value, dict) and value.get("type") == "resolve")
if isinstance(value, dict) and value.get("type") == "resolve"
]
if location is None:
return resolved_params
return list(filter(lambda rp: rp.resolve_config.get("location") == location, resolved_params))


def _action_type_unless_custom_hook(
Expand Down Expand Up @@ -574,10 +618,14 @@ def process_parent_data_item(
item: Dict[str, Any],
resolved_params: List[ResolvedParam],
include_from_parent: List[str],
) -> Tuple[str, Dict[str, Any]]:
headers: Optional[Dict[str, Any]] = None,
) -> Tuple[str, Dict[str, Any], Dict[str, Any]]:
parent_resource_name = resolved_params[0].resolve_config["resource"]

param_values = {}
param_values: Dict[str, Dict[str, str]] = {
"path": {},
"header": {},
}

for resolved_param in resolved_params:
field_values = jsonpath.find_values(resolved_param.field_path, item)
Expand All @@ -591,9 +639,15 @@ def process_parent_data_item(
f" {', '.join(item.keys())}"
)

param_values[resolved_param.param_name] = field_values[0]
location = resolved_param.resolve_config.get("location")
if location == "path":
param_values["path"][resolved_param.param_name] = field_values[0]
elif location == "header":
param_values["header"][resolved_param.param_name] = field_values[0]
else:
param_values["path"][resolved_param.param_name] = field_values[0]

bound_path = path.format(**param_values)
bound_path = path.format(**param_values["path"])

parent_record: Dict[str, Any] = {}
if include_from_parent:
Expand All @@ -607,7 +661,15 @@ def process_parent_data_item(
)
parent_record[child_key] = item[parent_key]

return bound_path, parent_record
if headers is not None:
formatted_headers = {
k.format(**param_values["header"]) if isinstance(k, str) else str(k): (
v.format(**param_values["header"]) if isinstance(v, str) else str(v)
)
for k, v in headers.items()
}
return bound_path, formatted_headers, parent_record
return bound_path, {}, parent_record


def _merge_resource_endpoints(
Expand Down
9 changes: 8 additions & 1 deletion dlt/sources/rest_api/typing.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,7 @@
Optional,
Union,
)
from enum import Enum

from dlt.common import jsonpath
from dlt.common.schema.typing import (
Expand Down Expand Up @@ -224,9 +225,13 @@ class ParamBindConfig(TypedDict):
type: ParamBindType # noqa


class ResolveParamConfig(ParamBindConfig):
ResolveParamLocation = Literal["path", "header"]


class ResolveParamConfig(ParamBindConfig, total=False):
resource: str
field: str
location: Optional[ResolveParamLocation]


class IncrementalParamConfig(ParamBindConfig, IncrementalRESTArgs):
Expand All @@ -243,6 +248,7 @@ class ResolvedParam:

def __post_init__(self) -> None:
self.field_path = jsonpath.compile_path(self.resolve_config["field"])
self.resolve_config["location"] = self.resolve_config.get("location", "path")


class ResponseActionDict(TypedDict, total=False):
Expand All @@ -259,6 +265,7 @@ class Endpoint(TypedDict, total=False):
method: Optional[HTTPMethodBasic]
params: Optional[Dict[str, Union[ResolveParamConfig, IncrementalParamConfig, Any]]]
json: Optional[Dict[str, Any]]
headers: Optional[Dict[str, Any]]
paginator: Optional[PaginatorConfig]
data_selector: Optional[jsonpath.TJsonPath]
response_actions: Optional[List[ResponseAction]]
Expand Down
Loading