Skip to content

Commit

Permalink
Introduce directory feature module
Browse files Browse the repository at this point in the history
Allow payjoin and payjoin-directory to share ShortId code.
  • Loading branch information
DanGould committed Jan 21, 2025
1 parent 1fd9748 commit 2e4848d
Show file tree
Hide file tree
Showing 8 changed files with 93 additions and 95 deletions.
1 change: 1 addition & 0 deletions payjoin-directory/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,7 @@ hyper = { version = "1", features = ["http1", "server"] }
hyper-rustls = { version = "0.26", optional = true }
hyper-util = { version = "0.1", features = ["tokio"] }
ohttp = { package = "bitcoin-ohttp", version = "0.6.0"}
payjoin = { path = "../payjoin", default-features = false, features = ["directory"] }
redis = { version = "0.23.3", features = ["aio", "tokio-comp"] }
rustls = { version = "0.22.4", optional = true }
tokio = { version = "1.12.0", features = ["full"] }
Expand Down
24 changes: 15 additions & 9 deletions payjoin-directory/src/db.rs
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
use std::time::Duration;

use futures::StreamExt;
use payjoin::directory::ShortId;
use redis::{AsyncCommands, Client, ErrorKind, RedisError, RedisResult};
use tracing::debug;

Expand Down Expand Up @@ -53,24 +54,29 @@ impl DbPool {
}

/// Peek using [`DEFAULT_COLUMN`] as the channel type.
pub async fn push_default(&self, subdirectory_id: &str, data: Vec<u8>) -> Result<()> {
pub async fn push_default(&self, subdirectory_id: &ShortId, data: Vec<u8>) -> Result<()> {
self.push(subdirectory_id, DEFAULT_COLUMN, data).await
}

pub async fn peek_default(&self, subdirectory_id: &str) -> Result<Vec<u8>> {
pub async fn peek_default(&self, subdirectory_id: &ShortId) -> Result<Vec<u8>> {
self.peek_with_timeout(subdirectory_id, DEFAULT_COLUMN).await
}

pub async fn push_v1(&self, subdirectory_id: &str, data: Vec<u8>) -> Result<()> {
pub async fn push_v1(&self, subdirectory_id: &ShortId, data: Vec<u8>) -> Result<()> {
self.push(subdirectory_id, PJ_V1_COLUMN, data).await
}

/// Peek using [`PJ_V1_COLUMN`] as the channel type.
pub async fn peek_v1(&self, subdirectory_id: &str) -> Result<Vec<u8>> {
pub async fn peek_v1(&self, subdirectory_id: &ShortId) -> Result<Vec<u8>> {
self.peek_with_timeout(subdirectory_id, PJ_V1_COLUMN).await
}

async fn push(&self, subdirectory_id: &str, channel_type: &str, data: Vec<u8>) -> Result<()> {
async fn push(
&self,
subdirectory_id: &ShortId,
channel_type: &str,
data: Vec<u8>,
) -> Result<()> {
let mut conn = self.client.get_async_connection().await?;
let key = channel_name(subdirectory_id, channel_type);
() = conn.set(&key, data.clone()).await?;
Expand All @@ -80,7 +86,7 @@ impl DbPool {

async fn peek_with_timeout(
&self,
subdirectory_id: &str,
subdirectory_id: &ShortId,
channel_type: &str,
) -> Result<Vec<u8>> {
match tokio::time::timeout(self.timeout, self.peek(subdirectory_id, channel_type)).await {
Expand All @@ -92,7 +98,7 @@ impl DbPool {
}
}

async fn peek(&self, subdirectory_id: &str, channel_type: &str) -> RedisResult<Vec<u8>> {
async fn peek(&self, subdirectory_id: &ShortId, channel_type: &str) -> RedisResult<Vec<u8>> {
let mut conn = self.client.get_async_connection().await?;
let key = channel_name(subdirectory_id, channel_type);

Expand Down Expand Up @@ -140,6 +146,6 @@ impl DbPool {
}
}

fn channel_name(subdirectory_id: &str, channel_type: &str) -> Vec<u8> {
(subdirectory_id.to_owned() + channel_type).into_bytes()
fn channel_name(subdirectory_id: &ShortId, channel_type: &str) -> Vec<u8> {
(subdirectory_id.to_string() + channel_type).into_bytes()
}
39 changes: 17 additions & 22 deletions payjoin-directory/src/lib.rs
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
use std::net::{IpAddr, Ipv6Addr, SocketAddr};
use std::str::FromStr;
use std::sync::Arc;
use std::time::Duration;

Expand All @@ -11,6 +12,7 @@ use hyper::server::conn::http1;
use hyper::service::service_fn;
use hyper::{Method, Request, Response, StatusCode, Uri};
use hyper_util::rt::TokioIo;
use payjoin::directory::{ShortId, ShortIdError};
use tokio::net::TcpListener;
use tokio::sync::Mutex;
use tracing::{debug, error, info, trace};
Expand All @@ -32,9 +34,6 @@ const V1_REJECT_RES_JSON: &str =
r#"{{"errorCode": "original-psbt-rejected ", "message": "Body is not a string"}}"#;
const V1_UNAVAILABLE_RES_JSON: &str = r#"{{"errorCode": "unavailable", "message": "V2 receiver offline. V1 sends require synchronous communications."}}"#;

// 8 bytes as bech32 is 12.8 characters
const ID_LENGTH: usize = 13;

mod db;

#[cfg(feature = "_danger-local-https")]
Expand Down Expand Up @@ -313,6 +312,12 @@ impl From<hyper::http::Error> for HandlerError {
fn from(e: hyper::http::Error) -> Self { HandlerError::InternalServerError(e.into()) }
}

impl From<ShortIdError> for HandlerError {
fn from(_: ShortIdError) -> Self {
HandlerError::BadRequest(anyhow::anyhow!("subdirectory ID must be 13 bech32 characters"))
}
}

fn handle_peek(
result: db::Result<Vec<u8>>,
timeout_response: Response<BoxBody<Bytes, hyper::Error>>,
Expand Down Expand Up @@ -353,11 +358,11 @@ async fn post_fallback_v1(
};

let v2_compat_body = format!("{}\n{}", body_str, query);
let id = check_id_length(id)?;
pool.push_default(id, v2_compat_body.into())
let id = ShortId::from_str(id)?;
pool.push_default(&id, v2_compat_body.into())
.await
.map_err(|e| HandlerError::BadRequest(e.into()))?;
handle_peek(pool.peek_v1(id).await, none_response)
handle_peek(pool.peek_v1(&id).await, none_response)
}

async fn put_payjoin_v1(
Expand All @@ -368,29 +373,19 @@ async fn put_payjoin_v1(
trace!("Put_payjoin_v1");
let ok_response = Response::builder().status(StatusCode::OK).body(empty())?;

let id = check_id_length(id)?;
let id = ShortId::from_str(id)?;
let req =
body.collect().await.map_err(|e| HandlerError::InternalServerError(e.into()))?.to_bytes();
if req.len() > V1_MAX_BUFFER_SIZE {
return Err(HandlerError::PayloadTooLarge);
}

match pool.push_v1(id, req.into()).await {
match pool.push_v1(&id, req.into()).await {
Ok(_) => Ok(ok_response),
Err(e) => Err(HandlerError::BadRequest(e.into())),
}
}

fn check_id_length(id: &str) -> Result<&str, HandlerError> {
if id.len() != ID_LENGTH {
return Err(HandlerError::BadRequest(anyhow::anyhow!(
"subdirectory ID must be 13 bech32 characters",
)));
}

Ok(id)
}

async fn post_subdir(
id: &str,
body: BoxBody<Bytes, hyper::Error>,
Expand All @@ -399,15 +394,15 @@ async fn post_subdir(
let none_response = Response::builder().status(StatusCode::OK).body(empty())?;
trace!("post_subdir");

let id = check_id_length(id)?;
let id = ShortId::from_str(id)?;

let req =
body.collect().await.map_err(|e| HandlerError::InternalServerError(e.into()))?.to_bytes();
if req.len() > V1_MAX_BUFFER_SIZE {
return Err(HandlerError::PayloadTooLarge);
}

match pool.push_default(id, req.into()).await {
match pool.push_default(&id, req.into()).await {
Ok(_) => Ok(none_response),
Err(e) => Err(HandlerError::BadRequest(e.into())),
}
Expand All @@ -418,9 +413,9 @@ async fn get_subdir(
pool: DbPool,
) -> Result<Response<BoxBody<Bytes, hyper::Error>>, HandlerError> {
trace!("get_subdir");
let id = check_id_length(id)?;
let id = ShortId::from_str(id)?;
let timeout_response = Response::builder().status(StatusCode::ACCEPTED).body(empty())?;
handle_peek(pool.peek_default(id).await, timeout_response)
handle_peek(pool.peek_default(&id).await, timeout_response)
}

fn not_found() -> Response<BoxBody<Bytes, hyper::Error>> {
Expand Down
3 changes: 2 additions & 1 deletion payjoin/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -20,8 +20,9 @@ default = ["v2"]
base64 = ["bitcoin/base64"]
#[doc = "Core features for payjoin state machines"]
_core = ["bitcoin/rand", "serde_json", "url", "bitcoin_uri"]
directory = []
v1 = ["_core"]
v2 = ["_core","bitcoin/serde", "hpke", "dep:http", "bhttp", "ohttp", "serde", "url/serde",]
v2 = ["_core","bitcoin/serde", "hpke", "dep:http", "bhttp", "ohttp", "serde", "url/serde", "directory"]
#[doc = "Functions to fetch OHTTP keys via CONNECT proxy using reqwest. Enables `v2` since only `v2` uses OHTTP."]
io = ["v2", "reqwest/rustls-tls"]
_danger-local-https = ["reqwest/rustls-tls", "rustls"]
Expand Down
7 changes: 3 additions & 4 deletions payjoin/src/bech32.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,3 @@
use std::fmt;

use bitcoin::bech32::primitives::decode::{CheckedHrpstring, CheckedHrpstringError};
use bitcoin::bech32::{self, EncodeError, Hrp, NoChecksum};

Expand All @@ -15,8 +13,9 @@ pub mod nochecksum {
bech32::encode_upper::<NoChecksum>(hrp, data)
}

pub fn encode_to_fmt(f: &mut fmt::Formatter, hrp: Hrp, data: &[u8]) -> Result<(), EncodeError> {
bech32::encode_upper_to_fmt::<NoChecksum, fmt::Formatter>(f, hrp, data)
#[cfg(feature = "v2")]
pub fn encode_to_fmt(f: &mut core::fmt::Formatter, hrp: Hrp, data: &[u8]) -> Result<(), EncodeError> {
bech32::encode_upper_to_fmt::<NoChecksum, core::fmt::Formatter>(f, hrp, data)
}
}

Expand Down
50 changes: 50 additions & 0 deletions payjoin/src/directory.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,50 @@
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ShortId(pub [u8; 8]);

impl ShortId {
pub fn as_bytes(&self) -> &[u8] { &self.0 }
pub fn as_slice(&self) -> &[u8] { &self.0 }
}

impl std::fmt::Display for ShortId {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
let id_hrp = bitcoin::bech32::Hrp::parse("ID").unwrap();
f.write_str(
crate::bech32::nochecksum::encode(id_hrp, &self.0)
.expect("bech32 encoding of short ID must succeed")
.strip_prefix("ID1")
.expect("human readable part must be ID1"),
)
}
}

#[derive(Debug)]
pub enum ShortIdError {
DecodeBech32(bitcoin::bech32::primitives::decode::CheckedHrpstringError),
IncorrectLength(std::array::TryFromSliceError),
}

impl std::convert::From<bitcoin::hashes::sha256::Hash> for ShortId {
fn from(h: bitcoin::hashes::sha256::Hash) -> Self {
bitcoin::hashes::Hash::as_byte_array(&h)[..8]
.try_into()
.expect("truncating SHA256 to 8 bytes should always succeed")
}
}

impl std::convert::TryFrom<&[u8]> for ShortId {
type Error = ShortIdError;
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
let bytes: [u8; 8] = bytes.try_into().map_err(ShortIdError::IncorrectLength)?;
Ok(Self(bytes))
}
}

impl std::str::FromStr for ShortId {
type Err = ShortIdError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let (_, bytes) = crate::bech32::nochecksum::decode(&("ID1".to_string() + s))
.map_err(ShortIdError::DecodeBech32)?;
(&bytes[..]).try_into()
}
}
4 changes: 3 additions & 1 deletion payjoin/src/lib.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,8 +33,10 @@ pub use crate::hpke::{HpkeKeyPair, HpkePublicKey};
pub(crate) mod ohttp;
#[cfg(feature = "v2")]
pub use crate::ohttp::OhttpKeys;
#[cfg(feature = "v2")]
#[cfg(any(feature = "v2", feature = "directory"))]
pub(crate) mod bech32;
#[cfg(feature = "directory")]
pub mod directory;

#[cfg(feature = "io")]
pub mod io;
Expand Down
60 changes: 2 additions & 58 deletions payjoin/src/uri/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,8 @@ use bitcoin::address::NetworkChecked;
pub use error::PjParseError;
use url::Url;

#[cfg(feature = "v2")]
pub(crate) use crate::directory::ShortId;
use crate::uri::error::InternalPjParseError;
#[cfg(feature = "v2")]
pub(crate) use crate::uri::url_ext::UrlExt;
Expand All @@ -12,64 +14,6 @@ pub mod error;
#[cfg(feature = "v2")]
pub(crate) mod url_ext;

#[cfg(feature = "v2")]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct ShortId(pub [u8; 8]);

#[cfg(feature = "v2")]
impl ShortId {
pub fn as_bytes(&self) -> &[u8] { &self.0 }
pub fn as_slice(&self) -> &[u8] { &self.0 }
}

#[cfg(feature = "v2")]
impl std::fmt::Display for ShortId {
fn fmt(&self, f: &mut std::fmt::Formatter) -> std::fmt::Result {
let id_hrp = bitcoin::bech32::Hrp::parse("ID").unwrap();
f.write_str(
crate::bech32::nochecksum::encode(id_hrp, &self.0)
.expect("bech32 encoding of short ID must succeed")
.strip_prefix("ID1")
.expect("human readable part must be ID1"),
)
}
}

#[cfg(feature = "v2")]
#[derive(Debug)]
pub enum ShortIdError {
DecodeBech32(bitcoin::bech32::primitives::decode::CheckedHrpstringError),
IncorrectLength(std::array::TryFromSliceError),
}

#[cfg(feature = "v2")]
impl std::convert::From<bitcoin::hashes::sha256::Hash> for ShortId {
fn from(h: bitcoin::hashes::sha256::Hash) -> Self {
bitcoin::hashes::Hash::as_byte_array(&h)[..8]
.try_into()
.expect("truncating SHA256 to 8 bytes should always succeed")
}
}

#[cfg(feature = "v2")]
impl std::convert::TryFrom<&[u8]> for ShortId {
type Error = ShortIdError;
fn try_from(bytes: &[u8]) -> Result<Self, Self::Error> {
let bytes: [u8; 8] = bytes.try_into().map_err(ShortIdError::IncorrectLength)?;
Ok(Self(bytes))
}
}

#[cfg(feature = "v2")]
impl std::str::FromStr for ShortId {
type Err = ShortIdError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
let (_, bytes) = crate::bech32::nochecksum::decode(&("ID1".to_string() + s))
.map_err(ShortIdError::DecodeBech32)?;
(&bytes[..]).try_into()
}
}

#[derive(Debug, Clone)]
pub enum MaybePayjoinExtras {
Supported(PayjoinExtras),
Expand Down

0 comments on commit 2e4848d

Please sign in to comment.