Skip to content

Commit

Permalink
VertexAI support - currently supports chat-bison, hopefully gemini up…
Browse files Browse the repository at this point in the history
…on release (#34)
  • Loading branch information
jondurbin authored Dec 6, 2023
1 parent fc15c5d commit 29aa8c7
Show file tree
Hide file tree
Showing 3 changed files with 234 additions and 13 deletions.
238 changes: 226 additions & 12 deletions airoboros/self_instruct.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,10 @@
import sys
import yaml
from collections import defaultdict
from google.auth.transport import requests as google_requests # type: ignore
from google.oauth2 import service_account # type: ignore
from loguru import logger
from time import sleep
from time import sleep, time
from tqdm import tqdm
from typing import List, Dict, Any
from uuid import uuid4
Expand All @@ -40,6 +42,25 @@
OPENAI_API_BASE_URL = "https://api.openai.com"
READABILITY_HINT = "The output should be written in such a way as to have a Flesch-Kincaid readability score of 30 or lower - best understood by those with college education. Only output the story - don't add any notes or information about Flesch-Kincaid scores."

# List of OpenAI models we support (there are others, but skipping for now...)
OPENAI_MODELS = [
"gpt-3.5-turbo-0613",
"gpt-3.5-turbo-0301",
"gpt-4-0314",
"gpt-4-32k-0314",
"gpt-4-1106-preview",
"gpt-4",
"gpt-4-32k",
"gpt-4-0613",
"gpt-4-32k-0613",
"gpt-3.5-turbo-1106",
"gpt-3.5-turbo",
"gpt-3.5-turbo-16k",
]

# Base URL for vertexai.
VERTEXAI_BASE_URL = "https://{region}-aiplatform.googleapis.com/v1/projects/{project_id}/locations/{region}/publishers/{publisher}/models/{model}:predict"


class SelfInstructor:
"""Class and methods used to generate instructions, based on self-instruct paper/code."""
Expand Down Expand Up @@ -75,10 +96,22 @@ def load_config(self):
self.openai_api_key = raw_config.get("openai_api_key") or os.environ.get(
"OPENAI_API_KEY"
)
if not self.openai_api_key:
raise ValueError(
"OPENAI_API_KEY environment variable or openai_api_key must be provided"
if raw_config.get("vertexai_credentials_path"):
self._vertexai_token = None
self._vertexai_token_date = None
self._vertexai_credentials_path = raw_config[
"vertexai_credentials_path"
]
self._vertexai_region = raw_config.get("vertexai_region", "us-central1")
self._vertexai_project_id = raw_config["vertexai_project_id"]
self._vertexai_publisher = raw_config.get(
"vertexai_publisher", "google"
)
if not self.openai_api_key:
if not raw_config.get("vertexai_credentials_path"):
raise ValueError(
"OpenAI API key or vertexai_credentials_path must be provided!"
)
self.organization_id = raw_config.get("organization_id")
self.topics_path = raw_config.get("topics_path") or "topics.txt"
self.output_path = raw_config.get("output_path") or "instructions.jsonl"
Expand Down Expand Up @@ -180,7 +213,31 @@ def initialize_index(self):
)
)

def validate_model(self, model):
def validate_vertexai_model(self, model):
"""Ensure the specified model is available in vertexai."""
if "chat" not in model:
raise ValueError(
"Currently, only the chat models are supported for vertexai, sorry"
)
test_payload = {
"instances": [{"messages": [{"author": "user", "content": "hello"}]}],
"parameters": {"temperature": 0.1, "maxOutputTokens": 1},
}
try:
headers = {"Authorization": f"Bearer {self.get_vertexai_token()}"}
url = VERTEXAI_BASE_URL.format(
region=self._vertexai_region,
project_id=self._vertexai_project_id,
publisher=self._vertexai_publisher,
model=model,
)
result = requests.post(url, json=test_payload, headers=headers)
assert result.status_code == 200
logger.success(f"Successfully validated model: {model}")
except Exception:
raise ValueError(f"Error trying to validate vertexai model: {model}")

def validate_openai_model(self, model):
"""Ensure the specified model is available."""
headers = {"Authorization": f"Bearer {self.openai_api_key}"}
if self.organization_id:
Expand All @@ -195,6 +252,12 @@ def validate_model(self, model):
raise ValueError(f"Model is not available to your API key: {model}")
logger.success(f"Successfully validated model: {model}")

def validate_model(self, model):
"""Validate a model (openai or vertexai)."""
if model in OPENAI_MODELS:
return self.validate_openai_model(model)
return self.validate_vertexai_model(model)

async def initialize_topics(self) -> List[str]:
"""Ensure topics are initialized, i.e. topics already exist and are read,
or a new list of topics is generated.
Expand Down Expand Up @@ -275,6 +338,75 @@ def load_template(path: str) -> str:
with open(path) as infile:
return infile.read()

def get_vertexai_token(self):
if self._vertexai_token and self._vertexai_token_date > time() - 300:
return self._vertexai_token
scopes = [
"https://www.googleapis.com/auth/cloud-platform",
"https://www.googleapis.com/auth/cloud-platform.read-only",
]
path = self._vertexai_credentials_path
credentials = service_account.Credentials.from_service_account_file(
path, scopes=scopes
)
credentials.refresh(google_requests.Request())
self._vertexai_token = credentials.token
self._vertexai_token_date = time()
return credentials.token

@backoff.on_exception(
backoff.fibo,
(
requests.exceptions.ConnectionError,
requests.exceptions.Timeout,
RateLimitError,
),
max_value=19,
)
async def _post_vertexai(self, model: str, payload: Dict[str, Any]) -> Dict[str, Any]:
"""Perform a post request to VertexAI (e.g., Bison/Gemini).
:param model: Model to use, e.g. "bison-text-32k"
:type model: str
:param payload: Dict containing request body/payload.
:type payload: Dict[str, Any]
:return: Response object.
:rtype: Dict[str, Any]
"""
headers = {"Authorization": f"Bearer {self.get_vertexai_token()}"}
request_id = str(uuid4())
logger.debug(f"POST [{request_id}] with payload {json.dumps(payload)}")
url = VERTEXAI_BASE_URL.format(
region=self._vertexai_region,
project_id=self._vertexai_project_id,
publisher=self._vertexai_publisher,
model=model,
)
async with aiohttp.ClientSession() as session:
async with session.post(url, json=payload, headers=headers) as result:
if result.status != 200:
logger.error(
f"Error querying Vertex AI: {result.status}: {await result.text()}"
)
code = None
try:
body = await result.json()
code = body["error"].get("code")
except Exception:
...
if code == 429:
await asyncio.sleep(3)
raise RateLimitError(await result.text())
raise Exception(
f"Error querying Vertex AI: [{code}]: {await result.text()}"
)
data = await result.json()
if data["predictions"][0].get("safetyAttributes", [{}])[0].get("blocked"):
raise Exception("Response blocked by vertex.")
return data

@backoff.on_exception(
backoff.fibo,
(
Expand All @@ -287,7 +419,7 @@ def load_template(path: str) -> str:
),
max_value=19,
)
async def _post(self, path: str, payload: Dict[str, Any]) -> Dict[str, Any]:
async def _post_openai(self, path: str, payload: Dict[str, Any]) -> Dict[str, Any]:
"""Perform a post request to OpenAI API.
:param path: URL path to send request to.
Expand Down Expand Up @@ -342,16 +474,91 @@ async def _post(self, path: str, payload: Dict[str, Any]) -> Dict[str, Any]:
logger.debug(f"token usage: {self.used_tokens}")
return result

async def _post_no_exc(self, *a, **k):
"""Post, ignoring all exceptions."""
async def _post_no_exc_openai(self, *a, **k):
"""Post to OpenAI, ignoring all exceptions."""
try:
return await self._post(*a, **k)
return await self._post_openai(*a, **k)
except Exception as ex:
logger.error(f"Error performing post: {ex}")
return None

async def generate_response(self, instruction: str, **kwargs) -> str:
"""Call OpenAI with the specified instruction and return the text response.
async def _post_no_exc_vertexai(self, *a, **k):
"""Post to VertexAI, ignoring all exceptions."""
try:
return await self._post_vertexai(*a, **k)
except Exception as ex:
logger.error(f"Error performing post: {ex}")
return None

async def generate_response_vertexai(self, instruction: str, **kwargs) -> str:
"""Call the model endpoint with the specified instruction and return the text response.
:param instruction: The instruction to respond to.
:type instruction: str
:return: Response text.
:rtype: str
"""
messages = copy.deepcopy(kwargs.pop("messages", None) or [])
filter_response = kwargs.pop("filter_response", True)
model = kwargs.get("model", self.model)

# Make sure our parameters conform to VertexAI specs.
payload = {**kwargs}
params = {
"maxOutputTokens": payload.pop("max_tokens", payload.pop("maxDecodeSteps", None)) or 2048
}
if "temperature" in payload:
params["temperature"] = payload.pop("temperature")
if "top_p" in payload:
params["topP"] = payload.pop("top_p")
if "top_k" in payload:
params["topK"] = payload.pop("top_k")
if "presence_penalty" in payload:
params["presencePenalty"] = payload.pop("presence_penalty")
if "frequency_penalty" in payload:
params["frequencyPenalty"] = payload.pop("frequency_penalty")
payload.pop("model", None)
payload["parameters"] = params
payload["instances"] = [{"messages": []}]
if messages and messages[0]["role"] == "system":
payload["instances"][0]["context"] = messages[0]["content"]
for message in messages:
if message["role"] == "system":
payload["instances"][0]["context"] = message["content"]
else:
payload["instances"][0]["messages"].append(
{
"author": message["role"],
"content": message["content"],
}
)
if instruction:
payload["instances"][0]["messages"].append(
{"author": "user", "content": instruction}
)

response = await self._post_no_exc_vertexai(model, payload)
if (
not response
or not response.get("predictions")
or not response["predictions"][0].get("candidates")
or not response["predictions"][0]["candidates"][0]["content"].strip()
):
return None
text = response["predictions"][0]["candidates"][0]["content"]
if filter_response:
for banned in self.response_filters:
if banned.search(text, re.I):
logger.warning(f"Banned response [{banned}]: {text}")
return None
if text.startswith(("I'm sorry,", "Apologies,", "I can't", "I won't")):
logger.warning(f"Banned response [apology]: {text}")
return None
return text.strip()

async def generate_response_openai(self, instruction: str, **kwargs) -> str:
"""Call the model endpoint with the specified instruction and return the text response.
:param instruction: The instruction to respond to.
:type instruction: str
Expand All @@ -369,7 +576,7 @@ async def generate_response(self, instruction: str, **kwargs) -> str:
payload["messages"] = messages
if instruction:
payload["messages"].append({"role": "user", "content": instruction})
response = await self._post_no_exc(path, payload)
response = await self._post_no_exc_openai(path, payload)
if (
not response
or not response.get("choices")
Expand All @@ -388,6 +595,13 @@ async def generate_response(self, instruction: str, **kwargs) -> str:
return None
return text

async def generate_response(self, instruction: str, **kwargs) -> str:
"""Generate a response - wrapper around the openai/vertexai methods above."""
model = kwargs.pop("model", None) or self.model
if model in OPENAI_MODELS:
return await self.generate_response_openai(instruction, **kwargs)
return await self.generate_response_vertexai(instruction, **kwargs)

async def is_decent_response(self, item):
"""Filter the responses by having the LLM score based on a set of rules."""
config = self.raw_config.get("scoring", {})
Expand Down
6 changes: 6 additions & 0 deletions example-config.yaml
Original file line number Diff line number Diff line change
@@ -1,6 +1,12 @@
# The model to use in generation. Available models: https://platform.openai.com/docs/models/continuous-model-upgrades
model: "gpt-4"

# To use VertexAI (e.g. chat-bison or soon hopefully gemini):
# model: "chat-bison"
# vertexai_credentials_path: creds.json
# vertexai_project_id: replace-project-id
# vertexai_publisher: google

# OpenAI API key (if null, uses environment variable OPENAI_API_KEY)
openai_api_key:

Expand Down
3 changes: 2 additions & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,7 @@

setup(
name="airoboros",
version="2.2.0",
version="2.2.1",
description="Updated and improved implementation of the self-instruct system.",
long_description=long_description,
long_description_content_type="text/markdown",
Expand Down Expand Up @@ -35,6 +35,7 @@
"uvicorn>=0.23.0",
"flash_attn==2.1.0",
"optimum==1.12.0",
"google-auth==2.25.1",
],
extras_require={
"dev": [
Expand Down

0 comments on commit 29aa8c7

Please sign in to comment.