Skip to content

Commit

Permalink
Fixes es auth & ssl connections
Browse files Browse the repository at this point in the history
  • Loading branch information
vctrstrm committed Dec 10, 2021
1 parent b665c8f commit 7a62ddc
Showing 1 changed file with 23 additions and 27 deletions.
50 changes: 23 additions & 27 deletions gamechangerml/src/search/ranking/ltr.py
Original file line number Diff line number Diff line change
Expand Up @@ -47,6 +47,9 @@ def __init__(

@property
def client(self) -> Elasticsearch:
if hasattr(self, "_client"):
return getattr(self, "_client")

host_args = dict(
hosts=[
{
Expand All @@ -57,11 +60,7 @@ def client(self) -> Elasticsearch:
}
]
)
auth_args = (
dict(http_auth=(self.user, self.password)
) if self.enable_auth else {}
)

auth_args = dict(http_auth=(self.user, self.password)) if self.enable_auth else {}
ssl_args = dict(use_ssl=self.enable_ssl)

es_args = dict(
Expand All @@ -70,14 +69,12 @@ def client(self) -> Elasticsearch:
**ssl_args,
)

self._es_args: t.Dict[str, t.Any] = es_args
if hasattr(self, "_es_args"):
return Elasticsearch(**self._es_args)
return Elasticsearch(**self._es_args)
self._es_client = Elasticsearch(**es_args)
return self._es_client

@property
def auth_headers(self) -> t.Dict[str, str]:
return {"Authorization": f"Basic {self.auth_token}"}
return {"Authorization": f"Basic {self.auth_token}"} if self.enable_auth else {}

@property
def content_headers(self) -> t.Dict[str, str]:
Expand All @@ -92,26 +89,25 @@ def default_headers(self) -> t.Dict[str, str]:

@property
def root_url(self) -> str:
return "https://" if self.enable_ssl else "" + f"http://"
return ("https" if self.enable_ssl else "http") + f"://{self.host}:{self.port}/"

def request(self, method: str, endpoint: str, **request_opts) -> requests.Response:
# url = urljoin(self.root_url, endpoint.lstrip("/"))
url = self.root_url + endpoint
def request(self, method: str, url: str, **request_opts) -> requests.Response:
complete_url = urljoin(self.root_url, url.lstrip("/"))
return requests.request(
method=method, url=url, headers=self.default_headers, **request_opts
method=method, url=complete_url, headers=self.default_headers, **request_opts
)

def post(self, endpoint: str, **request_opts) -> requests.Response:
return self.request(method="POST", endpoint=endpoint, **request_opts)
def post(self, url: str, **request_opts) -> requests.Response:
return self.request(method="POST", url=url, **request_opts)

def put(self, endpoint: str, **request_opts) -> requests.Response:
return self.request(method="PUT", endpoint=endpoint, **request_opts)
def put(self, url: str, **request_opts) -> requests.Response:
return self.request(method="PUT", url=url, **request_opts)

def get(self, endpoint: str, **request_opts) -> requests.Response:
return self.request(method="GET", endpoint=endpoint, **request_opts)
def get(self, url: str, **request_opts) -> requests.Response:
return self.request(method="GET", url=url, **request_opts)

def delete(self, endpoint: str, **request_opts) -> requests.Response:
return self.request(method="DELETE", endpoint=endpoint, **request_opts)
def delete(self, url: str, **request_opts) -> requests.Response:
return self.request(method="DELETE", url=url, **request_opts)


logger = logging.getLogger("gamechanger")
Expand Down Expand Up @@ -228,7 +224,7 @@ def post_model(self, model, model_name):
"model": {"type": "model/xgboost+json", "definition": model},
}
}
endpoint = f"{self.esu.host}/_ltr/_featureset/doc_features/_createmodel"
endpoint = "/_ltr/_featureset/doc_features/_createmodel"
r = self.esu.post(endpoint, data=json.dumps(query))
return r.content

Expand Down Expand Up @@ -590,17 +586,17 @@ def post_features(self):
],
}
}
endpoint = f"{self.esu.host}/_ltr/_featureset/doc_features"
endpoint = "/_ltr/_featureset/doc_features"
r = self.esu.post(endpoint, data=json.dumps(query))
return r.content

def post_init_ltr(self):
endpoint = f"{self.esu.host}/_ltr"
endpoint = "/_ltr"
r = self.esu.put(endpoint)
return r.content

def delete_ltr(self, model_name="ltr_model"):
endpoint = f"{self.esu.host}/_ltr/_model/{model_name}"
endpoint = "/_ltr/_model/{model_name}"
r = self.esu.delete(endpoint)
return r.content

Expand Down

0 comments on commit 7a62ddc

Please sign in to comment.