Skip to content

Commit

Permalink
Protect metadata with Oblivious HTTP
Browse files Browse the repository at this point in the history
  • Loading branch information
DanGould committed Sep 24, 2023
1 parent 515e8e1 commit d742829
Show file tree
Hide file tree
Showing 9 changed files with 687 additions and 62 deletions.
457 changes: 421 additions & 36 deletions Cargo.lock

Large diffs are not rendered by default.

2 changes: 2 additions & 0 deletions payjoin-cli/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@ v2 = ["payjoin/v2", "tokio/full", "tokio-tungstenite", "futures-util/sink", "fut
[dependencies]
anyhow = "1.0.70"
base64 = "0.13.0"
bhttp = { version = "0.4.0", features = ["http", "bhttp"] }
bip21 = "0.3.1"
bitcoincore-rpc = "0.17.0"
clap = "4.1.4"
Expand All @@ -26,6 +27,7 @@ env_logger = "0.9.0"
futures = "0.3.28"
futures-util = { version = "0.3.28", default-features = false }
log = "0.4.7"
ohttp = "0.4.0"
payjoin = { path = "../payjoin", features = ["send", "receive"] }
reqwest = { version = "0.11.4", features = ["blocking"] }
rcgen = { version = "0.11.1", optional = true }
Expand Down
122 changes: 107 additions & 15 deletions payjoin-cli/src/app.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,11 @@ use bitcoincore_rpc::jsonrpc::serde_json;
use bitcoincore_rpc::RpcApi;
use clap::ArgMatches;
use config::{Config, File, FileFormat};
use ohttp::ClientResponse;
use payjoin::bitcoin::psbt::Psbt;
use payjoin::receive::{Error, PayjoinProposal, UncheckedProposal};
use payjoin::{bitcoin, PjUriExt, UriExt};
use reqwest::Request;
#[cfg(not(feature = "v2"))]
use rouille::{Request, Response};
use serde::{Deserialize, Serialize};
Expand Down Expand Up @@ -50,30 +52,44 @@ impl App {
.danger_accept_invalid_certs(self.config.danger_accept_invalid_certs)
.build()
.with_context(|| "Failed to build reqwest http client")?;
let _ = client
dbg!(&req.body);
let pre_req = client
.post(req.url.as_str())
.body(req.body)
.header("Content-Type", "text/plain")
.send()
.await
.with_context(|| "HTTP request failed")?;
.build()
.with_context(|| "Failed to build HTTP request")?;
let (ohttp_req, _) = self.ohttp_encapsulate_req(pre_req);
dbg!(&ohttp_req);
let _ = client.post(&self.config.ohttp_proxy).body(ohttp_req).send().await?;

log::debug!("Awaiting response");
let res = Self::long_poll(&client, req.url.as_str()).await?;
let res = self.long_poll(&client, req.url.as_str()).await?;
let mut res = std::io::Cursor::new(&res);
self.process_pj_response(ctx, &mut res)?;
Ok(())
}

#[cfg(feature = "v2")]
async fn long_poll(client: &reqwest::Client, url: &str) -> Result<Vec<u8>, reqwest::Error> {
async fn long_poll(
&self,
client: &reqwest::Client,
url: &str,
) -> Result<Vec<u8>, reqwest::Error> {
loop {
let response = client.get(url).send().await?;

if response.status().is_success() {
let body = response.bytes().await?;
let req = client.get(url).build()?;
let (ohttp_req, ctx) = self.ohttp_encapsulate_req(req);

let ohttp_response =
client.post(&self.config.ohttp_proxy).body(ohttp_req).send().await?;
log::debug!("Response: {:?}", ohttp_response);
if ohttp_response.status().is_success() {
let body = ohttp_response.bytes().await?;
if !body.is_empty() {
return Ok(body.to_vec());
let bhttp_response = ctx.decapsulate(&body).unwrap();
let mut r = std::io::Cursor::new(bhttp_response);
let response = bhttp::Message::read_bhttp(&mut r).unwrap();
return Ok(response.content().to_vec());
} else {
log::info!("No response yet for payjoin request, retrying in 5 seconds");
}
Expand All @@ -83,6 +99,33 @@ impl App {
}
}

fn ohttp_encapsulate_req(&self, req: Request) -> (Vec<u8>, ClientResponse) {
let ohttp_config = payjoin::bitcoin::base64::decode_config(
&self.config.ohttp_config,
payjoin::bitcoin::base64::URL_SAFE,
)
.unwrap();
let ctx = ohttp::ClientRequest::from_encoded_config(&ohttp_config).unwrap();

let mut bhttp_message = bhttp::Message::request(
req.method().as_str().as_bytes().to_vec(),
req.url().scheme().as_bytes().to_vec(),
req.url().authority().as_bytes().to_vec(),
req.url().path().as_bytes().to_vec(),
);
match req.body() {
Some(body) => {
bhttp_message.write_content(body.as_bytes().unwrap());
}
None => (),
}
// let req = serialize_request_to_bytes(req);
// let http_message = bhttp::Message::read_http(&mut std::io::Cursor::new(&req)).unwrap();
let mut bhttp_req = Vec::new();
let _ = bhttp_message.write_bhttp(bhttp::Mode::KnownLength, &mut bhttp_req);
ctx.encapsulate(&bhttp_req).unwrap()
}

#[cfg(not(feature = "v2"))]
pub fn send_payjoin(&self, bip21: &str) -> Result<()> {
let (req, ctx) = self.create_pj_request(bip21)?;
Expand All @@ -102,6 +145,23 @@ impl App {
Ok(())
}

// fn create_v2_pj_request(&self, bip21: &str,
// ) -> Result<(payjoin::send::Request, payjoin::send::Context)> {
// let (req, ctx) = self.create_pj_request(bip21)?;
// let config = base64::decode(&self.config.ohttp_config)?;
// let req_ctx = ohttp::ClientRequest::from_encoded_config(&config)
// .with_context(|| "Failed to decode ohttp config")?;
// let (enc_req, req_ctx) = req_ctx.encapsulate(&req.body).with_context(|| "Failed to encapsulate request")?;

// Ok((payjoin::send::Request {
// url: req.url,
// body: enc_req,
// }, payjoin::send::Context {
// ohttp_ctx: req_ctx,
// ..ctx
// }))
// }

fn create_pj_request(
&self,
bip21: &str,
Expand Down Expand Up @@ -221,7 +281,7 @@ impl App {
.with_context(|| "Failed to build reqwest http client")?;
log::debug!("Awaiting request");
let receive_endpoint = format!("{}/{}", self.config.pj_endpoint, context.receive_subdir());
let mut buffer = Self::long_poll(&client, &receive_endpoint).await?;
let mut buffer = self.long_poll(&client, &receive_endpoint).await?;

log::debug!("Received request");
let (proposal, e) = context
Expand All @@ -232,9 +292,15 @@ impl App {
.map_err(|e| anyhow!("Failed to process UncheckedProposal {}", e))?;
let mut payjoin_bytes = payjoin_psbt.serialize();
let payjoin = payjoin::v2::encrypt_message_b(&mut payjoin_bytes, e);
let _ = client
let req = client
.post(receive_endpoint)
.body(payjoin)
.build()
.with_context(|| "Failed to build HTTP request")?;
let (req, _) = self.ohttp_encapsulate_req(req);
let _ = client
.post(&self.config.ohttp_proxy)
.body(req)
.send()
.await
.with_context(|| "HTTP request failed")?;
Expand Down Expand Up @@ -274,14 +340,15 @@ impl App {
let amount = Amount::from_sat(amount_arg.parse()?);
//let subdir = self.config.pj_endpoint + pubkey.map_or(&String::from(""), |s| &format!("/{}", s));
let pj_uri_string = format!(
"{}?amount={}&pj={}",
"{}?amount={}&pj={}&ohttp={}",
pj_receiver_address.to_qr_uri(),
amount.to_btc(),
format!(
"{}{}",
self.config.pj_endpoint,
pubkey.map_or(String::from(""), |s| format!("/{}", s))
)
),
self.config.ohttp_config,
);
let pj_uri = payjoin::Uri::from_str(&pj_uri_string)
.map_err(|e| anyhow!("Constructed a bad URI string from args: {}", e))?;
Expand Down Expand Up @@ -465,6 +532,19 @@ impl App {
}
}

fn serialize_request_to_bytes(req: reqwest::Request) -> Vec<u8> {
let mut serialized_request =
format!("{} {} HTTP/1.1\r\n", req.method(), req.url()).into_bytes();

for (name, value) in req.headers().iter() {
let header_line = format!("{}: {}\r\n", name.as_str(), value.to_str().unwrap());
serialized_request.extend(header_line.as_bytes());
}

serialized_request.extend(b"\r\n");
serialized_request
}

struct SeenInputs {
set: OutPointSet,
file: std::fs::File,
Expand Down Expand Up @@ -504,6 +584,8 @@ pub(crate) struct AppConfig {
pub bitcoind_cookie: Option<String>,
pub bitcoind_rpcuser: String,
pub bitcoind_rpcpass: String,
pub ohttp_config: String,
pub ohttp_proxy: String,

// send-only
pub danger_accept_invalid_certs: bool,
Expand Down Expand Up @@ -537,6 +619,16 @@ impl AppConfig {
"bitcoind_rpcpass",
matches.get_one::<String>("rpcpass").map(|s| s.as_str()),
)?
.set_default("ohttp_config", "")?
.set_override_option(
"ohttp_config",
matches.get_one::<String>("ohttp_config").map(|s| s.as_str()),
)?
.set_default("ohttp_proxy", "")?
.set_override_option(
"ohttp_proxy",
matches.get_one::<String>("ohttp_proxy").map(|s| s.as_str()),
)?
// Subcommand defaults without which file serialization fails.
.set_default("danger_accept_invalid_certs", false)?
.set_default("pj_host", "0.0.0.0:3000")?
Expand Down
6 changes: 6 additions & 0 deletions payjoin-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,12 @@ fn cli() -> ArgMatches {
.long("rpcpass")
.help("The password for the bitcoin node"))
.subcommand_required(true)
.arg(Arg::new("ohttp_config")
.long("ohttp-config")
.help("The ohttp config file"))
.arg(Arg::new("ohttp_proxy")
.long("ohttp-proxy")
.help("The ohttp proxy url"))
.subcommand(
Command::new("send")
.arg_required_else_help(true)
Expand Down
8 changes: 8 additions & 0 deletions payjoin-relay/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -9,9 +9,17 @@ edition = "2021"
axum = "0.6.2"
anyhow = "1.0.71"
futures-util = { version = "0.3.28", default-features = false, features = ["sink", "std"] }
hex = "0.4.3"
hyper = "0.14.27"
http = "0.2.4"
# ohttp = "0.4.0"
httparse = "1.8.0"
ohttp = { path = "../../ohttp/ohttp" }
bhttp = { version = "0.4.0", features = ["http"] }
payjoin = { path = "../payjoin", features = ["v2"] }
sqlx = { version = "0.7.1", features = ["postgres", "runtime-tokio"] }
tokio = { version = "1.12.0", features = ["full"] }
tokio-tungstenite = "0.20.0"
tower-service = "0.3.2"
tracing = "0.1.37"
tracing-subscriber = { version = "0.3.17", features = ["env-filter"] }
89 changes: 83 additions & 6 deletions payjoin-relay/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,17 +2,16 @@ use std::error::Error;
use std::sync::Arc;
use std::time::Duration;

use anyhow::Result;
use anyhow::{Result, Context};
use axum::body::Bytes;
use axum::extract::Path;
use axum::http::StatusCode;
use axum::http::{StatusCode, Request};
use axum::routing::{get, post};
use axum::Router;
use futures_util::stream::{SplitSink, SplitStream};
use futures_util::{SinkExt, StreamExt};
use payjoin::v2::{MAX_BUFFER_SIZE, RECEIVE};
use tokio::net::{TcpListener, TcpStream};
use tokio_tungstenite::tungstenite::handshake::server::Request;
use tokio_tungstenite::tungstenite::Message;
use tokio_tungstenite::{accept_hdr_async, WebSocketStream};
use tracing::{debug, error, info, info_span, Instrument};
Expand All @@ -26,8 +25,9 @@ use crate::db::DbPool;
async fn main() -> Result<(), Box<dyn std::error::Error>> {
init_logging();
let pool = DbPool::new(std::time::Duration::from_secs(30)).await?;

let app = Router::new()
let ohttp = Arc::new(init_ohttp()?);
let ohttp_config = ohttp_config(&*ohttp)?;
let target_resource = Router::new()
.route(
"/:id",
post({
Expand All @@ -50,9 +50,16 @@ async fn main() -> Result<(), Box<dyn std::error::Error>> {
move |id, body| post_payjoin(id, body, pool)
}),
);

let ohttp_gateway = Router::new()
.route("/", post( move |body| handle_ohttp(body, target_resource, ohttp)))
.route("/ohttp-keys", get({
move || get_ohttp_config(ohttp_config)}));


println!("Serverless payjoin relay awaiting HTTP connection on port 8080");
axum::Server::bind(&"0.0.0.0:8080".parse()?).serve(app.into_make_service()).await?;
axum::Server::bind(&"0.0.0.0:8080".parse()?).serve(ohttp_gateway.into_make_service()).await?;
//hyper::Server::bind(&"0.0.0.0:8080").serve()
Ok(())
}

Expand All @@ -65,6 +72,72 @@ fn init_logging() {
println!("Logging initialized");
}

fn init_ohttp() -> Result<ohttp::Server> {
use ohttp::hpke::{Aead, Kdf, Kem};
use ohttp::{KeyId, SymmetricSuite};

const KEY_ID: KeyId = 1;
const KEM: Kem = Kem::X25519Sha256;
const SYMMETRIC: &[SymmetricSuite] =
&[SymmetricSuite::new(Kdf::HkdfSha256, Aead::ChaCha20Poly1305)];

// create or read from file
let server_config = ohttp::KeyConfig::new(KEY_ID, KEM, Vec::from(SYMMETRIC)).unwrap();
let encoded_config = server_config.encode().unwrap();
let b64_config = payjoin::bitcoin::base64::encode_config(&encoded_config, payjoin::bitcoin::base64::Config::new(payjoin::bitcoin::base64::CharacterSet::UrlSafe, false));
info!("ohttp server config base64 UrlSafe: {:?}", b64_config);
ohttp::Server::new(server_config).with_context(|| "Failed to initialize ohttp server")
}

async fn handle_ohttp(enc_request: Bytes, mut target: Router, ohttp: Arc<ohttp::Server>) -> (StatusCode, Vec<u8>) {
use tower_service::Service;
use axum::body::Body;
use http::Uri;

// decapsulate
let (bhttp_req, res_ctx) = ohttp.decapsulate(&enc_request).unwrap();
let mut cursor = std::io::Cursor::new(bhttp_req);
let req = bhttp::Message::read_bhttp(&mut cursor).unwrap();
// let parsed_request: httparse::Request = httparse::Request::new(&mut vec![]).parse(cursor).unwrap();
// // handle request
// Request::new
let uri = Uri::builder()
.scheme(req.control().scheme().unwrap())
.authority(req.control().authority().unwrap())
.path_and_query(req.control().path().unwrap())
.build()
.unwrap();
let body = req.content().to_vec();
let mut request = Request::builder()
.uri(uri)
.method(req.control().method().unwrap());
for header in req.header().fields() {
request = request.header(header.name(), header.value())
}
let request = request
.body(Body::from(body))
.unwrap();

let response = target.call(request).await.unwrap();

let (parts, body) = response.into_parts();
let mut bhttp_res = bhttp::Message::response(parts.status.as_u16());
let full_body = hyper::body::to_bytes(body).await.unwrap();
bhttp_res.write_content(&full_body);
let mut bhttp_bytes = Vec::new();
bhttp_res.write_bhttp(bhttp::Mode::KnownLength, &mut bhttp_bytes).unwrap();
let ohttp_res = res_ctx.encapsulate(&bhttp_bytes).unwrap();
(StatusCode::OK, ohttp_res)
}

fn ohttp_config(server: &ohttp::Server) -> Result<String> {
use payjoin::bitcoin::base64;

let b64_config = base64::Config::new(base64::CharacterSet::UrlSafe, false);
let encoded_config = server.config().encode().with_context(|| "Failed to encode ohttp config")?;
Ok(base64::encode_config(&encoded_config, b64_config))
}

async fn post_fallback(Path(id): Path<String>, body: Bytes, pool: DbPool) -> (StatusCode, String) {
let id = shorten_string(&id);
let body = body.to_vec();
Expand All @@ -78,6 +151,10 @@ async fn post_fallback(Path(id): Path<String>, body: Bytes, pool: DbPool) -> (St
}
}

async fn get_ohttp_config(config: String) -> (StatusCode, String) {
(StatusCode::OK, config)
}

async fn get_request(Path(id): Path<String>, pool: DbPool) -> (StatusCode, Vec<u8>) {
let id = shorten_string(&id);
match pool.peek_req(&id).await {
Expand Down
Loading

0 comments on commit d742829

Please sign in to comment.