Skip to content

Commit

Permalink
Merge pull request #75 from dod-advana/hotfix/es_client
Browse files Browse the repository at this point in the history
some fixes for the url paths and hosts
  • Loading branch information
vctrstrm committed Dec 10, 2021
2 parents e70b5c2 + 7a62ddc commit 14ed032
Show file tree
Hide file tree
Showing 2 changed files with 57 additions and 60 deletions.
17 changes: 8 additions & 9 deletions gamechangerml/scripts/download_dependencies.sh
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
#!/bin/bash
#!/usr/bin/env bash
echo "Be sure to set up environment variables for s3 by sourcing setup_env.sh if running this manually"

function download_and_unpack_deps() {

local pkg_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )/../../" >/dev/null 2>&1 && pwd )"
local pkg_dir="$( cd "$( dirname "${BASH_SOURCE[0]}" )/../" >/dev/null 2>&1 && pwd )"
local models_dest_dir="$pkg_dir/models/"
local data_dest_dir="$pkg_dir/data/"

Expand Down Expand Up @@ -34,15 +34,14 @@ function download_and_unpack_deps() {
aws s3 cp "$S3_ML_DATA_PATH" "$data_dest_dir"

echo "Uncompressing all tar files in models"
for f in ./gamechangerml/models/*.tar.gz; do
tar kxvfz "$f" --exclude '*/.git/*' --exclude '*/.DS_Store/*' -C "$models_dest_dir";
done
find "$models_dest_dir" -maxdepth 1 -type f -name "*.tar.gz" | while IFS=$'\n' read -r f; do
tar kxzvf "$f" --exclude '*/.git/*' --exclude '*/.DS_Store/*' -C "$models_dest_dir"
done

echo "Uncompressing all tar files in data"
for f in ./gamechangerml/data/*.tar.gz; do
tar kxvfz "$f" --exclude '*/.git/*' --exclude '*/.DS_Store/*' -C "$data_dest_dir";
find "$data_dest_dir" -maxdepth 1 -type f -name "*.tar.gz" | while IFS=$'\n' read -r f; do
tar kxzvf "$f" --exclude '*/.git/*' --exclude '*/.DS_Store/*' -C "$data_dest_dir"
done

}

download_and_unpack_deps
download_and_unpack_deps
100 changes: 49 additions & 51 deletions gamechangerml/src/search/ranking/ltr.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,96 +21,93 @@


ES_INDEX = os.environ.get("ES_INDEX", "gamechanger")


class ESUtils:
def __init__(self,
def __init__(
self,
host: str = os.environ.get("ES_HOST", "localhost"),
port: str = os.environ.get("ES_PORT", 9200),
port: str = os.environ.get("ES_PORT", 443),
user: str = os.environ.get("ES_USER", ""),
password: str = os.environ.get("ES_PASSWORD", ""),
enable_ssl: bool = os.environ.get("ES_ENABLE_SSL", "False").lower() == "true",
enable_auth: bool = os.environ.get("ES_ENABLE_AUTH", "False").lower() == "true"):
enable_ssl: bool = os.environ.get("ES_ENABLE_SSL", "True").lower() == "true",
enable_auth: bool = os.environ.get("ES_ENABLE_AUTH", "False").lower() == "true",
):

self.host = host
self.port = port
self.user = user
self.password = password
self.enable_ssl = enable_ssl
self.enable_auth = enable_auth

self.auth_token = base64.b64encode(f"{self.user}:{self.password}".encode()).decode()

self.auth_token = base64.b64encode(
f"{self.user}:{self.password}".encode()
).decode()

@property
def client(self) -> Elasticsearch:
if hasattr(self, '_es_args'):
return Elasticsearch(**self._es_args)
if hasattr(self, "_client"):
return getattr(self, "_client")

host_args = dict(
hosts=[{
'host': self.host,
'port': self.port,
'http_compress': True,
'timeout': 60
}]
)

auth_args = dict(
http_auth=(
self.user,
self.password
)
) if self.enable_auth else {}

ssl_args = dict(
use_ssl=self.enable_ssl
hosts=[
{
"host": self.host,
"port": self.port,
"http_compress": True,
"timeout": 60,
}
]
)
auth_args = dict(http_auth=(self.user, self.password)) if self.enable_auth else {}
ssl_args = dict(use_ssl=self.enable_ssl)

es_args = dict(
**host_args,
**auth_args,
**ssl_args,
)

self._es_args: t.Dict[str, t.Any] = es_args
return Elasticsearch(**self._es_args)
self._es_client = Elasticsearch(**es_args)
return self._es_client

@property
def auth_headers(self) -> t.Dict[str, str]:
return {
"Authorization": f"Basic {self.auth_token}"
}
return {"Authorization": f"Basic {self.auth_token}"} if self.enable_auth else {}

@property
def content_headers(self) -> t.Dict[str, str]:
return {
"Content-Type": "application/json"
}

return {"Content-Type": "application/json"}

@property
def default_headers(self) -> t.Dict[str, str]:
return dict(
**self.auth_headers,
**self.content_headers
)
if self.enable_auth:
return dict(**self.auth_headers, **self.content_headers)
else:
return dict(**self.content_headers)

@property
def root_url(self) -> str:
return "http" + "s" if self.enable_ssl else "" + f"://{self.host}:{self.port}/"
return ("https" if self.enable_ssl else "http") + f"://{self.host}:{self.port}/"

def request(self, method: str, endpoint: str, **request_opts) -> requests.Response:
url = urljoin(self.root_url, endpoint.lstrip("/"))
return requests.request(method=method, url=url, headers=self.default_headers, **request_opts)
def request(self, method: str, url: str, **request_opts) -> requests.Response:
complete_url = urljoin(self.root_url, url.lstrip("/"))
return requests.request(
method=method, url=complete_url, headers=self.default_headers, **request_opts
)

def post(self, endpoint: str, **request_opts) -> requests.Response:
return self.request(method='POST', endpoint=endpoint, **request_opts)
def post(self, url: str, **request_opts) -> requests.Response:
return self.request(method="POST", url=url, **request_opts)

def put(self, endpoint: str, **request_opts) -> requests.Response:
return self.request(method='PUT', endpoint=endpoint, **request_opts)
def put(self, url: str, **request_opts) -> requests.Response:
return self.request(method="PUT", url=url, **request_opts)

def get(self, endpoint: str, **request_opts) -> requests.Response:
return self.request(method='GET', endpoint=endpoint, **request_opts)
def get(self, url: str, **request_opts) -> requests.Response:
return self.request(method="GET", url=url, **request_opts)

def delete(self, endpoint: str, **request_opts) -> requests.Response:
return self.request(method='DELETE', endpoint=endpoint, **request_opts)
def delete(self, url: str, **request_opts) -> requests.Response:
return self.request(method="DELETE", url=url, **request_opts)


logger = logging.getLogger("gamechanger")
Expand All @@ -120,6 +117,7 @@ def delete(self, endpoint: str, **request_opts) -> requests.Response:
os.makedirs(LTR_MODEL_PATH, exist_ok=True)
os.makedirs(LTR_DATA_PATH, exist_ok=True)


class LTR:
def __init__(
self,
Expand Down Expand Up @@ -598,7 +596,7 @@ def post_init_ltr(self):
return r.content

def delete_ltr(self, model_name="ltr_model"):
endpoint = f"/_ltr/_model/{model_name}"
endpoint = "/_ltr/_model/{model_name}"
r = self.esu.delete(endpoint)
return r.content

Expand Down

0 comments on commit 14ed032

Please sign in to comment.