diff --git a/gamechangerml/scripts/download_dependencies.sh b/gamechangerml/scripts/download_dependencies.sh index cf6e1166..aa5ca512 100755 --- a/gamechangerml/scripts/download_dependencies.sh +++ b/gamechangerml/scripts/download_dependencies.sh @@ -1,9 +1,9 @@ -#!/bin/bash +#!/usr/bin/env bash echo "Be sure to set up environment variables for s3 by sourcing setup_env.sh if running this manually" function download_and_unpack_deps() { - local pkg_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )/../../" >/dev/null 2>&1 && pwd )" + local pkg_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )/../" >/dev/null 2>&1 && pwd )" local models_dest_dir="$pkg_dir/models/" local data_dest_dir="$pkg_dir/data/" @@ -34,15 +34,14 @@ function download_and_unpack_deps() { aws s3 cp "$S3_ML_DATA_PATH" "$data_dest_dir" echo "Uncompressing all tar files in models" - for f in ./gamechangerml/models/*.tar.gz; do - tar kxvfz "$f" --exclude '*/.git/*' --exclude '*/.DS_Store/*' -C "$models_dest_dir"; - done + find "$models_dest_dir" -maxdepth 1 -type f -name "*.tar.gz" | while IFS=$'\n' read -r f; do + tar kxzvf "$f" --exclude '*/.git/*' --exclude '*/.DS_Store/*' -C "$models_dest_dir" + done echo "Uncompressing all tar files in data" - for f in ./gamechangerml/data/*.tar.gz; do - tar kxvfz "$f" --exclude '*/.git/*' --exclude '*/.DS_Store/*' -C "$data_dest_dir"; + find "$data_dest_dir" -maxdepth 1 -type f -name "*.tar.gz" | while IFS=$'\n' read -r f; do + tar kxzvf "$f" --exclude '*/.git/*' --exclude '*/.DS_Store/*' -C "$data_dest_dir" done - } -download_and_unpack_deps \ No newline at end of file +download_and_unpack_deps diff --git a/gamechangerml/src/search/ranking/ltr.py b/gamechangerml/src/search/ranking/ltr.py index 415edd0a..60efa190 100644 --- a/gamechangerml/src/search/ranking/ltr.py +++ b/gamechangerml/src/search/ranking/ltr.py @@ -21,14 +21,18 @@ ES_INDEX = os.environ.get("ES_INDEX", "gamechanger") + + class ESUtils: - def __init__(self, + def __init__( + self, host: str = os.environ.get("ES_HOST", "localhost"), - port: str = os.environ.get("ES_PORT", 9200), + port: str = os.environ.get("ES_PORT", 443), user: str = os.environ.get("ES_USER", ""), password: str = os.environ.get("ES_PASSWORD", ""), - enable_ssl: bool = os.environ.get("ES_ENABLE_SSL", "False").lower() == "true", - enable_auth: bool = os.environ.get("ES_ENABLE_AUTH", "False").lower() == "true"): + enable_ssl: bool = os.environ.get("ES_ENABLE_SSL", "True").lower() == "true", + enable_auth: bool = os.environ.get("ES_ENABLE_AUTH", "False").lower() == "true", + ): self.host = host self.port = port @@ -36,33 +40,28 @@ def __init__(self, self.password = password self.enable_ssl = enable_ssl self.enable_auth = enable_auth - - self.auth_token = base64.b64encode(f"{self.user}:{self.password}".encode()).decode() + + self.auth_token = base64.b64encode( + f"{self.user}:{self.password}".encode() + ).decode() @property def client(self) -> Elasticsearch: - if hasattr(self, '_es_args'): - return Elasticsearch(**self._es_args) + if hasattr(self, "_client"): + return getattr(self, "_client") host_args = dict( - hosts=[{ - 'host': self.host, - 'port': self.port, - 'http_compress': True, - 'timeout': 60 - }] - ) - - auth_args = dict( - http_auth=( - self.user, - self.password - ) - ) if self.enable_auth else {} - - ssl_args = dict( - use_ssl=self.enable_ssl + hosts=[ + { + "host": self.host, + "port": self.port, + "http_compress": True, + "timeout": 60, + } + ] ) + auth_args = dict(http_auth=(self.user, self.password)) if self.enable_auth else {} + ssl_args = dict(use_ssl=self.enable_ssl) es_args = dict( **host_args, @@ -70,47 +69,45 @@ def client(self) -> Elasticsearch: **ssl_args, ) - self._es_args: t.Dict[str, t.Any] = es_args - return Elasticsearch(**self._es_args) + self._es_client = Elasticsearch(**es_args) + return self._es_client @property def auth_headers(self) -> t.Dict[str, str]: - return { - "Authorization": f"Basic {self.auth_token}" - } + return {"Authorization": f"Basic {self.auth_token}"} if self.enable_auth else {} @property def content_headers(self) -> t.Dict[str, str]: - return { - "Content-Type": "application/json" - } - + return {"Content-Type": "application/json"} + @property def default_headers(self) -> t.Dict[str, str]: - return dict( - **self.auth_headers, - **self.content_headers - ) + if self.enable_auth: + return dict(**self.auth_headers, **self.content_headers) + else: + return dict(**self.content_headers) @property def root_url(self) -> str: - return "http" + "s" if self.enable_ssl else "" + f"://{self.host}:{self.port}/" + return ("https" if self.enable_ssl else "http") + f"://{self.host}:{self.port}/" - def request(self, method: str, endpoint: str, **request_opts) -> requests.Response: - url = urljoin(self.root_url, endpoint.lstrip("/")) - return requests.request(method=method, url=url, headers=self.default_headers, **request_opts) + def request(self, method: str, url: str, **request_opts) -> requests.Response: + complete_url = urljoin(self.root_url, url.lstrip("/")) + return requests.request( + method=method, url=complete_url, headers=self.default_headers, **request_opts + ) - def post(self, endpoint: str, **request_opts) -> requests.Response: - return self.request(method='POST', endpoint=endpoint, **request_opts) + def post(self, url: str, **request_opts) -> requests.Response: + return self.request(method="POST", url=url, **request_opts) - def put(self, endpoint: str, **request_opts) -> requests.Response: - return self.request(method='PUT', endpoint=endpoint, **request_opts) + def put(self, url: str, **request_opts) -> requests.Response: + return self.request(method="PUT", url=url, **request_opts) - def get(self, endpoint: str, **request_opts) -> requests.Response: - return self.request(method='GET', endpoint=endpoint, **request_opts) + def get(self, url: str, **request_opts) -> requests.Response: + return self.request(method="GET", url=url, **request_opts) - def delete(self, endpoint: str, **request_opts) -> requests.Response: - return self.request(method='DELETE', endpoint=endpoint, **request_opts) + def delete(self, url: str, **request_opts) -> requests.Response: + return self.request(method="DELETE", url=url, **request_opts) logger = logging.getLogger("gamechanger") @@ -120,6 +117,7 @@ def delete(self, endpoint: str, **request_opts) -> requests.Response: os.makedirs(LTR_MODEL_PATH, exist_ok=True) os.makedirs(LTR_DATA_PATH, exist_ok=True) + class LTR: def __init__( self, @@ -598,7 +596,7 @@ def post_init_ltr(self): return r.content def delete_ltr(self, model_name="ltr_model"): - endpoint = f"/_ltr/_model/{model_name}" + endpoint = "/_ltr/_model/{model_name}" r = self.esu.delete(endpoint) return r.content