Skip to content

Commit

Permalink
Error if send or receive session expired (#299)
Browse files Browse the repository at this point in the history
  • Loading branch information
DanGould authored Jul 17, 2024
2 parents eb51b35 + 161f787 commit d9c76dd
Show file tree
Hide file tree
Showing 10 changed files with 254 additions and 79 deletions.
1 change: 0 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

5 changes: 2 additions & 3 deletions payjoin-cli/src/app/v2.rs
Original file line number Diff line number Diff line change
Expand Up @@ -84,7 +84,7 @@ impl AppTrait for App {
self.config.pj_directory.clone(),
ohttp_keys.clone(),
self.config.ohttp_relay.clone(),
std::time::Duration::from_secs(60 * 60),
None,
);
let (req, ctx) =
initializer.extract_req().map_err(|e| anyhow!("Failed to extract request {}", e))?;
Expand Down Expand Up @@ -246,8 +246,7 @@ impl App {
session: &mut payjoin::receive::v2::ActiveSession,
) -> Result<payjoin::receive::v2::UncheckedProposal> {
loop {
let (req, context) =
session.extract_req().map_err(|_| anyhow!("Failed to extract request"))?;
let (req, context) = session.extract_req()?;
println!("Polling receive request...");
let http = http_agent()?;
let ohttp_response = http
Expand Down
3 changes: 1 addition & 2 deletions payjoin/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ exclude = ["tests"]
send = []
receive = ["bitcoin/rand"]
base64 = ["bitcoin/base64"]
v2 = ["bitcoin/rand", "bitcoin/serde", "chacha20poly1305", "dep:http", "bhttp", "ohttp", "dep:percent-encoding", "serde", "url/serde"]
v2 = ["bitcoin/rand", "bitcoin/serde", "chacha20poly1305", "dep:http", "bhttp", "ohttp", "serde", "url/serde"]
io = ["reqwest/rustls-tls"]
danger-local-https = ["io", "reqwest/rustls-tls", "rustls"]

Expand All @@ -31,7 +31,6 @@ log = { version = "0.4.14"}
http = { version = "1", optional = true }
bhttp = { version = "=0.5.1", optional = true }
ohttp = { version = "0.5.1", optional = true }
percent-encoding = { version = "0.1.3", optional = true, package = "percent-encoding-rfc3986" }
serde = { version = "1.0.186", default-features = false, optional = true }
reqwest = { version = "0.12", default-features = false, optional = true }
rustls = { version = "0.22.2", optional = true }
Expand Down
44 changes: 44 additions & 0 deletions payjoin/src/receive/v2/error.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
use core::fmt;
use std::error;

use crate::v2::OhttpEncapsulationError;

#[derive(Debug)]
pub struct SessionError(InternalSessionError);

#[derive(Debug)]
pub(crate) enum InternalSessionError {
/// The session has expired
Expired(std::time::SystemTime),
/// OHTTP Encapsulation failed
OhttpEncapsulationError(OhttpEncapsulationError),
}

impl fmt::Display for SessionError {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
match &self.0 {
InternalSessionError::Expired(expiry) => write!(f, "Session expired at {:?}", expiry),
InternalSessionError::OhttpEncapsulationError(e) =>
write!(f, "OHTTP Encapsulation Error: {}", e),
}
}
}

impl error::Error for SessionError {
fn source(&self) -> Option<&(dyn error::Error + 'static)> {
match &self.0 {
InternalSessionError::Expired(_) => None,
InternalSessionError::OhttpEncapsulationError(e) => Some(e),
}
}
}

impl From<InternalSessionError> for SessionError {
fn from(e: InternalSessionError) -> Self { SessionError(e) }
}

impl From<OhttpEncapsulationError> for SessionError {
fn from(e: OhttpEncapsulationError) -> Self {
SessionError(InternalSessionError::OhttpEncapsulationError(e))
}
}
30 changes: 21 additions & 9 deletions payjoin/src/receive/v2.rs → payjoin/src/receive/v2/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,13 +12,17 @@ use serde::ser::SerializeStruct;
use serde::{Deserialize, Serialize, Serializer};
use url::Url;

use super::v2::error::{InternalSessionError, SessionError};
use super::{Error, InternalRequestError, RequestError, SelectionError};
use crate::psbt::PsbtExt;
use crate::receive::optional_parameters::Params;
use crate::v2::OhttpEncapsulationError;
use crate::{OhttpKeys, PjUriBuilder, Request};

/// The state for a payjoin V2 receive session, including necessary
/// information for communication and cryptographic operations.
pub(crate) mod error;

static TWENTY_FOUR_HOURS_DEFAULT_EXPIRY: Duration = Duration::from_secs(60 * 60 * 24);

#[derive(Debug, Clone, PartialEq, Eq)]
struct SessionContext {
address: Address,
Expand Down Expand Up @@ -58,7 +62,7 @@ impl SessionInitializer {
directory: Url,
ohttp_keys: OhttpKeys,
ohttp_relay: Url,
expire_after: Duration,
expire_after: Option<Duration>,
) -> Self {
let secp = bitcoin::secp256k1::Secp256k1::new();
let (sk, _) = secp.generate_keypair(&mut rand::rngs::OsRng);
Expand All @@ -68,7 +72,8 @@ impl SessionInitializer {
directory,
ohttp_keys,
ohttp_relay,
expiry: SystemTime::now() + expire_after,
expiry: SystemTime::now()
+ expire_after.unwrap_or(TWENTY_FOUR_HOURS_DEFAULT_EXPIRY),
s: bitcoin::secp256k1::KeyPair::from_secret_key(&secp, &sk),
e: None,
},
Expand Down Expand Up @@ -116,8 +121,12 @@ pub struct ActiveSession {
}

impl ActiveSession {
pub fn extract_req(&mut self) -> Result<(Request, ohttp::ClientResponse), Error> {
let (body, ohttp_ctx) = self.fallback_req_body()?;
pub fn extract_req(&mut self) -> Result<(Request, ohttp::ClientResponse), SessionError> {
if SystemTime::now() > self.context.expiry {
return Err(InternalSessionError::Expired(self.context.expiry).into());
}
let (body, ohttp_ctx) =
self.fallback_req_body().map_err(InternalSessionError::OhttpEncapsulationError)?;
let url = self.context.ohttp_relay.clone();
let req = Request { url, body };
Ok((req, ohttp_ctx))
Expand Down Expand Up @@ -146,14 +155,16 @@ impl ActiveSession {
}
}

fn fallback_req_body(&mut self) -> Result<(Vec<u8>, ohttp::ClientResponse), Error> {
fn fallback_req_body(
&mut self,
) -> Result<(Vec<u8>, ohttp::ClientResponse), OhttpEncapsulationError> {
let fallback_target = self.pj_url();
Ok(crate::v2::ohttp_encapsulate(
crate::v2::ohttp_encapsulate(
&mut self.context.ohttp_keys,
"GET",
fallback_target.as_str(),
None,
)?)
)
}

fn extract_proposal_from_v1(&mut self, response: String) -> Result<UncheckedProposal, Error> {
Expand Down Expand Up @@ -203,6 +214,7 @@ impl ActiveSession {
self.context.address.clone(),
self.pj_url(),
Some(self.context.ohttp_keys.clone()),
Some(self.context.expiry),
)
}

Expand Down
6 changes: 3 additions & 3 deletions payjoin/src/send/error.rs
Original file line number Diff line number Diff line change
Expand Up @@ -196,7 +196,7 @@ pub(crate) enum InternalCreateRequestError {
#[cfg(feature = "v2")]
MissingOhttpConfig,
#[cfg(feature = "v2")]
PercentEncoding,
Expired(std::time::SystemTime),
}

impl fmt::Display for CreateRequestError {
Expand Down Expand Up @@ -227,7 +227,7 @@ impl fmt::Display for CreateRequestError {
#[cfg(feature = "v2")]
MissingOhttpConfig => write!(f, "no ohttp configuration with which to make a v2 request available"),
#[cfg(feature = "v2")]
PercentEncoding => write!(f, "fragment is not RFC 3986 percent-encoded"),
Expired(expiry) => write!(f, "session expired at {:?}", expiry),
}
}
}
Expand Down Expand Up @@ -260,7 +260,7 @@ impl std::error::Error for CreateRequestError {
#[cfg(feature = "v2")]
MissingOhttpConfig => None,
#[cfg(feature = "v2")]
PercentEncoding => None,
Expired(_) => None,
}
}
}
Expand Down
13 changes: 8 additions & 5 deletions payjoin/src/send/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -315,6 +315,12 @@ impl RequestContext {
ohttp_relay: Url,
) -> Result<(Request, ContextV2), CreateRequestError> {
use crate::uri::UrlExt;

if let Some(expiry) = self.endpoint.exp() {
if std::time::SystemTime::now() > expiry {
return Err(InternalCreateRequestError::Expired(expiry).into());
}
}
let rs = Self::rs_pubkey_from_dir_endpoint(&self.endpoint)?;
let url = self.endpoint.clone();
let body = serialize_v2_body(
Expand All @@ -325,11 +331,8 @@ impl RequestContext {
)?;
let body = crate::v2::encrypt_message_a(body, self.e, rs)
.map_err(InternalCreateRequestError::Hpke)?;
let mut ohttp = self
.endpoint
.ohttp()
.map_err(|_| InternalCreateRequestError::PercentEncoding)?
.ok_or(InternalCreateRequestError::MissingOhttpConfig)?;
let mut ohttp =
self.endpoint.ohttp().ok_or(InternalCreateRequestError::MissingOhttpConfig)?;
let (body, ohttp_res) =
crate::v2::ohttp_encapsulate(&mut ohttp, "POST", url.as_str(), Some(&body))
.map_err(InternalCreateRequestError::OhttpEncapsulation)?;
Expand Down
8 changes: 7 additions & 1 deletion payjoin/src/uri/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -107,15 +107,19 @@ impl PjUriBuilder {
/// - `address`: Represents a bitcoin address.
/// - `origin`: Represents either the payjoin endpoint in v1 or the directory in v2.
/// - `ohttp_keys`: Optional OHTTP keys for v2 (only available if the "v2" feature is enabled).
/// - `expiry`: Optional non-default expiry for the payjoin session (only available if the "v2" feature is enabled).
pub fn new(
address: Address,
origin: Url,
#[cfg(feature = "v2")] ohttp_keys: Option<OhttpKeys>,
#[cfg(feature = "v2")] expiry: Option<std::time::SystemTime>,
) -> Self {
#[allow(unused_mut)]
let mut pj = origin;
#[cfg(feature = "v2")]
let _ = pj.set_ohttp(ohttp_keys);
pj.set_ohttp(ohttp_keys);
#[cfg(feature = "v2")]
pj.set_exp(expiry);
Self { address, amount: None, message: None, label: None, pj, pjos: false }
}
/// Set the amount you want to receive.
Expand Down Expand Up @@ -352,6 +356,8 @@ mod tests {
Url::parse(pj).unwrap(),
#[cfg(feature = "v2")]
None,
#[cfg(feature = "v2")]
None,
)
.amount(amount)
.message("message".to_string())
Expand Down
Loading

0 comments on commit d9c76dd

Please sign in to comment.