From f7b8ebc7c86da43e59b5ba59734adbf1eb3ac25a Mon Sep 17 00:00:00 2001 From: David Cristofaro Date: Fri, 28 Apr 2023 22:09:12 +1000 Subject: [PATCH] Break into modules etc. --- Cargo.lock | 27 ------ Cargo.toml | 4 +- Dockerfile | 4 +- Justfile | 19 +++- src/env.rs | 26 ++++++ src/error.rs | 11 ++- src/main.rs | 217 +++----------------------------------------- src/signal.rs | 23 +++++ src/web/generate.rs | 109 ++++++++++++++++++++++ src/web/health.rs | 38 ++++++++ src/web/mod.rs | 35 +++++++ 11 files changed, 275 insertions(+), 238 deletions(-) create mode 100644 src/env.rs create mode 100644 src/signal.rs create mode 100644 src/web/generate.rs create mode 100644 src/web/health.rs create mode 100644 src/web/mod.rs diff --git a/Cargo.lock b/Cargo.lock index 0d195bf..fb44216 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -23,12 +23,6 @@ dependencies = [ "alloc-no-stdlib", ] -[[package]] -name = "anyhow" -version = "1.0.70" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "7de8ce5e0f9f8d88245311066a578d72b7af3e7088f32783804676302df237e4" - [[package]] name = "async-compression" version = "0.3.15" @@ -418,7 +412,6 @@ dependencies = [ name = "gpt-html" version = "0.1.0" dependencies = [ - "anyhow", "axum", "axum-extra", "eventsource-stream", @@ -426,7 +419,6 @@ dependencies = [ "reqwest", "serde", "serde_json", - "strum_macros", "tokio", "tower-http", ] @@ -456,12 +448,6 @@ version = "0.12.3" source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "8a9ee70c43aaf417c914396645a0fa852624801b24ebb7ae78fe8272889ac888" -[[package]] -name = "heck" -version = "0.4.1" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "95505c38b4572b2d910cecb0281560f54b440a19336cbbcb27bf6ce6adc6f5a8" - [[package]] name = "hermit-abi" version = "0.2.6" @@ -1106,19 +1092,6 @@ dependencies = [ "winapi", ] -[[package]] -name = "strum_macros" -version = "0.24.3" -source = "registry+https://github.com/rust-lang/crates.io-index" -checksum = "1e385be0d24f186b4ce2f9982191e7101bb737312ad61c1f2f984f34bcf85d59" -dependencies = [ - "heck", - "proc-macro2", - "quote", - "rustversion", - "syn 1.0.109", -] - [[package]] name = "syn" version = "1.0.109" diff --git a/Cargo.toml b/Cargo.toml index bc885b4..54249af 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -13,10 +13,10 @@ futures = "0.3" reqwest = { version = "0.11", features = ["json", "stream"] } serde = { version = "1", features = ["derive"] } serde_json = "1" -strum_macros = "0.24" +# strum_macros = "0.24" tokio = { version = "1", features = ["full"] } # tower = { version = "0.4", features = ["full"] } tower-http = { version = "0.4", features = ["full"] } [dev-dependencies] -anyhow = "1" +# anyhow = "1" diff --git a/Dockerfile b/Dockerfile index 871c21c..a3761be 100644 --- a/Dockerfile +++ b/Dockerfile @@ -15,7 +15,9 @@ COPY . . RUN cargo build --release --bin gpt-html FROM alpine AS runtime -ENV COMMIT_SHA=$COMMIT_SHA +ARG COMMIT_SHA +ENV COMMIT_SHA="$COMMIT_SHA" +ENV DOCKER="true" WORKDIR /app COPY --from=build /app/target/release/gpt-html . EXPOSE 9292 diff --git a/Justfile b/Justfile index f2e9632..e204085 100644 --- a/Justfile +++ b/Justfile @@ -1,2 +1,19 @@ +COMMIT_SHA := `git rev-parse HEAD` + +dev: + cargo watch -w src/ -x run + +docker-build: + docker build --build-arg COMMIT_SHA="{{COMMIT_SHA}}" . + +docker-build-quiet: + docker build --build-arg COMMIT_SHA="{{COMMIT_SHA}}" --quiet . + +docker-run: + docker run --env OPENAI_API_KEY --publish 8080:8080 "$(just docker-build-quiet)" + deploy: - fly deploy --env COMMIT_SHA="$(git rev-parse HEAD)" + fly deploy --build-arg + +logs: + fly logs diff --git a/src/env.rs b/src/env.rs new file mode 100644 index 0000000..71e7b04 --- /dev/null +++ b/src/env.rs @@ -0,0 +1,26 @@ +use std::env; + +use tokio::process::Command; + +pub fn print() { + Command::new("env") + .spawn() + .expect("env command failed to start"); +} + +pub fn commit_sha() -> String { + env::var("COMMIT_SHA").unwrap_or_else(|_| "unknown".to_string()) +} + +pub fn docker() -> bool { + let var = env::var("DOCKER"); + var.is_ok() && var.unwrap() == "true" +} + +pub fn http_basic_auth_password() -> Option { + env::var("HTTP_BASIC_AUTH_PASSWORD").ok() +} + +pub fn openai_api_key() -> Option { + env::var("OPENAI_API_KEY").ok() +} diff --git a/src/error.rs b/src/error.rs index fa7de51..2526d3d 100644 --- a/src/error.rs +++ b/src/error.rs @@ -1,13 +1,13 @@ +use std::fmt::Display; + use axum::{ http::StatusCode, response::{IntoResponse, Response}, }; -use std::fmt::Display; -use strum_macros::AsRefStr; pub type Result = core::result::Result; -#[derive(Debug, AsRefStr)] +#[derive(Debug)] pub enum Error { SystemTimeError, EnvironmentError, @@ -20,7 +20,8 @@ pub enum Error { impl IntoResponse for Error { fn into_response(self) -> Response { - println!("->> {:<12} - {self:?}", "INTO_RES"); + println!("\n----------"); + println!("Error: {self:?}"); // Create a placeholder Axum reponse. let mut response = StatusCode::INTERNAL_SERVER_ERROR.into_response(); @@ -34,7 +35,7 @@ impl IntoResponse for Error { impl Display for Error { fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { - write!(f, "{}", self.as_ref()) + write!(f, "{:?}", self) } } diff --git a/src/main.rs b/src/main.rs index 16595b4..159be39 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,214 +1,27 @@ -pub use self::error::{Error, Result}; +use std::net::SocketAddr; + +use axum::Server; -use axum::{ - body::StreamBody, - http::{header, Uri}, - response::IntoResponse, - routing::{get, get_service}, - Json, Router, Server, -}; -use axum_extra::middleware::option_layer; -use eventsource_stream::Eventsource; -use futures::future; -use futures::{StreamExt, TryStreamExt}; -use serde::{Deserialize, Serialize}; -use std::{ - env, - io::{self, Write}, - net::SocketAddr, - time::SystemTime, -}; -use tokio::signal; -use tower_http::{services::ServeDir, validate_request::ValidateRequestHeaderLayer}; +pub use self::error::{Error, Result}; +mod env; mod error; +mod signal; +mod web; #[tokio::main] async fn main() { println!("Starting server..."); - let addr = SocketAddr::from(([0, 0, 0, 0], 8080)); - Server::bind(&addr) - .serve(app().into_make_service()) - .with_graceful_shutdown(shutdown_signal()) - .await - .expect("server should serve"); -} - -fn app() -> Router { - let auth = option_layer( - env::var("HTTP_BASIC_AUTH_PASSWORD") - .ok() - .map(|password| ValidateRequestHeaderLayer::basic("user", &password)), - ); - - Router::new().route("/health", get(health)).nest_service( - "/", - // Handle GET. - get_service( - // Serve static files from "public". - ServeDir::new("public") - // When static file missing use handler. - .fallback(get(handler)), - ) - // Other methods use handler. - .post(handler) - .patch(handler) - .put(handler) - .delete(handler) - // Apply HTTP basic auth. - .layer(auth), - ) -} - -async fn shutdown_signal() { - let ctrl_c = async { - signal::ctrl_c() - .await - .expect("failed to install Ctrl+C handler"); - }; - - let terminate = async { - signal::unix::signal(signal::unix::SignalKind::terminate()) - .expect("failed to install signal handler") - .recv() - .await; - }; - - tokio::select! { - _ = ctrl_c => {}, - _ = terminate => {}, - } - - println!("signal received, starting graceful shutdown"); -} - -#[derive(Debug, Serialize)] -struct HealthBody { - time: u64, - commit_sha: String, - basic_auth_enabled: bool, -} - -#[derive(Debug, Serialize)] -struct ChatCompletionsBody { - model: String, - stream: bool, - messages: Vec, -} - -#[derive(Debug, Serialize)] -struct Message { - role: String, - content: String, -} - -#[derive(Debug, Deserialize)] -struct Event { - choices: Vec, -} - -#[derive(Debug, Deserialize)] -struct Choice { - delta: Delta, -} - -#[derive(Debug, Deserialize)] -struct Delta { - content: String, -} - -async fn health() -> Result { - println!("\n----------"); - println!("Health"); - println!("----------"); - - env::var("OPENAI_API_KEY").map_err(|_| Error::EnvironmentError)?; - - let time = SystemTime::now() - .duration_since(SystemTime::UNIX_EPOCH) - .map_err(|_| Error::SystemTimeError)? - .as_secs(); - let commit_sha = env::var("COMMIT_SHA").unwrap_or_else(|_| "unknown".to_string()); - let basic_auth_enabled = env::var("HTTP_BASIC_AUTH_PASSWORD").is_ok(); - - Ok(Json(HealthBody { - time, - basic_auth_enabled, - commit_sha, - })) -} - -async fn handler(uri: Uri) -> Result { - println!("\n----------"); - println!("Fetching: {uri}"); - println!("----------"); - - let prompt = r#" -Output a valid HTML document for the webpage that could be located at the URL path provided by the user. Include general navigation anchor tags as well as relative anchor tags to other related pages. Include a minimal amount of inline styles to improve the look of the page. Make the text content quite long with a decent amount of interesting content. Do not use any dummy text on the page. - -Start the reponse with the following exact characters: - - -"#; - - let body = ChatCompletionsBody { - model: "gpt-3.5-turbo".to_string(), - stream: true, - messages: vec![ - Message { - role: "system".to_string(), - content: prompt.to_string(), - }, - Message { - role: "user".to_string(), - content: uri.to_string(), - }, - ], + let addr = if env::docker() { + SocketAddr::from(([0, 0, 0, 0], 8080)) + } else { + SocketAddr::from(([127, 0, 0, 1], 8080)) }; - let stream = reqwest::Client::new() - .post("https://api.openai.com/v1/chat/completions") - .header("content-type", "application/json") - .header( - "authorization", - &format!( - "Bearer {}", - env::var("OPENAI_API_KEY").map_err(|_| Error::EnvironmentError)? - ), - ) - .body(serde_json::to_string(&body).map_err(|_| Error::SerializationError)?) - .send() + Server::bind(&addr) + .serve(web::app().into_make_service()) + .with_graceful_shutdown(signal::shutdown()) .await - .map_err(|_| Error::RequestError)? - .bytes_stream() - .eventsource() - .map(|r| match r { - Ok(e) => { - serde_json::from_str::(&e.data).map_err(|_| Error::DeserializationError) - } - _ => Err(Error::StreamError), - }) - // Discard errors (will most likely be `Error::JsonError`). - .filter(|r| future::ready(r.is_ok())) - .map_ok(|event| { - let content = event - .choices - .into_iter() - .next() - .expect("event should have at least one choice") - .delta - .content; - - // Debug log. - print!("{}", content); - let _ = io::stdout().flush(); - - content - }); - - Ok(( - [(header::CONTENT_TYPE, "text/html")], - StreamBody::new(stream), - )) + .expect("server should serve"); } diff --git a/src/signal.rs b/src/signal.rs new file mode 100644 index 0000000..72a6466 --- /dev/null +++ b/src/signal.rs @@ -0,0 +1,23 @@ +use tokio::signal; + +pub async fn shutdown() { + let ctrl_c = async { + signal::ctrl_c() + .await + .expect("failed to install Ctrl+C handler"); + }; + + let terminate = async { + signal::unix::signal(signal::unix::SignalKind::terminate()) + .expect("failed to install signal handler") + .recv() + .await; + }; + + tokio::select! { + _ = ctrl_c => {}, + _ = terminate => {}, + } + + println!("Signal received, starting graceful shutdown..."); +} diff --git a/src/web/generate.rs b/src/web/generate.rs new file mode 100644 index 0000000..cb205d7 --- /dev/null +++ b/src/web/generate.rs @@ -0,0 +1,109 @@ +use std::{env, future}; + +use axum::{body::StreamBody, http::Uri, response::IntoResponse}; +use eventsource_stream::Eventsource; +use futures::{StreamExt, TryStreamExt}; +use reqwest::{header, Client}; +use serde::{Deserialize, Serialize}; + +use crate::error::{Error, Result}; + +#[derive(Debug, Serialize)] +struct ChatCompletionsBody { + model: String, + stream: bool, + messages: Vec, +} + +#[derive(Debug, Serialize)] +struct Message { + role: String, + content: String, +} + +#[derive(Debug, Deserialize)] +struct Event { + choices: Vec, +} + +#[derive(Debug, Deserialize)] +struct Choice { + delta: Delta, +} + +#[derive(Debug, Deserialize)] +struct Delta { + content: String, +} + +pub async fn handler(uri: Uri) -> Result { + println!("\n----------"); + println!("Fetching: {uri}"); + + let prompt = r#" +Output a valid HTML document for the webpage that could be located at the URL path provided by the user. Include general navigation anchor tags as well as relative anchor tags to other related pages. Include a minimal amount of inline styles to improve the look of the page. Make the text content quite long with a decent amount of interesting content. Do not use any dummy text on the page. + +Start the reponse with the following exact characters: + + +"#; + + let body = ChatCompletionsBody { + model: "gpt-3.5-turbo".to_string(), + stream: true, + messages: vec![ + Message { + role: "system".to_string(), + content: prompt.to_string(), + }, + Message { + role: "user".to_string(), + content: uri.to_string(), + }, + ], + }; + + let stream = Client::new() + .post("https://api.openai.com/v1/chat/completions") + .header("content-type", "application/json") + .header( + "authorization", + &format!( + "Bearer {}", + env::var("OPENAI_API_KEY").map_err(|_| Error::EnvironmentError)? + ), + ) + .body(serde_json::to_string(&body).map_err(|_| Error::SerializationError)?) + .send() + .await + .map_err(|_| Error::RequestError)? + .bytes_stream() + .eventsource() + .map(|r| match r { + Ok(e) => { + serde_json::from_str::(&e.data).map_err(|_| Error::DeserializationError) + } + _ => Err(Error::StreamError), + }) + // Discard errors (will most likely be `Error::JsonError`). + .filter(|r| future::ready(r.is_ok())) + .map_ok(|event| { + let content = event + .choices + .into_iter() + .next() + .expect("event should have at least one choice") + .delta + .content; + + // Debug log. + print!("{}", content); + + content + }); + + Ok(( + [(header::CONTENT_TYPE, "text/html")], + StreamBody::new(stream), + )) +} diff --git a/src/web/health.rs b/src/web/health.rs new file mode 100644 index 0000000..4a7bf80 --- /dev/null +++ b/src/web/health.rs @@ -0,0 +1,38 @@ +use std::time::SystemTime; + +use axum::{response::IntoResponse, Json}; +use serde::Serialize; + +use crate::{ + env, + error::{Error, Result}, +}; + +#[derive(Debug, Serialize)] +struct HealthBody { + time: u64, + commit_sha: String, + basic_auth_enabled: bool, +} + +pub async fn handler() -> Result { + println!("\n----------"); + println!("Health"); + + env::print(); + + env::openai_api_key().ok_or(Error::EnvironmentError)?; + + let time = SystemTime::now() + .duration_since(SystemTime::UNIX_EPOCH) + .map_err(|_| Error::SystemTimeError)? + .as_secs(); + let commit_sha = env::commit_sha(); + let basic_auth_enabled = env::http_basic_auth_password().is_some(); + + Ok(Json(HealthBody { + time, + basic_auth_enabled, + commit_sha, + })) +} diff --git a/src/web/mod.rs b/src/web/mod.rs new file mode 100644 index 0000000..d92d33a --- /dev/null +++ b/src/web/mod.rs @@ -0,0 +1,35 @@ +use axum::{ + routing::{get, get_service}, + Router, +}; +use axum_extra::middleware::option_layer; +use tower_http::{services::ServeDir, validate_request::ValidateRequestHeaderLayer}; + +use crate::env; + +pub mod generate; +pub mod health; + +pub fn app() -> Router { + Router::new() + .route("/health", get(health::handler)) + .nest_service( + "/", + // Handle GET. + get_service( + // Serve static files from "public". + ServeDir::new("public") + // When static file missing use handler. + .fallback(get(generate::handler)), + ) + // Other methods use handler. + .post(generate::handler) + .patch(generate::handler) + .put(generate::handler) + .delete(generate::handler) + // Optionally apply HTTP basic auth. + .layer(option_layer(env::http_basic_auth_password().map( + |password| ValidateRequestHeaderLayer::basic("user", &password), + ))), + ) +}