Skip to content

Commit

Permalink
Cache safe head Header lookups
Browse files Browse the repository at this point in the history
  • Loading branch information
clabby committed Jan 19, 2025
1 parent eb4856d commit 55338f7
Show file tree
Hide file tree
Showing 3 changed files with 41 additions and 28 deletions.
1 change: 1 addition & 0 deletions Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions crates/proof-sdk/proof-interop/Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@ serde.workspace = true
tracing.workspace = true
serde_json.workspace = true
async-trait.workspace = true
spin.workspace = true

# Arbitrary
arbitrary = { version = "1.4", features = ["derive"], optional = true }
Expand Down
67 changes: 39 additions & 28 deletions crates/proof-sdk/proof-interop/src/provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,9 @@ use kona_interop::InteropProvider;
use kona_mpt::{OrderedListWalker, TrieNode, TrieProvider};
use kona_preimage::{CommsClient, PreimageKey, PreimageKeyType};
use kona_proof::errors::OracleProviderError;
use maili_registry::HashMap;
use op_alloy_consensus::OpReceiptEnvelope;
use spin::RwLock;

/// A [CommsClient] backed [InteropProvider] implementation.
#[derive(Debug, Clone)]
Expand All @@ -20,15 +22,17 @@ pub struct OracleInteropProvider<T> {
oracle: Arc<T>,
/// The [PreState] for the current program execution.
pre_state: PreState,
/// The safe head block header cache, keyed by chain ID.
safe_head_cache: Arc<RwLock<HashMap<u64, Header>>>,
}

impl<T> OracleInteropProvider<T>
where
T: CommsClient + Send + Sync,
{
/// Creates a new [OracleInteropProvider] with the given oracle client and [PreState].
pub const fn new(oracle: Arc<T>, pre_state: PreState) -> Self {
Self { oracle, pre_state }
pub fn new(oracle: Arc<T>, pre_state: PreState) -> Self {
Self { oracle, pre_state, safe_head_cache: Arc::new(RwLock::new(HashMap::new())) }
}

/// Fetch the [Header] for the block with the given hash.
Expand Down Expand Up @@ -98,33 +102,40 @@ where
async fn header_by_number(&self, chain_id: u64, number: u64) -> Result<Header, Self::Error> {
// Find the safe head for the given chain ID.
//
// TODO: Deduplicate + cache safe head lookups.
let pre_state = match &self.pre_state {
PreState::SuperRoot(super_root) => super_root,
PreState::TransitionState(transition_state) => &transition_state.pre_state,
// If the safe head is not in the cache, we need to fetch it from the oracle.
let mut header = if let Some(header) = self.safe_head_cache.read().get(&chain_id) {
header.clone()
} else {
let pre_state = match &self.pre_state {
PreState::SuperRoot(super_root) => super_root,
PreState::TransitionState(transition_state) => &transition_state.pre_state,
};
let output = pre_state
.output_roots
.iter()
.find(|o| o.chain_id == chain_id)
.ok_or(OracleProviderError::UnknownChainId(chain_id))?;
self.oracle
.write(&HintType::L2OutputRoot.encode_with(&[
output.output_root.as_slice(),
output.chain_id.to_be_bytes().as_slice(),
]))
.await
.map_err(OracleProviderError::Preimage)?;
let output_preimage = self
.oracle
.get(PreimageKey::new(*output.output_root, PreimageKeyType::Keccak256))
.await
.map_err(OracleProviderError::Preimage)?;
let safe_head_hash = output_preimage[96..128]
.try_into()
.map_err(OracleProviderError::SliceConversion)?;

// Fetch the starting block header.
let header = self.header_by_hash(chain_id, safe_head_hash).await?;
self.safe_head_cache.write().insert(chain_id, header.clone());
header
};
let output = pre_state
.output_roots
.iter()
.find(|o| o.chain_id == chain_id)
.ok_or(OracleProviderError::UnknownChainId(chain_id))?;
self.oracle
.write(&HintType::L2OutputRoot.encode_with(&[
output.output_root.as_slice(),
output.chain_id.to_be_bytes().as_slice(),
]))
.await
.map_err(OracleProviderError::Preimage)?;
let output_preimage = self
.oracle
.get(PreimageKey::new(*output.output_root, PreimageKeyType::Keccak256))
.await
.map_err(OracleProviderError::Preimage)?;
let safe_head_hash =
output_preimage[96..128].try_into().map_err(OracleProviderError::SliceConversion)?;

// Fetch the starting block header.
let mut header = self.header_by_hash(chain_id, safe_head_hash).await?;

// Check if the block number is in range. If not, we can fail early.
if number > header.number {
Expand Down

0 comments on commit 55338f7

Please sign in to comment.