Skip to content

Commit

Permalink
refactor(starknet_batcher): delete the proposal manager
Browse files Browse the repository at this point in the history
  • Loading branch information
dafnamatsry committed Dec 10, 2024
1 parent 3bc9a7f commit cfc0dcc
Show file tree
Hide file tree
Showing 6 changed files with 236 additions and 667 deletions.
215 changes: 145 additions & 70 deletions crates/starknet_batcher/src/batcher.rs
Original file line number Diff line number Diff line change
Expand Up @@ -28,16 +28,17 @@ use starknet_batcher_types::errors::BatcherError;
use starknet_mempool_types::communication::SharedMempoolClient;
use starknet_mempool_types::mempool_types::CommitBlockArgs;
use starknet_sequencer_infra::component_definitions::ComponentStarter;
use tracing::{debug, error, info, instrument, trace};
use tokio::sync::Mutex;
use tracing::{debug, error, info, instrument, trace, Instrument};

use crate::block_builder::{
BlockBuilderExecutionParams,
BlockBuilderFactory,
BlockBuilderFactoryTrait,
BlockBuilderTrait,
BlockMetadata,
};
use crate::config::BatcherConfig;
use crate::proposal_manager::{GenerateProposalError, ProposalManager, ProposalManagerTrait};
use crate::transaction_provider::{
DummyL1ProviderClient,
ProposeTransactionProvider,
Expand All @@ -47,8 +48,10 @@ use crate::utils::{
deadline_as_instant,
proposal_status_from,
verify_block_input,
ProposalError,
ProposalOutput,
ProposalResult,
ProposalTask,
};

type OutputStreamReceiver = tokio::sync::mpsc::UnboundedReceiver<Transaction>;
Expand All @@ -60,11 +63,30 @@ pub struct Batcher {
pub storage_writer: Box<dyn BatcherStorageWriterTrait>,
pub mempool_client: SharedMempoolClient,

// Used to create block builders.
// Using the factory pattern to allow for easier testing.
block_builder_factory: Box<dyn BlockBuilderFactoryTrait>,

// The height that the batcher is currently working on.
// All proposals are considered to be at this height.
active_height: Option<BlockNumber>,
proposal_manager: Box<dyn ProposalManagerTrait>,

block_builder_factory: Box<dyn BlockBuilderFactoryTrait>,
// The block proposal that is currently being built, if any.
// At any given time, there can be only one proposal being actively executed (either proposed
// or validated).
active_proposal: Arc<Mutex<Option<ProposalId>>>,
active_proposal_task: Option<ProposalTask>,

// Holds all the proposals that completed execution in the current height.
executed_proposals: Arc<Mutex<HashMap<ProposalId, ProposalResult<ProposalOutput>>>>,

// The propose blocks transaction streams, used to stream out the proposal transactions.
// Each stream is kept until all the transactions are streamed out, or a new height is started.
propose_tx_streams: HashMap<ProposalId, OutputStreamReceiver>,

// The validate blocks transaction streams, used to stream in the transactions to validate.
// Each stream is kept until SendProposalContent::Finish/Abort is received, or a new height is
// started.
validate_tx_streams: HashMap<ProposalId, InputStreamSender>,
}

Expand All @@ -75,16 +97,17 @@ impl Batcher {
storage_writer: Box<dyn BatcherStorageWriterTrait>,
mempool_client: SharedMempoolClient,
block_builder_factory: Box<dyn BlockBuilderFactoryTrait>,
proposal_manager: Box<dyn ProposalManagerTrait>,
) -> Self {
Self {
config: config.clone(),
storage_reader,
storage_writer,
mempool_client,
active_height: None,
block_builder_factory,
proposal_manager,
active_height: None,
active_proposal: Arc::new(Mutex::new(None)),
active_proposal_task: None,
executed_proposals: Arc::new(Mutex::new(HashMap::new())),
propose_tx_streams: HashMap::new(),
validate_tx_streams: HashMap::new(),
}
Expand Down Expand Up @@ -112,7 +135,8 @@ impl Batcher {
}

// Clear all the proposals from the previous height.
self.proposal_manager.reset().await;
self.abort_active_proposal().await;
self.executed_proposals.lock().await.clear();
self.propose_tx_streams.clear();
self.validate_tx_streams.clear();

Expand All @@ -134,6 +158,8 @@ impl Batcher {
propose_block_input.retrospective_block_hash,
)?;

self.set_active_proposal(propose_block_input.proposal_id).await?;

let tx_provider = ProposeTransactionProvider::new(
self.mempool_client.clone(),
// TODO: use a real L1 provider client.
Expand All @@ -160,8 +186,7 @@ impl Batcher {
)
.map_err(|_| BatcherError::InternalError)?;

self.proposal_manager
.spawn_proposal(propose_block_input.proposal_id, block_builder, abort_signal_sender)
self.spawn_proposal(propose_block_input.proposal_id, block_builder, abort_signal_sender)
.await?;

self.propose_tx_streams.insert(propose_block_input.proposal_id, output_tx_receiver);
Expand All @@ -180,6 +205,8 @@ impl Batcher {
validate_block_input.retrospective_block_hash,
)?;

self.set_active_proposal(validate_block_input.proposal_id).await?;

// A channel to send the transactions to include in the block being validated.
let (input_tx_sender, input_tx_receiver) =
tokio::sync::mpsc::channel(self.config.input_stream_content_buffer_size);
Expand All @@ -206,8 +233,7 @@ impl Batcher {
)
.map_err(|_| BatcherError::InternalError)?;

self.proposal_manager
.spawn_proposal(validate_block_input.proposal_id, block_builder, abort_signal_sender)
self.spawn_proposal(validate_block_input.proposal_id, block_builder, abort_signal_sender)
.await?;

self.validate_tx_streams.insert(validate_block_input.proposal_id, input_tx_sender);
Expand All @@ -229,11 +255,7 @@ impl Batcher {
match send_proposal_content_input.content {
SendProposalContent::Txs(txs) => self.handle_send_txs_request(proposal_id, txs).await,
SendProposalContent::Finish => self.handle_finish_proposal_request(proposal_id).await,
SendProposalContent::Abort => {
self.proposal_manager.abort_proposal(proposal_id).await;
self.close_input_transaction_stream(proposal_id)?;
Ok(SendProposalContentResponse { response: ProposalStatus::Aborted })
}
SendProposalContent::Abort => self.handle_abort_proposal_request(proposal_id).await,
}
}

Expand Down Expand Up @@ -261,6 +283,8 @@ impl Batcher {
let proposal_result =
self.get_completed_proposal_result(proposal_id).await.expect("Proposal should exist.");
match proposal_result {
// TODO(dafna): at this point the proposal result must be an error, since it finsisehd
// earlier than expected. Consider panicking instead of returning an error.
Ok(_) => Err(BatcherError::ProposalAlreadyFinished { proposal_id }),
Err(err) => Ok(SendProposalContentResponse { response: proposal_status_from(err)? }),
}
Expand All @@ -272,9 +296,11 @@ impl Batcher {
) -> BatcherResult<SendProposalContentResponse> {
debug!("Send proposal content done for {}", proposal_id);

self.close_input_transaction_stream(proposal_id)?;
self.validate_tx_streams.remove(&proposal_id);
if self.is_active(proposal_id).await {
self.proposal_manager.await_active_proposal().await;
if let Some(proposal_task) = self.active_proposal_task.take() {
proposal_task.join_handle.await.ok();
}
}

let proposal_result =
Expand All @@ -286,11 +312,16 @@ impl Batcher {
Ok(SendProposalContentResponse { response: proposal_status })
}

fn close_input_transaction_stream(&mut self, proposal_id: ProposalId) -> BatcherResult<()> {
self.validate_tx_streams
.remove(&proposal_id)
.ok_or(BatcherError::ProposalNotFound { proposal_id })?;
Ok(())
async fn handle_abort_proposal_request(
&mut self,
proposal_id: ProposalId,
) -> BatcherResult<SendProposalContentResponse> {
if self.is_active(proposal_id).await {
self.abort_active_proposal().await;
self.executed_proposals.lock().await.insert(proposal_id, Err(ProposalError::Aborted));
}
self.validate_tx_streams.remove(&proposal_id);
Ok(SendProposalContentResponse { response: ProposalStatus::Aborted })
}

#[instrument(skip(self), err)]
Expand Down Expand Up @@ -329,39 +360,95 @@ impl Batcher {

#[instrument(skip(self), err)]
pub async fn decision_reached(&mut self, input: DecisionReachedInput) -> BatcherResult<()> {
let height = self.active_height.ok_or(BatcherError::NoActiveHeight)?;

let proposal_id = input.proposal_id;
let proposal_output = self
.proposal_manager
.take_proposal_result(proposal_id)
.await
.ok_or(BatcherError::ExecutedProposalNotFound { proposal_id })??;
let proposal_result = self.executed_proposals.lock().await.remove(&proposal_id);
let ProposalOutput { state_diff, nonces: address_to_nonce, tx_hashes, .. } =
proposal_output;
// TODO: Keep the height from start_height or get it from the input.
let height = self.storage_reader.height().map_err(|err| {
error!("Failed to get height from storage: {}", err);
BatcherError::InternalError
})?;
proposal_result.ok_or(BatcherError::ExecutedProposalNotFound { proposal_id })??;

info!(
"Committing proposal {} at height {} and notifying mempool of the block.",
proposal_id, height
);
trace!("Transactions: {:#?}, State diff: {:#?}.", tx_hashes, state_diff);

// Commit the proposal to the storage and notify the mempool.
self.storage_writer.commit_proposal(height, state_diff).map_err(|err| {
error!("Failed to commit proposal to storage: {}", err);
BatcherError::InternalError
})?;
if let Err(mempool_err) =
self.mempool_client.commit_block(CommitBlockArgs { address_to_nonce, tx_hashes }).await
{
let mempool_result =
self.mempool_client.commit_block(CommitBlockArgs { address_to_nonce, tx_hashes }).await;

if let Err(mempool_err) = mempool_result {
error!("Failed to commit block to mempool: {}", mempool_err);
// TODO: Should we rollback the state diff and return an error?
}
};

Ok(())
}

async fn is_active(&self, proposal_id: ProposalId) -> bool {
self.proposal_manager.get_active_proposal().await == Some(proposal_id)
*self.active_proposal.lock().await == Some(proposal_id)
}

// Sets a new active proposal task.
// Fails if there is another proposal being currently generated, or a proposal with the same ID
// already exists.
async fn set_active_proposal(&mut self, proposal_id: ProposalId) -> BatcherResult<()> {
if self.executed_proposals.lock().await.contains_key(&proposal_id) {
return Err(BatcherError::ProposalAlreadyExists { proposal_id });
}

let mut active_proposal = self.active_proposal.lock().await;
if let Some(active_proposal_id) = *active_proposal {
return Err(BatcherError::ServerBusy {
active_proposal_id,
new_proposal_id: proposal_id,
});
}

debug!("Set proposal {} as the one being generated.", proposal_id);
*active_proposal = Some(proposal_id);
Ok(())
}

// Starts a new block proposal generation task for the given proposal_id.
// Uses the given block_builder to generate the proposal.
async fn spawn_proposal(
&mut self,
proposal_id: ProposalId,
mut block_builder: Box<dyn BlockBuilderTrait>,
abort_signal_sender: tokio::sync::oneshot::Sender<()>,
) -> BatcherResult<()> {
info!("Starting generation of a new proposal with id {}.", proposal_id);

let active_proposal = self.active_proposal.clone();
let executed_proposals = self.executed_proposals.clone();

let join_handle = tokio::spawn(
async move {
let result = block_builder
.build_block()
.await
.map(ProposalOutput::from)
.map_err(|e| ProposalError::BlockBuilderError(Arc::new(e)));

// The proposal is done, clear the active proposal.
// Keep the proposal result only if it is the same as the active proposal.
// The active proposal might have changed if this proposal was aborted.
let mut active_proposal = active_proposal.lock().await;
if *active_proposal == Some(proposal_id) {
active_proposal.take();
executed_proposals.lock().await.insert(proposal_id, result);
}
}
.in_current_span(),
);

self.active_proposal_task = Some(ProposalTask { abort_signal_sender, join_handle });
Ok(())
}

// Returns a completed proposal result, either its commitment or an error if the proposal
Expand All @@ -370,8 +457,7 @@ impl Batcher {
&self,
proposal_id: ProposalId,
) -> Option<ProposalResult<ProposalCommitment>> {
let completed_proposals = self.proposal_manager.get_completed_proposals().await;
let guard = completed_proposals.lock().await;
let guard = self.executed_proposals.lock().await;
let proposal_result = guard.get(&proposal_id);

match proposal_result {
Expand All @@ -380,6 +466,22 @@ impl Batcher {
None => None,
}
}

// Ends the current active proposal.
// This call is non-blocking.
async fn abort_active_proposal(&mut self) {
self.active_proposal.lock().await.take();
if let Some(proposal_task) = self.active_proposal_task.take() {
proposal_task.abort_signal_sender.send(()).ok();
}
}

#[cfg(test)]
pub async fn await_active_proposal(&mut self) {
if let Some(proposal_task) = self.active_proposal_task.take() {
proposal_task.join_handle.await.ok();
}
}
}

pub fn create_batcher(config: BatcherConfig, mempool_client: SharedMempoolClient) -> Batcher {
Expand All @@ -393,15 +495,7 @@ pub fn create_batcher(config: BatcherConfig, mempool_client: SharedMempoolClient
});
let storage_reader = Arc::new(storage_reader);
let storage_writer = Box::new(storage_writer);
let proposal_manager = Box::new(ProposalManager::new());
Batcher::new(
config,
storage_reader,
storage_writer,
mempool_client,
block_builder_factory,
proposal_manager,
)
Batcher::new(config, storage_reader, storage_writer, mempool_client, block_builder_factory)
}

#[cfg_attr(test, automock)]
Expand Down Expand Up @@ -436,23 +530,4 @@ impl BatcherStorageWriterTrait for papyrus_storage::StorageWriter {
}
}

impl From<GenerateProposalError> for BatcherError {
fn from(err: GenerateProposalError) -> Self {
match err {
GenerateProposalError::AlreadyGeneratingProposal {
current_generating_proposal_id,
new_proposal_id,
} => BatcherError::ServerBusy {
active_proposal_id: current_generating_proposal_id,
new_proposal_id,
},
GenerateProposalError::BlockBuilderError(..) => BatcherError::InternalError,
GenerateProposalError::NoActiveHeight => BatcherError::NoActiveHeight,
GenerateProposalError::ProposalAlreadyExists { proposal_id } => {
BatcherError::ProposalAlreadyExists { proposal_id }
}
}
}
}

impl ComponentStarter for Batcher {}
Loading

0 comments on commit cfc0dcc

Please sign in to comment.