Skip to content

Commit

Permalink
Merge pull request #74 from dod-advana/hotfix/es-auth-support
Browse files Browse the repository at this point in the history
Hotfix/es auth support
  • Loading branch information
vctrstrm committed Dec 9, 2021
2 parents ad8da24 + c06881a commit e70b5c2
Show file tree
Hide file tree
Showing 3 changed files with 130 additions and 26 deletions.
22 changes: 20 additions & 2 deletions gamechangerml/setup_env.sh
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,13 @@ function setup_prod() {
export S3_ML_DATA_PATH="${S3_ML_DATA_PATH:-s3://advana-data-zone/bronze/gamechanger/ml-data/v1/data_20211018.tar.gz}"

export DOWNLOAD_DEP="${DOWNLOAD_DEP:-true}"

export ES_HOST="${ES_HOST:-}"
export ES_PORT="${ES_PORT:-443}"
export ES_USER="${ES_USER:-}"
export ES_PASSWORD="${ES_PASSWORD:-}"
export ES_ENABLE_SSL="${ES_ENABLE_SSL:-true}"
export ES_ENABLE_AUTH="${ES_ENABLE_AUTH:-true}"

export DEV_ENV="PROD"
}
Expand All @@ -48,7 +54,13 @@ function setup_dev() {
export MLFLOW_TRACKING_URI="http://${MLFLOW_HOST}:5050/"
export DOWNLOAD_DEP="${DOWNLOAD_DEP:-false}"
export MODEL_LOAD="${MODEL_LOAD:-True}"
export ES_HOST="${ES_HOST:-https://vpc-gamechanger-dev-es-ms4wkfqyvlyt3gmiyak2hleqyu.us-east-1.es.amazonaws.com}"

export ES_HOST="${ES_HOST:-vpc-gamechanger-dev-es-ms4wkfqyvlyt3gmiyak2hleqyu.us-east-1.es.amazonaws.com}"
export ES_PORT="${ES_PORT:-443}"
export ES_USER="${ES_USER:-}"
export ES_PASSWORD="${ES_PASSWORD:-}"
export ES_ENABLE_SSL="${ES_ENABLE_SSL:-true}"
export ES_ENABLE_AUTH="${ES_ENABLE_AUTH:-false}"
}


Expand All @@ -61,7 +73,13 @@ function setup_devlocal() {
export S3_SENT_INDEX_PATH="${S3_SENT_INDEX_PATH:-s3://advana-data-zone/bronze/gamechanger/models/sentence_index/v4/sent_index_20210422.tar.gz}"
export S3_ML_DATA_PATH="${S3_ML_DATA_PATH:-s3://advana-data-zone/bronze/gamechanger/ml-data/v1/data_20211018.tar.gz}"

export ES_HOST="${ES_HOST:-https://vpc-gamechanger-dev-es-ms4wkfqyvlyt3gmiyak2hleqyu.us-east-1.es.amazonaws.com}"
export ES_HOST="${ES_HOST:-vpc-gamechanger-dev-es-ms4wkfqyvlyt3gmiyak2hleqyu.us-east-1.es.amazonaws.com}"
export ES_PORT="${ES_PORT:-443}"
export ES_USER="${ES_USER:-}"
export ES_PASSWORD="${ES_PASSWORD:-}"
export ES_ENABLE_SSL="${ES_ENABLE_SSL:-true}"
export ES_ENABLE_AUTH="${ES_ENABLE_AUTH:-false}"

export DEV_ENV="DEVLOCAL"
}

Expand Down
132 changes: 109 additions & 23 deletions gamechangerml/src/search/ranking/ltr.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,34 +5,121 @@
from gamechangerml.src.search.ranking import search_data as meta
from gamechangerml.src.search.ranking import rank
from gamechangerml import REPO_PATH
import datetime
import pandas as pd
from tqdm import tqdm
import argparse
import logging
import os
from elasticsearch import Elasticsearch, helpers
import pickle
from elasticsearch import Elasticsearch
import xgboost as xgb
import matplotlib
import math
import requests
import json
from sklearn.preprocessing import LabelEncoder
from gamechangerml import MODEL_PATH, DATA_PATH
import typing as t
import base64
from urllib.parse import urljoin


ES_INDEX = os.environ.get("ES_INDEX", "gamechanger")
class ESUtils:
def __init__(self,
host: str = os.environ.get("ES_HOST", "localhost"),
port: str = os.environ.get("ES_PORT", 9200),
user: str = os.environ.get("ES_USER", ""),
password: str = os.environ.get("ES_PASSWORD", ""),
enable_ssl: bool = os.environ.get("ES_ENABLE_SSL", "False").lower() == "true",
enable_auth: bool = os.environ.get("ES_ENABLE_AUTH", "False").lower() == "true"):

self.host = host
self.port = port
self.user = user
self.password = password
self.enable_ssl = enable_ssl
self.enable_auth = enable_auth

self.auth_token = base64.b64encode(f"{self.user}:{self.password}".encode()).decode()

@property
def client(self) -> Elasticsearch:
if hasattr(self, '_es_args'):
return Elasticsearch(**self._es_args)

host_args = dict(
hosts=[{
'host': self.host,
'port': self.port,
'http_compress': True,
'timeout': 60
}]
)

auth_args = dict(
http_auth=(
self.user,
self.password
)
) if self.enable_auth else {}

ssl_args = dict(
use_ssl=self.enable_ssl
)

es_args = dict(
**host_args,
**auth_args,
**ssl_args,
)

self._es_args: t.Dict[str, t.Any] = es_args
return Elasticsearch(**self._es_args)

@property
def auth_headers(self) -> t.Dict[str, str]:
return {
"Authorization": f"Basic {self.auth_token}"
}

@property
def content_headers(self) -> t.Dict[str, str]:
return {
"Content-Type": "application/json"
}

@property
def default_headers(self) -> t.Dict[str, str]:
return dict(
**self.auth_headers,
**self.content_headers
)

@property
def root_url(self) -> str:
return "http" + "s" if self.enable_ssl else "" + f"://{self.host}:{self.port}/"

def request(self, method: str, endpoint: str, **request_opts) -> requests.Response:
url = urljoin(self.root_url, endpoint.lstrip("/"))
return requests.request(method=method, url=url, headers=self.default_headers, **request_opts)

def post(self, endpoint: str, **request_opts) -> requests.Response:
return self.request(method='POST', endpoint=endpoint, **request_opts)

def put(self, endpoint: str, **request_opts) -> requests.Response:
return self.request(method='PUT', endpoint=endpoint, **request_opts)

def get(self, endpoint: str, **request_opts) -> requests.Response:
return self.request(method='GET', endpoint=endpoint, **request_opts)

def delete(self, endpoint: str, **request_opts) -> requests.Response:
return self.request(method='DELETE', endpoint=endpoint, **request_opts)

ES_HOST = os.environ.get("ES_HOST", default="localhost")
ES_INDEX = os.environ.get("ES_INDEX", default="gamechanger")

client = Elasticsearch([ES_HOST], timeout=60)
logger = logging.getLogger("gamechanger")

LTR_MODEL_PATH = os.path.join(MODEL_PATH, "ltr")
LTR_DATA_PATH = os.path.join(DATA_PATH, "ltr")
os.makedirs(LTR_MODEL_PATH, exist_ok=True)
os.makedirs(LTR_DATA_PATH, exist_ok=True)


class LTR:
def __init__(
self,
Expand Down Expand Up @@ -61,6 +148,7 @@ def __init__(
"rmse",
"error",
]
self.esu = ESUtils()

def write_model(self, model):
"""write model: writes model to file
Expand Down Expand Up @@ -89,7 +177,7 @@ def read_xg_data(self, path=os.path.join(LTR_DATA_PATH, "xgboost.csv")):
except Exception as e:
logger.error("Could not read in data for training")

def read_mappings(self, path="gamechangerml/data/SearchPdfMapping.csv"):
def read_mappings(self, path=os.path.join(DATA_PATH, "SearchPdfMapping.csv")):
"""read mappings: reads search pdf mappings
params: path to file
returns:
Expand Down Expand Up @@ -132,15 +220,14 @@ def post_model(self, model, model_name):
returns:
r: results
"""
headers = {"Content-Type": "application/json"}
query = {
"model": {
"name": model_name,
"model": {"type": "model/xgboost+json", "definition": model},
}
}
endpoint = ES_HOST + "/_ltr/_featureset/doc_features/_createmodel"
r = requests.post(endpoint, data=json.dumps(query), headers=headers)
endpoint = "/_ltr/_featureset/doc_features/_createmodel"
r = self.esu.post(endpoint, data=json.dumps(query))
return r.content

def search(self, terms, rescore=True):
Expand Down Expand Up @@ -222,7 +309,7 @@ def search(self, terms, rescore=True):
}
}
}
r = client.search(index="gamechanger", body=dict(query))
r = self.esu.client.search(index=ES_INDEX, body=dict(query))
return r

def generate_judgement(self, mappings):
Expand Down Expand Up @@ -290,7 +377,7 @@ def query_es_fts(self, df):
query_list.append(json.dumps({"index": ES_INDEX}))
query_list.append(json.dumps(q))
query = "\n".join(query_list)
res = client.msearch(body=query)
res = self.esu.client.msearch(body=query)
ltr_log = [x["hits"]["hits"] for x in res["responses"]]
return ltr_log

Expand Down Expand Up @@ -501,19 +588,18 @@ def post_features(self):
],
}
}
headers = {"Content-Type": "application/json"}
endpoint = ES_HOST + "/_ltr/_featureset/doc_features"
r = requests.post(endpoint, data=json.dumps(query), headers=headers)
endpoint = "/_ltr/_featureset/doc_features"
r = self.esu.post(endpoint, data=json.dumps(query))
return r.content

def post_init_ltr(self):
endpoint = ES_HOST + "/_ltr"
r = requests.put(endpoint)
endpoint = "/_ltr"
r = self.esu.put(endpoint)
return r.content

def delete_ltr(self, model_name="ltr_model"):
endpoint = ES_HOST + f"/_ltr/_model/{model_name}"
r = requests.delete(endpoint)
endpoint = f"/_ltr/_model/{model_name}"
r = self.esu.delete(endpoint)
return r.content

def normalize(self, arr, start=0, end=4):
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ def parse_readme(readme: Path) -> str:
python_requires=">=3.8.0",
install_requires=[
p for p in parse_requirements(REQUIREMENTS_PATH)
if re.split(r'\s*[@=]\s*')[0].lower()
if re.split(r'\s*[@=]\s*', p)[0].lower()
not in EXCLUDE_PACKAGES
] + SUBSTITUTE_PACKAGES,
include_package_data=True,
Expand Down

0 comments on commit e70b5c2

Please sign in to comment.