Skip to content

Commit

Permalink
Merge pull request #13 from tadascience/check_model
Browse files Browse the repository at this point in the history
+ `check_model()`
  • Loading branch information
romainfrancois authored Mar 8, 2024
2 parents 7c03cde + 1abcd42 commit 81fa2eb
Show file tree
Hide file tree
Showing 6 changed files with 54 additions and 32 deletions.
6 changes: 3 additions & 3 deletions R/authenticate.R
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
authenticate <- function(request, .call = caller_env()){
authenticate <- function(request, error_call = caller_env()){
key <- Sys.getenv("MISTRAL_API_KEY")
if (identical(key, "")) {
cli_abort(c(
cli_abort(call = error_call, c(
"Please set the {.code MISTRAL_API_KEY} environment variable",
i = "Get an API key from {.url https://console.mistral.ai/api-keys/}",
i = "Use {.code usethis::edit_r_environ()} to set the {.code MISTRAL_API_KEY} environment variable"
), call = .call)
))
}
req_auth_bearer_token(request, key)
}
Expand Down
27 changes: 10 additions & 17 deletions R/chat.R
Original file line number Diff line number Diff line change
@@ -1,8 +1,8 @@
req_chat <- function(text = "What are the top 5 R packages ?", model = "mistral-tiny", stream = FALSE, .call = caller_env()) {

req_chat <- function(text = "What are the top 5 R packages ?", model = "mistral-tiny", stream = FALSE, error_call = caller_env()) {
check_model(model, error_call = error_call)
request(mistral_base_url) |>
req_url_path_append("v1", "chat", "completions") |>
authenticate(.call = .call) |>
authenticate(error_call = error_call) |>
req_body_json(
list(
model = model,
Expand All @@ -17,7 +17,7 @@ req_chat <- function(text = "What are the top 5 R packages ?", model = "mistral-
)
}

resp_chat <- function(response) {
resp_chat <- function(response, error_call = current_env()) {
data <- resp_body_json(response)

tib <- map_dfr(data$choices, \(choice) {
Expand All @@ -43,24 +43,17 @@ print.chat_tibble <- function(x, ...) {
#'
#' @param text some text
#' @param which model to use. See [models()] for more information about which models are available
#' @param ... ignored
#' @inheritParams httr2::req_perform
#'
#' @return Result text from Mistral
#'
#' @examples
#' chat("Top 5 R packages")
#'
#' @export
chat <- function(text = "What are the top 5 R packages ?", model = "mistral-tiny") {

available_models <- models(.call = current_env())
if (!(model %in% available_models)) {
cli::cli_abort(c(
glue::glue("The model {model} is not available."),
"i" = "Please use the {.code models()} function to see the available models."
))
}

req <- req_chat(text, model)
resp <- req_perform(req)
resp_chat(resp)
chat <- function(text = "What are the top 5 R packages ?", model = "mistral-tiny", ..., error_call = current_env()) {
req <- req_chat(text, model, error_call = error_call)
resp <- req_perform(req, error_call = error_call)
resp_chat(resp, error_call = error_call)
}
21 changes: 18 additions & 3 deletions R/models.R
Original file line number Diff line number Diff line change
@@ -1,22 +1,37 @@
check_model <- function(model, error_call = caller_env()) {
available_models <- models(error_call = error_call)

if (!(model %in% available_models)) {
cli_abort(call = error_call, c(
"The model {model} is not available.",
"i" = "Please use the {.code models()} function to see the available models."
))
}

invisible(model)
}

#' Retrieve all models available in the Mistral API
#'
#' @inheritParams httr2::req_perform
#'
#' @return A character vector with the models available in the Mistral API
#'
#' @examples
#' models()
#'
#' @export
models <- function(.call = caller_env()) {
models <- function(error_call = caller_env()) {

req <- request(mistral_base_url) |>
req_url_path_append("v1", "models") |>
authenticate(.call = .call) |>
authenticate(call = call) |>
req_cache(tempdir(),
use_on_error = TRUE,
max_age = 2 * 60 * 60) # 2 hours

resp <- req_perform(req) |>
resp_body_json(simplifyVector = T)
resp_body_json(simplifyVector = TRUE)

resp |>
purrr::pluck("data","id")
Expand Down
10 changes: 3 additions & 7 deletions R/stream.R
Original file line number Diff line number Diff line change
@@ -1,12 +1,8 @@
#' @export
stream <- function(text, model = "mistral-tiny") {
if (!(model %in% models())) {
cli::cli_abort("The model ", model, " is not available.",
"i" = "Please use the {.code models()} function to see the available models."
)
}
stream <- function(text, model = "mistral-tiny", ..., error_call = current_env()) {
check_model(error_call = error_call)

req <- req_chat(text, model, stream = TRUE)
req <- req_chat(text, model, stream = TRUE, error_call = error_call)
resp <- req_perform_stream(req,
callback = stream_callback,
round = "line",
Expand Down
14 changes: 13 additions & 1 deletion man/chat.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

8 changes: 7 additions & 1 deletion man/models.Rd

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

0 comments on commit 81fa2eb

Please sign in to comment.