diff --git a/README.md b/README.md index 6d0a8b2..016f5a1 100644 --- a/README.md +++ b/README.md @@ -41,31 +41,85 @@ By default, climate-data will search all possible given mirrors for reliability `/search/esgf` Required Parameters: - * `query`: Natural language string with search terms to retrieve datasets for. + * `query`: Natural language string (OR keywords/raw Lucene query, see: optional parameters) with search terms to retrieve datasets for. -Example: `/search/esgf?query=historical eastward wind 100 km cesm2 r11i1p1f1 cfday` +Optional Parameters: + * `keywords`: Pass a keyword-oriented search to ESGF. Keyword-oriented searches are not passed to the LLM. Listing keywords or providing a raw Lucene query is supported. + +##### Natural Language Search Example: + +Search: "find me datasets about max air temperature monthly with a community earth model and ssp3 7.0" + +URL: `/search/esgf?query=find me datasets about max air temperature monthly with a community earth model and ssp3 7.0` Output: ```json { - "results": [ - { - "metadata": { - "id": "CMIP6.CMIP.NCAR.CESM2.historical.r11i1p1f1.CFday.ua.gn.v20190514|aims3.llnl.gov", - "version": "20190514"... - } - }, ... - ] + "query": { + "raw": "(Daily Maximum Near-Surface Air Temperature OR Near-Surface Air Temperature) AND (tasmax OR tas) AND CESM2 AND ssp370 AND NCAR AND mon", + "search_terms": { + "variable_descriptions": [ + "Daily Maximum Near-Surface Air Temperature", + "Near-Surface Air Temperature", + "" + ], + "variable": [ + "tasmax", + "tas", + "" + ], + "source_id": "CESM2", + "experiment_id": "ssp370", + "nominal_resolution": "", + "institution_id": "NCAR", + "variant_label": "", + "frequency": "mon" + } + }, + "results": [ + { + "metadata": { + "id": "CMIP6.ScenarioMIP.NCAR.CESM2-WACCM.ssp370.r1i1p1f1.Amon.tas.gn.v20190815|esgf-data04.diasjp.net", + "version": "20190815" + }, ... + } + ] +} +``` + +##### Keyword Search Example: + +Search: "historical eastward wind 100 km cesm2 r11i1p1f1 cfday" + +URL: `/search/esgf?keywords=True&query=historical eastward wind 100 km cesm2 r11i1p1f1 cfday` + +Output: +```json +{ + "query": { + "original": "historical eastward wind 100 km cesm2 r11i1p1f1 cfday", + "raw": "historical AND eastward AND wind AND 100 AND km AND cesm2 AND r11i1p1f1 AND cfday" + }, + "results": [ + { + "metadata": { + "id": "CMIP6.CMIP.NCAR.CESM2.historical.r11i1p1f1.CFday.ua.gn.v20190514|aims3.llnl.gov", + "version": "20190514"... + } + }, ... + ] } ``` `results` is a list of datasets, sorted by relevance. -Each dataset contains a `metadata` field. +Each dataset contains a `metadata` field and a `query` field. `metadata` contains all of the stored metadata for the data set, provided by ESGF, such as experiment name, title, variables, geospatial coordinates, time, frequency, resolution, and more. -The `metadata` field contains an `id` field that is used for subsequent processing and lookups, containing the full dataset ID with revision and node information, such as: `CMIP6.CMIP.NCAR.CESM2.historical.r11i1p1f1.CFday.ua.gn.v20190514|esgf-data.ucar.edu` +The `metadata` field contains an `id` field that is used for subsequent processing and lookups, containing the full dataset ID with revision and node information, such as: `CMIP6.CMIP.NCAR.CESM2.historical.r11i1p1f1.CFday.ua.gn.v20190514|esgf-data.ucar.edu` + +`query` contains information about the search processing itself. One subfield is always present: `raw`, containing what is directly passed to the ESGF node. `search_terms` is an object mapping facet keys to LLM keywords for natural language searches. `original` is present on a keyword search that was converted to a Lucene query. #### Preview diff --git a/api/search/providers/esgf.py b/api/search/providers/esgf.py index 21a0d2d..b98f53f 100644 --- a/api/search/providers/esgf.py +++ b/api/search/providers/esgf.py @@ -1,4 +1,5 @@ import re + from api.settings import default_settings from api.search.provider import ( AccessURLs, @@ -12,134 +13,134 @@ import itertools import dask from openai import OpenAI -from numpy import dot import json -import numpy as np -import pandas as pd -from pathlib import Path -import pickle - -NATURAL_LANGUAGE_PROCESSING_CONTEXT = """ -You are a tool to extract keyword search terms by category from a given search request. - -The given keyword fields are: frequency, nominal_resolution, lower_time_bound, upper_time_bound, and description. - -The definitions of the keyword fields are as follows. - -frequency is a duration. -Possible example values for frequency values are: 6 hours, 6hrs, 3hr, daily, day, yearly, 12 hours, 12 hr - -nominal_resolution is a measure of distance. -Possible example values for resolution are: 100 km, 100km, 200km, 200 km, 2x2 degrees, 1x1 degrees, 1x1, 2x2, 20000 km - -lower_time_bound and upper_time_bound are measures of time. -When extracted from user input, convert them to UTC ISO 8601 format. -Possible example values include: - -"after 2022" = lower_time_bound: 2022-00-00T00:00:00Z -"between march 2021 and april 2023" = lower_time_bound: 2021-03-00T00:00:00Z ; upper_time_bound: 2023-04-00T00:00:00Z -"before september 1995" = upper_time_bound: 1995-09-00T00:00:00Z - -description is a text field and contains all other unprocessed information. - -Return the fields as a JSON object and include no other information. - -Examples of full processing are as follows. - -Input: -100km before 2023 daily air temperature -Output: -{ - "frequency": "daily", - "nominal_resolution": "100km", - "upper_time_bound": "2023-00-00T00:00:00Z", - "description": "air temperature" -} - -Input: 2x2 degree relative humidity between june 1997 and july 1999 6hr -Output: -{ - "frequency": "6hr", - "nominal_resolution": "2x2 degree", - "lower_time_bound": "1997-06-00T00:00:00Z", - "upper_time_bound": "1999-07-00T00:00:00Z", - "description": "relative humidity" -} - -Input: ts -Output: { - "description": "ts" -} - -Input: Find me datasets with the variable relative humidity -Output: { - "description": "relative humidity" -} - -Input: datasets before june 1995 the variable surface temperature model BCC-ESM1 -Output: { - "upper_time_bound": "1995-06-00T00:00:00Z", - "description": "surface temperature BCC-ESM1" -} -Only return JSON. + + +def generate_natural_language_system_prompt(facets: dict[str, list[str]]) -> str: + return f""" +You are an assistant trying to help a user determine which variables, sources, experiments, resolutions, +variants, institutions, and frequencies from ESGF's CMIP6 are being referenced in their natural language query. + +Here is a list of variable_descriptions: {facets['variable_long_name']} +Here is a list of variables: {facets['variable_id']} +Here is a list of source_ids: {facets['source_id']} +Here is a list of experiment_ids: {facets['experiment_id']} +Here is a list of nominal_resolutions: {facets['nominal_resolution']} +Here is a list of institution_ids: {facets['institution_id']} +Here is a list of variant_labels: {facets['variant_label']} +Here is a list of frequencies: {facets['frequency']} + +You should respond by building a dictionary that has the following keys: + [variable_descriptions, variable, source_id, experiment_id, nominal_resolution, institution_id, variant_label, frequency] + +Please select up to three variable_descriptions from the variable_descriptions list that most closely matches the user's query and assign those variable_descriptions to the variable_descriptions key. +If none clearly and obviously match, assign an empty string ''. + +Please select up to three variables from the variables list that most closely matches the user's query and assign those variables to the variable key. +If none clearly and obviously match, assign an empty string ''. + +Please select one and ONLY ONE source_id from the source_ids list that most closely matches the user's query and assign ONLY that source_id to the source_id key. +If none clearly and obviously match, assign an empty string ''." + +Please select one and ONLY ONE experiment_id from the experiment_ids list that most closely matches the user's query and assign ONLY that experiment_id to the experiment_id key." \ +If none clearly and obviously match, assign an empty string ''. + +Please select one and ONLY ONE nominal_resolution from the nominal_resolutions list that most closely matches the user's query and assign ONLY that nominal_resolution to the nominal_resolution key." \ +If none clearly and obviously match, assign an empty string ''. + +Please select one and ONLY ONE institution_id from the institution_ids list that most closely matches the user's query and assign ONLY that institution_id to the institution_id key." \ +If none clearly and obviously match, assign an empty string ''. + +Please select one and ONLY ONE variant_label from the variant_labels list that most closely matches the user's query and assign ONLY that variant_label to the variant_label key." \ +If none clearly and obviously match, assign an empty string ''. + +Please select one and ONLY ONE frequency from the frequencies list that most closely matches the user's query and assign ONLY that frequency to the frequency key." \ +If none clearly and obviously match, assign an empty string ''. + +Ensure that your response is properly formatted JSON please. + +Also, when you are selecting variable, source_id, experiment_id, nominal_resolution, institution_id, variant_label, and frequency make sure to select" \ +the most simple and obvious choice--no fancy footwork here please. """ -# cosine matching threshold to greedily take term -GREEDY_EXTRACTION_THRESHOLD = 0.93 - -SEARCH_FACETS = { - # match by cosine similarity - "similar": [ - "experiment_title", - "cf_standard_name", - "variable_long_name", - "variable_id", - "table_id", - "source_type", - "source_id", - "activity_id", - "nominal_resolution", - "frequency", - "realm", - ], - # only take exact matches - "exact": [ - "institution_id", - "variant_label", - "experiment_id", - "grid_label", - ], - # create embeddings, but handle manually elsewhere - optimization on a specific field - "other": ["nominal_resolution", "frequency"], -} + +SEARCH_FACETS = [ + "experiment_title", + "cf_standard_name", + "variable_long_name", + "variable_id", + "table_id", + "source_type", + "source_id", + "activity_id", + "nominal_resolution", + "frequency", + "realm", + "institution_id", + "variant_label", + "experiment_id", + "grid_label", + "nominal_resolution", + "frequency", +] class ESGFProvider(BaseSearchProvider): def __init__(self, openai_client): print("initializing esgf search provider") self.client: OpenAI = openai_client - self.embeddings = {} - - def initialize_embeddings(self, force_refresh=False): - """ - creates string embeddings if needed, otherwise reloads from cache. - force_refresh is needed if the list of facets changes. - """ - cache = Path("./embedding_cache") - if cache.exists() and not force_refresh: - print("embedding cache exists", flush=True) - with cache.open("rb") as f: - self.embeddings = pickle.load(f) - else: - print("no embedding cache, generating new", flush=True) - with cache.open(mode="wb") as f: - try: - self.embeddings = self.extract_embedding_strings() - except Exception as e: - raise IOError( - f"failed to access OpenAI: is OPENAI_API_KEY set in env?: {e}" - ) - pickle.dump(self.embeddings, f) + self.search_mirrors = [ + default_settings.esgf_url, + *default_settings.esgf_fallbacks.split(","), + ] + self.current_mirror_index = 0 + self.retries = 0 + self.max_retries = len(self.search_mirrors) + + self.with_all_available_mirrors(self.get_facet_possiblities) + + def increment_mirror(self): + self.current_mirror_index += 1 + self.current_mirror_index = self.current_mirror_index % len(self.search_mirrors) + + def with_all_available_mirrors(self, func, *args, **kwargs) -> Any: + self.retries = 0 + return_value = None + while self.retries < self.max_retries: + try: + return_value = func(*args, **kwargs) + break + except Exception as e: + print( + f"failed to run: retry {self.retries}, mirror: {self.search_mirrors[self.current_mirror_index]}", + flush=True, + ) + self.increment_mirror() + self.retries += 1 + if self.retries >= self.max_retries: + raise Exception(f"failed after {self.retries} retries: {e}") + return return_value + + def get_esgf_url_with_current_mirror(self) -> str: + mirror = self.search_mirrors[self.current_mirror_index] + return f"{mirror}/search" + + def get_facet_possiblities(self): + query = { + "project": "CMIP6", + "facets": ",".join(SEARCH_FACETS), + "limit": "0", + "format": "application/solr+json", + } + base_url = self.get_esgf_url_with_current_mirror() + response = requests.get(base_url, params=query) + if response.status_code >= 300: + msg = f"failed to fetch available facets: {response.status_code}, {response.content}" + raise Exception(msg) + facets = response.json() + self.facet_possibilities = facets["facet_counts"]["facet_fields"] + for facet, terms in self.facet_possibilities.items(): + self.facet_possibilities[facet] = terms[0::2] def is_terarium_hmi_dataset(self, dataset_id: str) -> bool: """ @@ -148,28 +149,31 @@ def is_terarium_hmi_dataset(self, dataset_id: str) -> bool: p = re.compile(r"^[0-9a-f]{8}-([0-9a-f]{4}-){3}[0-9a-f]{12}$") return bool(p.match(dataset_id.lower())) - def search( - self, query: str, page: int, force_refresh_cache: bool = False - ) -> DatasetSearchResults: + def search(self, query: str, page: int, keywords: bool) -> dict[str, Any]: """ converts a natural language query to a list of ESGF dataset metadata dictionaries by running a lucene query against the given ESGF node in settings. + + keywords: pass keywords directly to ESGF with no LLM in the middle """ - if len(self.embeddings.keys()) == 0 or force_refresh_cache: - self.initialize_embeddings(force_refresh_cache) + if keywords: + print(f"keyword searching for {query}", flush=True) + return self.keyword_search(query, page) return self.natural_language_search(query, page) def get_all_access_paths_by_id(self, dataset_id: str) -> AccessURLs: return [ - self.get_access_paths_by_id(id) - for id in self.get_mirrors_for_dataset(dataset_id) + self.with_all_available_mirrors(self.get_access_paths_by_id, id) + for id in self.with_all_available_mirrors( + self.get_mirrors_for_dataset, dataset_id + ) ] def get_mirrors_for_dataset(self, dataset_id: str) -> List[str]: # strip vert bar if provided with example mirror attached dataset_id = dataset_id.split("|")[0] - response = self.run_esgf_query(f"id:{dataset_id}*", 1, {}) + response = self.run_esgf_dataset_query(f"id:{dataset_id}*", 1, {}) full_ids = [d.metadata["id"] for d in response] return full_ids @@ -178,7 +182,7 @@ def get_datasets_from_id(self, dataset_id: str) -> List[Dict[str, Any]]: returns a list of datasets for a given ID. includes mirrors. """ if dataset_id == "": - return {} + return [] params = urlencode( { "type": "File", @@ -187,7 +191,8 @@ def get_datasets_from_id(self, dataset_id: str) -> List[Dict[str, Any]]: "limit": 200, } ) - full_url = f"{default_settings.esgf_url}/search?{params}" + base_url = self.get_esgf_url_with_current_mirror() + full_url = f"{base_url}?{params}" r = requests.get(full_url) response = r.json() if r.status_code != 200: @@ -236,14 +241,33 @@ def get_metadata_for_dataset(self, dataset_id: str) -> Dict[str, Any]: def get_access_paths(self, dataset: Dataset) -> AccessURLs: return self.get_all_access_paths_by_id(dataset.metadata["id"]) + def keyword_search(self, query: str, page: int) -> dict[str, Any]: + """ + converts a list of keywords to an ESGF query and runs it against the node. + """ + lucene_query_statements = ["AND", "OR", "(", ")"] + if any([query.find(substring) != -1 for substring in lucene_query_statements]): + datasets = self.run_esgf_dataset_query(query, page, options={}) + return {"query": {"raw": query}, "results": datasets} + else: + stripped_query = re.sub(r"[^A-Za-z0-9 ]+", "", query) + lucene_query = " AND ".join(stripped_query.split(" ")) + datasets = self.run_esgf_dataset_query(lucene_query, page, options={}) + return { + "query": { + "original": query, + "raw": lucene_query, + }, + "results": datasets, + } + def natural_language_search( self, search_query: str, page: int, retries=0 - ) -> DatasetSearchResults: + ) -> dict[str, Any]: """ converts to natural language and runs the result against the ESGF node, returning a list of datasets. """ search_terms_json = self.process_natural_language(search_query) - print(search_terms_json, flush=True) try: search_terms = json.loads(search_terms_json) except ValueError as e: @@ -252,15 +276,33 @@ def natural_language_search( ) if retries >= 3: print("openAI returned non-json in multiple retries, exiting") - return [] + return { + "error": f"openAI returned non-json in multiple retries. raw text: {search_terms_json}" + } return self.natural_language_search(search_query, page, retries + 1) - query = self.generate_query_string(search_terms) - options = self.generate_temporal_coverage_query(search_terms) - - print(query, flush=True) - if query == "": - return [] - return self.run_esgf_query(query, page, options) + query = " AND ".join( + [ + ( + search_term.strip() + if isinstance(search_term, str) + else "({})".format( + " OR ".join( + filter(lambda term: term.strip() != "", search_term) + ) + ) + ) + for search_term in filter( + lambda element: element != "", search_terms.values() + ) + ] + ) + datasets = self.with_all_available_mirrors( + self.run_esgf_dataset_query, query, page, options={} + ) + return { + "query": {"raw": query, "search_terms": search_terms}, + "results": datasets, + } def build_natural_language_prompt(self, search_query: str) -> str: """ @@ -273,9 +315,14 @@ def process_natural_language(self, search_query: str) -> str: runs query against LLM and returns the result string. """ response = self.client.chat.completions.create( - model="gpt-4", + model="gpt-4-0125-preview", messages=[ - {"role": "system", "content": NATURAL_LANGUAGE_PROCESSING_CONTEXT}, + { + "role": "system", + "content": generate_natural_language_system_prompt( + self.facet_possibilities + ), + }, { "role": "user", "content": self.build_natural_language_prompt(search_query), @@ -283,12 +330,10 @@ def process_natural_language(self, search_query: str) -> str: ], temperature=0.7, ) - query = response.choices[0].message.content or "" - print(query) - query = query[query.find("{") :] - return query + keywords_json = response.choices[0].message.content + return keywords_json - def run_esgf_query( + def run_esgf_dataset_query( self, query_string: str, page: int, options: Dict[str, str] ) -> DatasetSearchResults: """ @@ -309,7 +354,8 @@ def run_esgf_query( | options ) - full_url = f"{default_settings.esgf_url}/search?{encoded_string}" + base_url = self.get_esgf_url_with_current_mirror() + full_url = f"{base_url}?{encoded_string}" r = requests.get(full_url) if r.status_code != 200: error = str(r.content) @@ -327,250 +373,3 @@ def run_esgf_query( for metadata in response["response"]["docs"] ] )[0] - - def get_embedding(self, text): - """returns an embedding for a single string.""" - return ( - self.client.embeddings.create(input=[text], model="text-embedding-ada-002") - .data[0] - .embedding - ) - - def get_embeddings(self, text): - """returns a list of embeddings for a list of strings.""" - return [ - e.embedding - for e in self.client.embeddings.create( - input=text, model="text-embedding-ada-002" - ).data - ] - - def cosine_similarity(self, a, b): - return dot(a, b) / (np.linalg.norm(a) * np.linalg.norm(b)) - - def extract_embedding_strings(self) -> Dict[str, pd.DataFrame]: - """ - builds embeddings dictionary given the desired SEARCH_FACETS. finds possible - values for given facets from ESGF node and then gets string embeddings of those - enumerated values. - """ - desired_facets = [item for inner in SEARCH_FACETS.values() for item in inner] - print(desired_facets) - encoded_string = urlencode( - { - "project": "CMIP6", - "facets": ",".join(desired_facets), - "limit": "0", - "format": "application/solr+json", - } - ) - facet_possibilities = ( - f"https://esgf-node.llnl.gov/esg-search/search?{encoded_string}" - ) - - print("querying fields", flush=True) - r = requests.get(facet_possibilities) - if r.status_code != 200: - raise ConnectionError( - f"Failed to get facet potential values from ESGF node: {facet_possibilities} {r.status_code}" - ) - response = r.json() - fields = response["facet_counts"]["facet_fields"] - print("aggregating fields", flush=True) - embeddings = { - f: pd.DataFrame({"string": fields[f][::2]}) for f in desired_facets - } - print("creating embeddings...", flush=True) - for k in embeddings.keys(): - print(f" embeddings for: {k}", flush=True) - # drop '' and other falsy strings - embeddings[k] = embeddings[k][embeddings[k].string.astype(bool)] - embeddings[k]["embed"] = self.get_embeddings(embeddings[k].string.to_list()) - - return embeddings - - def get_single_best_match(self, text, similar_fields): - @dask.delayed - def get_best_match_from_field(self, text, field): - embedding = ( - self.get_embedding(text) - if field != "source_id" - else self.get_embedding(text.upper()) - ) - self.embeddings[field]["similarities"] = self.embeddings[field][ - "embed" - ].apply(lambda e: self.cosine_similarity(e, embedding)) - best_match = ( - self.embeddings[field] - .sort_values("similarities", ascending=False) - .head(3) - ) - string = best_match.string.values[0] - similarity = best_match.similarities.values[0] - print(f" {string} => {similarity}") - return (string, similarity) - - results = map( - lambda x: get_best_match_from_field(self, text, x), similar_fields - ) - computed: list = list(dask.compute(results))[0] - # sort by similarity value, descending - computed.sort(key=lambda x: x[1], reverse=True) - print(computed) - return computed[0] or ("", 0.00) - - def extract_relevant_description(self, description: str) -> List[str]: - """ - takes the LLM-extracted description field and parses it into meaningful - terms to build into the formatted apache lucene query. - """ - # experiment id and variant id are best taken as exact match rather than assumed by cosine - # general idea: - # break on word boundary and take... - # exact matches to experiment id and variant_label - # anything that's over GREEDY_EXTRACTION_THRESHOLD - # otherwise... - # take non-matching inputs and conjoin them back into a phrase to take highest match across all categories - # take the most relevant between averaged individual token similarities and the whole phrase - tokens = description.replace(",", " ").split() - - # looking for exact match on an ESGF dataset full ID would be a dict of 10+M entries - # so we can leverage breaking apart the longform id into each component period-separated - # as individual tokens. much faster and cleaner. - tokens = [ - t for exploded in [token.split(".") for token in tokens] for t in exploded - ] - # after the above, date stamps aren't in the same format in the version field, - # so we strip according to the format if it perfectly matches, then use it as free-text - # rather than a field to check. this happens below, during exact match - - matched = [] - exact_match_values = [ - match - for nested_list in [ - self.embeddings[field].string.values for field in SEARCH_FACETS["exact"] - ] - for match in nested_list - ] - - fallback_similarities = [] - - print(f"finding best terms for {tokens}") - - # first check all as a phrase, kick back to individual tokens if nothing fits well (greedy threshold) - conjoined_phrase, conjoined_similarity = self.get_single_best_match( - " ".join(tokens), SEARCH_FACETS["similar"] - ) - if conjoined_similarity > 0.935: - print( - f" greedily taking full phrase: {conjoined_phrase} at {conjoined_similarity}" - ) - matched.append(conjoined_phrase) - tokens = [] - - # parallel inner iterator for tokens - refactor of "remove from leftover tokens, - # append to matched" workflow. returns (matched, fallback) to be zipped over; - # if matched, return (token, None), if fallback, return (None, (phrase, similarity)) - @dask.delayed - def inner_iterator(t): - if len(t) == 9 and t[0] == "v" and t[1:].isdigit(): - print(f" date match: {t}") - return (t[1:], None) - if t in exact_match_values: - print(f" exact match: {t}") - return (t, None) - else: - print(f" approximate matching for {t}") - phrase, similarity = self.get_single_best_match( - t, SEARCH_FACETS["similar"] - ) - if similarity >= GREEDY_EXTRACTION_THRESHOLD: - print( - f" matched word {t} -> {phrase} over threshold {GREEDY_EXTRACTION_THRESHOLD}: {similarity}" - ) - return (phrase, (phrase, similarity)) - else: - print( - f" closest match {t} -> {phrase} is under threshold {GREEDY_EXTRACTION_THRESHOLD}: {similarity}" - ) - return (None, (phrase, similarity)) - - # zip(*x) is inverse to zip(x) - filter nones, split the two lists that were done in parallel - results = list(list(dask.compute(map(inner_iterator, tokens[:])))[0]) - if len(results) == 0: - return matched - matched, fallback_similarities = list( - map(lambda x: list(filter(lambda y: y is not None, x)), zip(*results)) - ) - # removed matched tokens. some require transformations, e.g. upper() - tokens = [t for t in tokens if t not in matched and t.upper() not in matched] - - if len(tokens) == 0: - print(f"finalized search terms are {matched}") - return matched - - print(f" leftover tokens: {tokens}\nmatching for whole phrase") - - conjoined_phrase, conjoined_similarity = self.get_single_best_match( - " ".join(tokens), SEARCH_FACETS["similar"] - ) - - fallback_similarities = [f for f in fallback_similarities if f[0] in tokens] - if len(fallback_similarities) == 0: - if conjoined_similarity >= 0.90: - matched.append(conjoined_phrase) - return matched - - avg_sim = sum((map(lambda f: f[1], fallback_similarities))) / len( - fallback_similarities - ) - print(f" conjoined similarity {conjoined_similarity} - avg by parts {avg_sim}") - if conjoined_similarity >= avg_sim: - print(f" using conjoined phrase {conjoined_phrase}") - matched.append(conjoined_phrase) - else: - for part in fallback_similarities: - matched += part[0] - - print(f"finalized search terms are {matched}") - - return matched - - def generate_query_string(self, search_terms: Dict[str, str]) -> str: - """ - handles LLM-extracted fields separately as needed and returns the formatted lucene query. - """ - desired_terms = ["nominal_resolution", "frequency"] - best_matches = [] - - for desired in desired_terms: - if desired in search_terms: - print(f"{search_terms[desired], desired}") - phrase, sim = self.get_single_best_match( - search_terms[desired], [desired] - ) - if sim >= GREEDY_EXTRACTION_THRESHOLD: - best_matches.append(phrase) - else: - print( - f" failed to find good candidate for {desired}: '{search_terms[desired]}'" - ) - search_terms["description"] += f" {search_terms[desired]}" - - description = [] - if "description" in search_terms: - description = self.extract_relevant_description(search_terms["description"]) - query_string = " AND ".join(map(lambda t: f'"{t}"', best_matches + description)) - print(f"lucene query: {query_string}") - return query_string - - def generate_temporal_coverage_query(self, terms: Dict[str, str]) -> Dict[str, str]: - """ - creates ESGF search time bound arguments. - """ - query = {} - if "upper_time_bound" in terms: - query["end"] = terms["upper_time_bound"] - if "lower_time_bound" in terms: - query["start"] = terms["lower_time_bound"] - return query diff --git a/api/server.py b/api/server.py index 2ff4038..ae34793 100644 --- a/api/server.py +++ b/api/server.py @@ -15,8 +15,6 @@ client = OpenAI() esgf = ESGFProvider(client) -esgf.initialize_embeddings() - era5 = ERA5Provider(client) @@ -31,12 +29,12 @@ async def job_status(job_id: str, redis=Depends(get_redis)): @app.get("/search/esgf") -async def esgf_search(query: str = "", page: int = 1, refresh_cache=False): +async def esgf_search(query: str = "", page: int = 1, keywords: bool = False): try: - datasets = esgf.search(query, page, refresh_cache) + datasets = esgf.search(query, page, keywords) except Exception as e: return {"error": f"failed to fetch datasets: {e}"} - return {"results": datasets} + return datasets @app.get("/search/era5") diff --git a/api/settings.py b/api/settings.py index b97d1dd..16fd521 100644 --- a/api/settings.py +++ b/api/settings.py @@ -3,11 +3,28 @@ from pydantic_settings import BaseSettings import os +DEFAULT_ESGF_FALLBACKS = [ + "https://esgf-node.ornl.gov/esg-search", + "https://ds.nccs.nasa.gov/esg-search", + "https://dpesgf03.nccs.nasa.gov/esg-search", + "https://esg-dn1.nsc.liu.se/esg-search", + "https://esg-dn2.nsc.liu.se/esg-search", + "https://esg-dn3.nsc.liu.se/esg-search", + "https://cmip.bcc.cma.cn/esg-search", + "http://cmip.fio.org.cn/esg-search", + "http://cordexesg.dmi.dk/esg-search", + "http://data.meteo.unican.es/esg-search", + "http://esg-cccr.tropmet.res.in/esg-search", +] + class Settings(BaseSettings): esgf_url: str = Field( os.environ.get("ESGF_URL", "https://esgf-node.llnl.gov/esg-search") ) + esgf_fallbacks: str = Field( + os.environ.get("ESGF_FALLBACKS", ",".join(DEFAULT_ESGF_FALLBACKS)) + ) esgf_openid: Tuple[str, str] = Field( (os.environ.get("ESGF_OPENID_USER", ""), os.environ.get("ESGF_OPENID_PASS", "")) ) diff --git a/env.example b/env.example index 9a88888..3d319ae 100644 --- a/env.example +++ b/env.example @@ -4,6 +4,7 @@ TERARIUM_USER={TERARIUM_USER} TERARIUM_PASS={TERARIUM_PASS} ESGF_URL="https://esgf-node.llnl.gov/esg-search" +ESGF_FALLBACKS='https://esgf-node.ornl.gov/esg-search,https://ds.nccs.nasa.gov/esg-search,https://dpesgf03.nccs.nasa.gov/esg-search,https://esg-dn1.nsc.liu.se/esg-search,https://esg-dn2.nsc.liu.se/esg-search,https://esg-dn3.nsc.liu.se/esg-search,https://cmip.bcc.cma.cn/esg-search,http://cmip.fio.org.cn/esg-search,http://cordexesg.dmi.dk/esg-search,http://data.meteo.unican.es/esg-search,http://esg-cccr.tropmet.res.in/esg-search' ESGF_OPENID_USER={ESGF_OPENID_USER} ESGF_OPENID_PASS={ESGF_OPENID_PASS}