diff --git a/gamechangerml/setup_env.sh b/gamechangerml/setup_env.sh index 893911cb..c58f6aa1 100755 --- a/gamechangerml/setup_env.sh +++ b/gamechangerml/setup_env.sh @@ -25,7 +25,13 @@ function setup_prod() { export S3_ML_DATA_PATH="${S3_ML_DATA_PATH:-s3://advana-data-zone/bronze/gamechanger/ml-data/v1/data_20211018.tar.gz}" export DOWNLOAD_DEP="${DOWNLOAD_DEP:-true}" + export ES_HOST="${ES_HOST:-}" + export ES_PORT="${ES_PORT:-443}" + export ES_USER="${ES_USER:-}" + export ES_PASSWORD="${ES_PASSWORD:-}" + export ES_ENABLE_SSL="${ES_ENABLE_SSL:-true}" + export ES_ENABLE_AUTH="${ES_ENABLE_AUTH:-true}" export DEV_ENV="PROD" } @@ -48,7 +54,13 @@ function setup_dev() { export MLFLOW_TRACKING_URI="http://${MLFLOW_HOST}:5050/" export DOWNLOAD_DEP="${DOWNLOAD_DEP:-false}" export MODEL_LOAD="${MODEL_LOAD:-True}" - export ES_HOST="${ES_HOST:-https://vpc-gamechanger-dev-es-ms4wkfqyvlyt3gmiyak2hleqyu.us-east-1.es.amazonaws.com}" + + export ES_HOST="${ES_HOST:-vpc-gamechanger-dev-es-ms4wkfqyvlyt3gmiyak2hleqyu.us-east-1.es.amazonaws.com}" + export ES_PORT="${ES_PORT:-443}" + export ES_USER="${ES_USER:-}" + export ES_PASSWORD="${ES_PASSWORD:-}" + export ES_ENABLE_SSL="${ES_ENABLE_SSL:-true}" + export ES_ENABLE_AUTH="${ES_ENABLE_AUTH:-false}" } @@ -61,7 +73,13 @@ function setup_devlocal() { export S3_SENT_INDEX_PATH="${S3_SENT_INDEX_PATH:-s3://advana-data-zone/bronze/gamechanger/models/sentence_index/v4/sent_index_20210422.tar.gz}" export S3_ML_DATA_PATH="${S3_ML_DATA_PATH:-s3://advana-data-zone/bronze/gamechanger/ml-data/v1/data_20211018.tar.gz}" - export ES_HOST="${ES_HOST:-https://vpc-gamechanger-dev-es-ms4wkfqyvlyt3gmiyak2hleqyu.us-east-1.es.amazonaws.com}" + export ES_HOST="${ES_HOST:-vpc-gamechanger-dev-es-ms4wkfqyvlyt3gmiyak2hleqyu.us-east-1.es.amazonaws.com}" + export ES_PORT="${ES_PORT:-443}" + export ES_USER="${ES_USER:-}" + export ES_PASSWORD="${ES_PASSWORD:-}" + export ES_ENABLE_SSL="${ES_ENABLE_SSL:-true}" + export ES_ENABLE_AUTH="${ES_ENABLE_AUTH:-false}" + export DEV_ENV="DEVLOCAL" } diff --git a/gamechangerml/src/search/ranking/ltr.py b/gamechangerml/src/search/ranking/ltr.py index 9ba1ab9e..415edd0a 100644 --- a/gamechangerml/src/search/ranking/ltr.py +++ b/gamechangerml/src/search/ranking/ltr.py @@ -5,26 +5,114 @@ from gamechangerml.src.search.ranking import search_data as meta from gamechangerml.src.search.ranking import rank from gamechangerml import REPO_PATH -import datetime import pandas as pd from tqdm import tqdm -import argparse import logging import os -from elasticsearch import Elasticsearch, helpers -import pickle +from elasticsearch import Elasticsearch import xgboost as xgb -import matplotlib -import math import requests import json from sklearn.preprocessing import LabelEncoder from gamechangerml import MODEL_PATH, DATA_PATH +import typing as t +import base64 +from urllib.parse import urljoin + + +ES_INDEX = os.environ.get("ES_INDEX", "gamechanger") +class ESUtils: + def __init__(self, + host: str = os.environ.get("ES_HOST", "localhost"), + port: str = os.environ.get("ES_PORT", 9200), + user: str = os.environ.get("ES_USER", ""), + password: str = os.environ.get("ES_PASSWORD", ""), + enable_ssl: bool = os.environ.get("ES_ENABLE_SSL", "False").lower() == "true", + enable_auth: bool = os.environ.get("ES_ENABLE_AUTH", "False").lower() == "true"): + + self.host = host + self.port = port + self.user = user + self.password = password + self.enable_ssl = enable_ssl + self.enable_auth = enable_auth + + self.auth_token = base64.b64encode(f"{self.user}:{self.password}".encode()).decode() + + @property + def client(self) -> Elasticsearch: + if hasattr(self, '_es_args'): + return Elasticsearch(**self._es_args) + + host_args = dict( + hosts=[{ + 'host': self.host, + 'port': self.port, + 'http_compress': True, + 'timeout': 60 + }] + ) + + auth_args = dict( + http_auth=( + self.user, + self.password + ) + ) if self.enable_auth else {} + + ssl_args = dict( + use_ssl=self.enable_ssl + ) + + es_args = dict( + **host_args, + **auth_args, + **ssl_args, + ) + + self._es_args: t.Dict[str, t.Any] = es_args + return Elasticsearch(**self._es_args) + + @property + def auth_headers(self) -> t.Dict[str, str]: + return { + "Authorization": f"Basic {self.auth_token}" + } + + @property + def content_headers(self) -> t.Dict[str, str]: + return { + "Content-Type": "application/json" + } + + @property + def default_headers(self) -> t.Dict[str, str]: + return dict( + **self.auth_headers, + **self.content_headers + ) + + @property + def root_url(self) -> str: + return "http" + "s" if self.enable_ssl else "" + f"://{self.host}:{self.port}/" + + def request(self, method: str, endpoint: str, **request_opts) -> requests.Response: + url = urljoin(self.root_url, endpoint.lstrip("/")) + return requests.request(method=method, url=url, headers=self.default_headers, **request_opts) + + def post(self, endpoint: str, **request_opts) -> requests.Response: + return self.request(method='POST', endpoint=endpoint, **request_opts) + + def put(self, endpoint: str, **request_opts) -> requests.Response: + return self.request(method='PUT', endpoint=endpoint, **request_opts) + + def get(self, endpoint: str, **request_opts) -> requests.Response: + return self.request(method='GET', endpoint=endpoint, **request_opts) + + def delete(self, endpoint: str, **request_opts) -> requests.Response: + return self.request(method='DELETE', endpoint=endpoint, **request_opts) -ES_HOST = os.environ.get("ES_HOST", default="localhost") -ES_INDEX = os.environ.get("ES_INDEX", default="gamechanger") -client = Elasticsearch([ES_HOST], timeout=60) logger = logging.getLogger("gamechanger") LTR_MODEL_PATH = os.path.join(MODEL_PATH, "ltr") @@ -32,7 +120,6 @@ os.makedirs(LTR_MODEL_PATH, exist_ok=True) os.makedirs(LTR_DATA_PATH, exist_ok=True) - class LTR: def __init__( self, @@ -61,6 +148,7 @@ def __init__( "rmse", "error", ] + self.esu = ESUtils() def write_model(self, model): """write model: writes model to file @@ -89,7 +177,7 @@ def read_xg_data(self, path=os.path.join(LTR_DATA_PATH, "xgboost.csv")): except Exception as e: logger.error("Could not read in data for training") - def read_mappings(self, path="gamechangerml/data/SearchPdfMapping.csv"): + def read_mappings(self, path=os.path.join(DATA_PATH, "SearchPdfMapping.csv")): """read mappings: reads search pdf mappings params: path to file returns: @@ -132,15 +220,14 @@ def post_model(self, model, model_name): returns: r: results """ - headers = {"Content-Type": "application/json"} query = { "model": { "name": model_name, "model": {"type": "model/xgboost+json", "definition": model}, } } - endpoint = ES_HOST + "/_ltr/_featureset/doc_features/_createmodel" - r = requests.post(endpoint, data=json.dumps(query), headers=headers) + endpoint = "/_ltr/_featureset/doc_features/_createmodel" + r = self.esu.post(endpoint, data=json.dumps(query)) return r.content def search(self, terms, rescore=True): @@ -222,7 +309,7 @@ def search(self, terms, rescore=True): } } } - r = client.search(index="gamechanger", body=dict(query)) + r = self.esu.client.search(index=ES_INDEX, body=dict(query)) return r def generate_judgement(self, mappings): @@ -290,7 +377,7 @@ def query_es_fts(self, df): query_list.append(json.dumps({"index": ES_INDEX})) query_list.append(json.dumps(q)) query = "\n".join(query_list) - res = client.msearch(body=query) + res = self.esu.client.msearch(body=query) ltr_log = [x["hits"]["hits"] for x in res["responses"]] return ltr_log @@ -501,19 +588,18 @@ def post_features(self): ], } } - headers = {"Content-Type": "application/json"} - endpoint = ES_HOST + "/_ltr/_featureset/doc_features" - r = requests.post(endpoint, data=json.dumps(query), headers=headers) + endpoint = "/_ltr/_featureset/doc_features" + r = self.esu.post(endpoint, data=json.dumps(query)) return r.content def post_init_ltr(self): - endpoint = ES_HOST + "/_ltr" - r = requests.put(endpoint) + endpoint = "/_ltr" + r = self.esu.put(endpoint) return r.content def delete_ltr(self, model_name="ltr_model"): - endpoint = ES_HOST + f"/_ltr/_model/{model_name}" - r = requests.delete(endpoint) + endpoint = f"/_ltr/_model/{model_name}" + r = self.esu.delete(endpoint) return r.content def normalize(self, arr, start=0, end=4): diff --git a/setup.py b/setup.py index db0eba90..e6bd89ff 100644 --- a/setup.py +++ b/setup.py @@ -64,7 +64,7 @@ def parse_readme(readme: Path) -> str: python_requires=">=3.8.0", install_requires=[ p for p in parse_requirements(REQUIREMENTS_PATH) - if re.split(r'\s*[@=]\s*')[0].lower() + if re.split(r'\s*[@=]\s*', p)[0].lower() not in EXCLUDE_PACKAGES ] + SUBSTITUTE_PACKAGES, include_package_data=True,