From 7a62ddced946b0ee187afaf697f591378d24a608 Mon Sep 17 00:00:00 2001 From: Eddie Storm <5409647+vctrstrm@users.noreply.github.com> Date: Thu, 9 Dec 2021 21:00:43 -0800 Subject: [PATCH] Fixes es auth & ssl connections --- gamechangerml/src/search/ranking/ltr.py | 50 ++++++++++++------------- 1 file changed, 23 insertions(+), 27 deletions(-) diff --git a/gamechangerml/src/search/ranking/ltr.py b/gamechangerml/src/search/ranking/ltr.py index d4e592b6..60efa190 100644 --- a/gamechangerml/src/search/ranking/ltr.py +++ b/gamechangerml/src/search/ranking/ltr.py @@ -47,6 +47,9 @@ def __init__( @property def client(self) -> Elasticsearch: + if hasattr(self, "_client"): + return getattr(self, "_client") + host_args = dict( hosts=[ { @@ -57,11 +60,7 @@ def client(self) -> Elasticsearch: } ] ) - auth_args = ( - dict(http_auth=(self.user, self.password) - ) if self.enable_auth else {} - ) - + auth_args = dict(http_auth=(self.user, self.password)) if self.enable_auth else {} ssl_args = dict(use_ssl=self.enable_ssl) es_args = dict( @@ -70,14 +69,12 @@ def client(self) -> Elasticsearch: **ssl_args, ) - self._es_args: t.Dict[str, t.Any] = es_args - if hasattr(self, "_es_args"): - return Elasticsearch(**self._es_args) - return Elasticsearch(**self._es_args) + self._es_client = Elasticsearch(**es_args) + return self._es_client @property def auth_headers(self) -> t.Dict[str, str]: - return {"Authorization": f"Basic {self.auth_token}"} + return {"Authorization": f"Basic {self.auth_token}"} if self.enable_auth else {} @property def content_headers(self) -> t.Dict[str, str]: @@ -92,26 +89,25 @@ def default_headers(self) -> t.Dict[str, str]: @property def root_url(self) -> str: - return "https://" if self.enable_ssl else "" + f"http://" + return ("https" if self.enable_ssl else "http") + f"://{self.host}:{self.port}/" - def request(self, method: str, endpoint: str, **request_opts) -> requests.Response: - # url = urljoin(self.root_url, endpoint.lstrip("/")) - url = self.root_url + endpoint + def request(self, method: str, url: str, **request_opts) -> requests.Response: + complete_url = urljoin(self.root_url, url.lstrip("/")) return requests.request( - method=method, url=url, headers=self.default_headers, **request_opts + method=method, url=complete_url, headers=self.default_headers, **request_opts ) - def post(self, endpoint: str, **request_opts) -> requests.Response: - return self.request(method="POST", endpoint=endpoint, **request_opts) + def post(self, url: str, **request_opts) -> requests.Response: + return self.request(method="POST", url=url, **request_opts) - def put(self, endpoint: str, **request_opts) -> requests.Response: - return self.request(method="PUT", endpoint=endpoint, **request_opts) + def put(self, url: str, **request_opts) -> requests.Response: + return self.request(method="PUT", url=url, **request_opts) - def get(self, endpoint: str, **request_opts) -> requests.Response: - return self.request(method="GET", endpoint=endpoint, **request_opts) + def get(self, url: str, **request_opts) -> requests.Response: + return self.request(method="GET", url=url, **request_opts) - def delete(self, endpoint: str, **request_opts) -> requests.Response: - return self.request(method="DELETE", endpoint=endpoint, **request_opts) + def delete(self, url: str, **request_opts) -> requests.Response: + return self.request(method="DELETE", url=url, **request_opts) logger = logging.getLogger("gamechanger") @@ -228,7 +224,7 @@ def post_model(self, model, model_name): "model": {"type": "model/xgboost+json", "definition": model}, } } - endpoint = f"{self.esu.host}/_ltr/_featureset/doc_features/_createmodel" + endpoint = "/_ltr/_featureset/doc_features/_createmodel" r = self.esu.post(endpoint, data=json.dumps(query)) return r.content @@ -590,17 +586,17 @@ def post_features(self): ], } } - endpoint = f"{self.esu.host}/_ltr/_featureset/doc_features" + endpoint = "/_ltr/_featureset/doc_features" r = self.esu.post(endpoint, data=json.dumps(query)) return r.content def post_init_ltr(self): - endpoint = f"{self.esu.host}/_ltr" + endpoint = "/_ltr" r = self.esu.put(endpoint) return r.content def delete_ltr(self, model_name="ltr_model"): - endpoint = f"{self.esu.host}/_ltr/_model/{model_name}" + endpoint = "/_ltr/_model/{model_name}" r = self.esu.delete(endpoint) return r.content