diff --git a/buf.gen.yaml b/buf.gen.yaml index 3ba58f14be..6608c4e185 100644 --- a/buf.gen.yaml +++ b/buf.gen.yaml @@ -7,4 +7,4 @@ plugins: inputs: - git_repo: https://github.com/starknet-io/starknet-p2p-specs.git branch: bcfa353a169c859e4d5d97757caccbe76f75bc06 # Latest commit as of 2025 May 6th - depth: 1 \ No newline at end of file + depth: 1 diff --git a/consensus/consensus.go b/consensus/consensus.go index 4365e101d6..6aabd49126 100644 --- a/consensus/consensus.go +++ b/consensus/consensus.go @@ -52,7 +52,9 @@ func Init( } currentHeight := types.Height(chainHeight + 1) - tendermintDB := consensusDB.NewTendermintDB[starknet.Value, starknet.Hash, starknet.Address](database) + tendermintDB := consensusDB.NewTendermintDB[ + starknet.Value, starknet.Hash, starknet.Address, + ](database) executor := builder.NewExecutor(blockchain, vm, logger, false, false) builder := builder.New(blockchain, executor) diff --git a/consensus/p2p/buffered/proto_broadcaster.go b/consensus/p2p/buffered/proto_broadcaster.go index 42e527b20e..dc4e417258 100644 --- a/consensus/p2p/buffered/proto_broadcaster.go +++ b/consensus/p2p/buffered/proto_broadcaster.go @@ -57,7 +57,8 @@ func (b ProtoBroadcaster[M]) Loop(ctx context.Context, topic *pubsub.Topic) { } for { - if err := topic.Publish(ctx, msgBytes); err != nil && !errors.Is(err, context.Canceled) { + err := topic.Publish(ctx, msgBytes) + if err != nil && !errors.Is(err, context.Canceled) { b.logger.Error("unable to send message", zap.Error(err)) time.Sleep(b.retryInterval) continue @@ -70,7 +71,8 @@ func (b ProtoBroadcaster[M]) Loop(ctx context.Context, topic *pubsub.Topic) { } case <-rebroadcasted.trigger: for msgBytes := range rebroadcasted.messages { - if err := topic.Publish(ctx, msgBytes); err != nil && !errors.Is(err, context.Canceled) { + err := topic.Publish(ctx, msgBytes) + if err != nil && !errors.Is(err, context.Canceled) { b.logger.Error("unable to rebroadcast message", zap.Error(err)) } } diff --git a/consensus/p2p/validator/proposal_stream.go b/consensus/p2p/validator/proposal_stream.go index b519294704..b39a21e85b 100644 --- a/consensus/p2p/validator/proposal_stream.go +++ b/consensus/p2p/validator/proposal_stream.go @@ -48,7 +48,9 @@ func newSingleProposalStream( } } -func (s *proposalStream) start(ctx context.Context, firstMessage *consensus.StreamMessage) (types.Height, error) { +func (s *proposalStream) start( + ctx context.Context, firstMessage *consensus.StreamMessage, +) (types.Height, error) { content := firstMessage.GetContent() if content == nil { return 0, fmt.Errorf("first message has empty content") diff --git a/consensus/p2p/vote/vote_broadcasters.go b/consensus/p2p/vote/vote_broadcasters.go index 603ae12944..853aa8a02c 100644 --- a/consensus/p2p/vote/vote_broadcasters.go +++ b/consensus/p2p/vote/vote_broadcasters.go @@ -34,7 +34,9 @@ func NewVoteBroadcaster[H types.Hash, A types.Addr]( } } -func (b *voteBroadcaster[H, A]) broadcast(ctx context.Context, message *types.Vote[H, A], voteType consensus.Vote_VoteType) { +func (b *voteBroadcaster[H, A]) broadcast( + ctx context.Context, message *types.Vote[H, A], voteType consensus.Vote_VoteType, +) { msg, err := b.voteAdapter.FromVote(message, voteType) if err != nil { b.logger.Error("unable to convert vote", zap.Error(err)) @@ -60,6 +62,10 @@ func (b *prevoteBroadcaster[H, A]) Broadcast(ctx context.Context, message *types type precommitBroadcaster[H types.Hash, A types.Addr] voteBroadcaster[H, A] -func (b *precommitBroadcaster[H, A]) Broadcast(ctx context.Context, message *types.Precommit[H, A]) { - (*voteBroadcaster[H, A])(b).broadcast(ctx, (*types.Vote[H, A])(message), consensus.Vote_Precommit) +func (b *precommitBroadcaster[H, A]) Broadcast( + ctx context.Context, message *types.Precommit[H, A], +) { + (*voteBroadcaster[H, A])(b).broadcast( + ctx, (*types.Vote[H, A])(message), consensus.Vote_Precommit, + ) } diff --git a/consensus/propeller/engine.go b/consensus/propeller/engine.go new file mode 100644 index 0000000000..0ffcfff119 --- /dev/null +++ b/consensus/propeller/engine.go @@ -0,0 +1,371 @@ +package propeller + +import ( + "context" + "fmt" + "time" + + "github.com/NethermindEth/juno/utils/log" + "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/peer" + "go.uber.org/zap" +) + +type broadcastResult struct { + units []Unit + errCh chan<- error +} + +// todo(rdr): using String until I find a better type +type StakerID struct { + peerID peer.ID //nolint:unused // populated once committee key wiring lands. + pubKey crypto.PubKey //nolint:unused // populated once committee key wiring lands. +} + +// Holds the state for a Committee ID: +// - The `scheduler` represents the Propeller Tree of peers +// - The `processor` stores the state of this committee: when the built or receive threshold +// have been reached +// - The `peerKeys` I am not sure yet todo(rdr): <- +type committeeState struct { + scheduler *Scheduler + // todo(rdr): A look at processor shows that it's lifetime is strictly coupled with the + // state of a current committee. They both should be created and closed at the + // same time. If it is like this then it stands to reason that it should be coupled + // here. Not 100% sure right now, so leaving a big todo for now. + + // todo(rdr): why do we need this + peerKeys map[peer.ID]crypto.PubKey +} + +// engineCommand is a tagged union of commands sent to the engine's Run() loop. +type engineCommand interface { + isCommand() +} + +type registerCommittee struct { + committeeID CommitteeID + peers []PeerCommittee + peersKeys []*StakerID + errCh chan error +} + +func (registerCommittee) isCommand() {} + +type unregisterCommittee struct { + committeeID CommitteeID +} + +func (unregisterCommittee) isCommand() {} + +type broadcast struct { + committeeID CommitteeID + msg []byte + errCh chan<- error +} + +func (broadcast) isCommand() {} + +type processUnit struct { + unit *Unit + sender peer.ID +} + +func (processUnit) isCommand() {} + +// Engine is the central orchestrator of the Propeller protocol. It: +// +// - Manages committee registrations (each committee has its own peer set and scheduler). +// - Process all incoming messages and broadcasts them when expected. +// - Handles broadcast requests from the service layer. +// - Forwards all noteworthy event to the service layer. +type Engine struct { + privKey crypto.PrivKey + localPeer peer.ID + + config Config + logger log.StructuredLogger + + // processor handles validates and process all the messages received by other peers + processor *Processor + + // committees holds the Scheduler (i.e. Propeller Tree) and Stakers ID of + // the peers of each registered channel + // todo(rdr): committeeState can set be there by value instead of by ref? + committees map[CommitteeID]*committeeState + + // todo(rdr): not sure of this one yet + // connected peers hold all the connected peers to the engine + connectedPeers map[peer.ID]struct{} + + // whenever a broadcast action is started, units preparation are done concurrently + // and delivered through this channel + unitsPrepared chan broadcastResult + + // eventCh is shared between all processors and the engine. The engine + // reads from it and forwards events to the application via Events(). + // todo(rdr): currently sent directly from the processor to the service, + // does the engine needs to do any filtering? + // eventCh chan Event + + // cmdCh receives commands from the propeller service and act on those + cmdCh chan engineCommand +} + +// NewEngine creates an engine instance. It returns the engine and the channel to +// send engineCommands to. +// Call Run() to start processing. +// +// Parameters: +// - privKey: this node's Ed25519 private key (for signing published messages). +// - config: protocol parameters. +// - log: structured logger. +// +// todo(rdr): Maybe in the future we don't want to expose the command channel and instead hide +// the interaction behind a public API. :think: +func NewEngine( + privKey crypto.PrivKey, + config *Config, + logger log.StructuredLogger, +) (*Engine, chan<- engineCommand, <-chan Event) { + localPeerID, err := peer.IDFromPrivateKey(privKey) + if err != nil { + // todo(rdr): pannic for now, error handling for later + panic(err) + } + + processor, eventsCh := NewProcessor(localPeerID, config) + + cmdCh := make(chan engineCommand) + + return &Engine{ + localPeer: localPeerID, + privKey: privKey, + config: *config, + logger: logger, + processor: processor, + committees: make(map[CommitteeID]*committeeState), + cmdCh: cmdCh, + unitsPrepared: make(chan broadcastResult), + // Unsure of the fields below + connectedPeers: make(map[peer.ID]struct{}), + }, cmdCh, eventsCh +} + +// registerCommittee creates the schedule and encoder for a new channel. +// +//nolint:unparam // peersKeys is part of the public registration API; wiring is still pending. +func (e *Engine) registerCommittee( + committeeID *CommitteeID, + peers []PeerCommittee, + peersKeys []*StakerID, +) error { + // todo(rdr): Why re-registration should be ignored, + // as far as I understand, it shouldn't happen :think: + if _, ok := e.committees[*committeeID]; ok { + e.logger.Warn( + "committee already registered, will ignore re-registration attempt", + // todo(rdr): give a proper string repr + zap.Any("committee id", committeeID), + ) + return nil + } + + schedule, err := NewScheduler(e.localPeer, peers) + if err != nil { + return fmt.Errorf("couldn't register a new committee: %w", err) + } + + e.committees[*committeeID] = &committeeState{ + scheduler: schedule, + // todo(rdr): need to add the peer pub keys + peerKeys: nil, + } + + e.logger.Info( + "registered new committee", + // todo(rdr): give a proper string representation + zap.Any("committeeID", committeeID), + zap.Int("peers", len(peers)), + zap.Int("dataShards", schedule.NumDataShards()), + zap.Int("codingShards", schedule.NumCodingShards()), + ) + + return nil +} + +// unregisterCommittee removes a channel's state. Not new processors will be started but +// currently running ones will continue until the timeout / stop naturally +func (e *Engine) unregisterCommittee(committeeID *CommitteeID) { + delete(e.committees, *committeeID) + // todo(rdr): We have to clean the processors, right? + // or will they shut down on their own eventually + // better to pass a context with cancel? + + e.logger.Info( + "unregistered propeller committee", + // todo(rdr): give a proper string representation + zap.Any("committee id", committeeID), + ) +} + +// prepareUnitsForBroadcast creates Proppeller units asynchronously since it is a very expensive +// operation. +func (e *Engine) prepareUnitsForBroadcast( + committeeID *CommitteeID, + data []byte, + errCh chan<- error, +) error { + cs, ok := e.committees[*committeeID] + if !ok { + return fmt.Errorf("cannot broadcast to an unregistered committee: %v", committeeID) + } + + // todo(rdr): unsure if this approach of passing arguments to the go routine makes sense + // todo(rdr): consider having a maximum amount of working threads and a queue tasks for this + // This is an expensive operation, hence we need to do it separately + go func(e *Engine, scheduler *Scheduler, committeeID CommitteeID, data []byte) { + units, err := CreatePropellerUnits( + e.privKey, + &committeeID, + // todo(rdr): Find how nonce is set when creating propeller units + Nonce(time.Now().UnixNano()), + data, + scheduler.NumDataShards(), + scheduler.NumCodingShards(), + ) + if err != nil { + errCh <- err + return + } + + // todo(rdr): Why do we send this back to the engine.Run thread instead of processing + // it right here? + e.unitsPrepared <- broadcastResult{ + units: units, + errCh: errCh, + } + }(e, cs.scheduler, *committeeID, data) + + return nil +} + +// broacast receives Propeller units (built in `prepareBroadcast`) and sends them +// +//nolint:unparam // ctx will be used once the actual sending is wired up. +func (e *Engine) broadcast(ctx context.Context, units []Unit) error { + targetCommittee := units[0].CommitteeID + + cs, ok := e.committees[targetCommittee] + if !ok { + return fmt.Errorf("target committee ID not found: %d", targetCommittee) + } + + targetPeers := cs.scheduler.BroadcastTargets() + if len(targetPeers) != len(units) { + return fmt.Errorf( + "different amount of target peers and propeller units to broadcast: %d vs %d", + len(targetPeers), + len(units), + ) + } + + // todo(rdr): I need to do the actual sending + // I need to pass to the eventCh all the units that it should receive + + return nil +} + +// processUnit routes an incoming unit to the correct processor, creating +// one if needed. +func (e *Engine) processUnit(ctx context.Context, unit *Unit, sender peer.ID) { + cs, ok := e.committees[unit.CommitteeID] + if !ok { + // note(rdr): maybe debug? + e.logger.Warn( + "received key for unregistered committee, dropping", + // todo(rdr): give a proper string representation + zap.Any("committee id", unit.CommitteeID), + ) + return + } + + err := e.processor.ProcessMessage(ctx, unit, sender, cs.scheduler) + if err != nil { + e.logger.Error("cannot process incoming unit", zap.Error(err)) + } +} + +// handleCommand dispatches a command to the appropriate handler. +func (e *Engine) handleCommand(ctx context.Context, command engineCommand) { + switch cmd := command.(type) { + case *registerCommittee: + err := e.registerCommittee(&cmd.committeeID, cmd.peers, cmd.peersKeys) + cmd.errCh <- err + case *unregisterCommittee: + e.unregisterCommittee(&cmd.committeeID) + case *broadcast: + // we might need to pass the error channel here so that the internal go-routine + // can forward it correctly (assuming a per command error channel) + if err := e.prepareUnitsForBroadcast(&cmd.committeeID, cmd.msg, cmd.errCh); err != nil { + cmd.errCh <- err + } + case *processUnit: + e.processUnit(ctx, cmd.unit, cmd.sender) + } +} + +// Run starts the engine's main loop until context is cancelled. +// The loop processes three things concurrently: +// 1. Commands from external callers (register, broadcast, handle incoming unit). +// 2. Events from message processors (forward to application). +// 3. Context cancellation (graceful shutdown). +func (e *Engine) Run(ctx context.Context) error { + for { + select { + case <-ctx.Done(): + return ctx.Err() + + case cmd := <-e.cmdCh: + e.handleCommand(ctx, cmd) + + case broadcastResult := <-e.unitsPrepared: + err := e.broadcast(ctx, broadcastResult.units) + broadcastResult.errCh <- err + } + } +} + +func (e *Engine) Broadcast(committeeID *CommitteeID, msg []byte) error { + // todo(rdr): check how costly is this? Is there a better way than creating a channel + errCh := make(chan error) + e.cmdCh <- &broadcast{ + committeeID: *committeeID, + msg: msg, + errCh: errCh, + } + return <-errCh +} + +func (e *Engine) RegisterCommittee( + committeeID *CommitteeID, + peers []PeerCommittee, + // todo(rdr): peersKeys is something I don't know how to set correctly yet + peersKeys []*StakerID, +) error { + // todo(rdr): does creating an error channel per call is performant or + // should we have a pool of err channels or that is too crazy :3 + // Thinking on the GC cost... + errCh := make(chan error) + e.cmdCh <- ®isterCommittee{ + committeeID: *committeeID, + peers: peers, + peersKeys: peersKeys, + errCh: errCh, + } + return <-errCh +} + +func (e *Engine) UnregisterCommittee(committeeID *CommitteeID) { + e.cmdCh <- &unregisterCommittee{committeeID: *committeeID} +} diff --git a/consensus/propeller/engine_test.go b/consensus/propeller/engine_test.go new file mode 100644 index 0000000000..66e9781a2d --- /dev/null +++ b/consensus/propeller/engine_test.go @@ -0,0 +1,372 @@ +package propeller_test + +// import ( +// "bytes" +// "context" +// "crypto/ed25519" +// "fmt" +// "sync" +// "testing" +// "time" +// +// "github.com/NethermindEth/juno/utils" +// "github.com/libp2p/go-libp2p/core/crypto" +// "github.com/libp2p/go-libp2p/core/peer" +// "github.com/stretchr/testify/assert" +// "github.com/stretchr/testify/require" +// ) +// +// engineTestEnv provides the common setup for engine-level tests. +// type engineTestEnv struct { +// peers []peer.ID +// privKeys []crypto.PrivKey +// engines []*Engine +// sentUnits map[peer.ID][]*Unit +// sentMu sync.Mutex +// log utils.Logger +// } +// +// //nolint:unparam // n is always 4 in current tests but kept for flexibility +// func newEngineTestEnv(t *testing.T, n int) *engineTestEnv { +// t.Helper() +// +// peers := make([]peer.ID, n) +// privKeys := make([]crypto.PrivKey, n) +// for i := range n { +// seed := make([]byte, ed25519.SeedSize) +// seed[0] = byte(i) +// reader := bytes.NewReader(seed) +// priv, pub, err := crypto.GenerateEd25519Key(reader) +// require.NoError(t, err) +// id, err := peer.IDFromPublicKey(pub) +// require.NoError(t, err) +// privKeys[i] = priv +// peers[i] = id +// } +// +// log := utils.NewNopZapLogger() +// +// env := &engineTestEnv{ +// peers: peers, +// privKeys: privKeys, +// sentUnits: make(map[peer.ID][]*Unit), +// log: log, +// } +// +// config := Config{ +// StaleMessageTimeout: 5 * time.Second, +// StreamProtocol: "/propeller/test/0.1.0", +// MaxWireMessageSize: 1 << 20, +// } +// +// engines := make([]*Engine, n) +// for i := range n { +// engines[i] = NewEngine( +// peers[i], privKeys[i], config, +// env.makeSendFn(), +// log, +// ) +// } +// env.engines = engines +// +// return env +// } +// +// // makeSendFn creates a SendUnitFunc that records sent units. +// func (env *engineTestEnv) makeSendFn() SendUnitFunc { +// return func(_ context.Context, to peer.ID, unit *Unit) error { +// env.sentMu.Lock() +// env.sentUnits[to] = append(env.sentUnits[to], unit) +// env.sentMu.Unlock() +// return nil +// } +// } +// +// // getSentUnits returns all units sent to a given peer. +// func (env *engineTestEnv) getSentUnits(to peer.ID) []*Unit { +// env.sentMu.Lock() +// defer env.sentMu.Unlock() +// result := make([]*Unit, len(env.sentUnits[to])) +// copy(result, env.sentUnits[to]) +// return result +// } +// +// func TestEngine_RegisterAndBroadcast(t *testing.T) { +// env := newEngineTestEnv(t, 4) +// +// ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second) +// defer cancel() +// +// engine := env.engines[0] +// +// // Run the engine in the background. +// done := make(chan error, 1) +// go func() { +// done <- engine.Run(ctx) +// }() +// +// // Register a channel with all peers. +// err := engine.RegisterChannel(ctx, 1, env.peers) +// require.NoError(t, err) +// +// // Broadcast a message. +// msg := []byte("hello from engine test") +// err = engine.Broadcast(ctx, 1, msg) +// require.NoError(t, err) +// +// // Verify that units were sent to the other 3 peers. +// // Give a moment for async processing. +// time.Sleep(100 * time.Millisecond) +// +// totalSent := 0 +// for _, p := range env.peers { +// if p == env.peers[0] { +// continue +// } +// units := env.getSentUnits(p) +// totalSent += len(units) +// } +// assert.Equal(t, 3, totalSent, "should send one unit to each non-publisher peer") +// +// cancel() +// <-done +// } +// +// func TestEngine_BroadcastUnregisteredChannel(t *testing.T) { +// env := newEngineTestEnv(t, 4) +// +// ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) +// defer cancel() +// +// engine := env.engines[0] +// +// go func() { +// engine.Run(ctx) //nolint:errcheck // test helper +// }() +// +// err := engine.Broadcast(ctx, 99, []byte("should fail")) +// require.Error(t, err) +// +// var pubErr *ShardPublishError +// require.ErrorAs(t, err, &pubErr) +// assert.Equal(t, ReasonChannelNotRegistered, pubErr.Reason) +// +// cancel() +// } +// +// func TestEngine_HandleUnit_CreatesProcessor(t *testing.T) { +// env := newEngineTestEnv(t, 4) +// +// ctx, cancel := context.WithTimeout(t.Context(), 10*time.Second) +// defer cancel() +// +// // Set up engine for peer 0. +// engine := env.engines[0] +// +// go func() { +// engine.Run(ctx) //nolint:errcheck // test helper +// }() +// +// // Register the channel. +// err := engine.RegisterChannel(ctx, 1, env.peers) +// require.NoError(t, err) +// +// // Simulate receiving a unit from peer 1 (as publisher). +// schedule := NewScheduler(env.peers) +// enc, err := NewEncoder(schedule.NumDataShards(), schedule.NumCodingShards()) +// require.NoError(t, err) +// +// msg := []byte("incoming message") +// units, root, err := EncodeMessage(msg, schedule, enc) +// require.NoError(t, err) +// +// publisher := env.peers[1] +// sig, err := SignMessage(root, env.privKeys[1]) +// require.NoError(t, err) +// +// for i := range units { +// units[i].Publisher = publisher +// units[i].Signature = sig +// units[i].CommitteeID = 1 +// } +// +// // Send units from their correct senders. +// for i, unit := range units { +// sender, err := schedule.PeerForShard(publisher, ShardIndex(i)) +// require.NoError(t, err) +// +// // Skip units "from ourselves" -- the validator rejects those. +// if sender == env.peers[0] { +// continue +// } +// +// unitCopy := unit +// engine.HandleUnit(&unitCopy, sender) +// } +// +// // Wait for the message to be processed and check events. +// var received *EventMessageReceived +// deadline := time.After(5 * time.Second) +// for received == nil { +// select { +// case ev := <-engine.Events(): +// if r, ok := ev.(EventMessageReceived); ok { +// received = &r +// } +// case <-deadline: +// t.Fatal("timed out waiting for EventMessageReceived") +// } +// } +// +// assert.Equal(t, msg, received.Message) +// assert.Equal(t, publisher, received.Publisher) +// assert.Equal(t, root, received.Root) +// +// cancel() +// } +// +// func TestEngine_UnregisterChannel(t *testing.T) { +// env := newEngineTestEnv(t, 4) +// +// ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) +// defer cancel() +// +// engine := env.engines[0] +// +// go func() { +// engine.Run(ctx) //nolint:errcheck // test helper +// }() +// +// err := engine.RegisterChannel(ctx, 1, env.peers) +// require.NoError(t, err) +// +// err = engine.UnregisterChannel(ctx, 1) +// require.NoError(t, err) +// +// // Allow command to be processed. +// time.Sleep(50 * time.Millisecond) +// +// // Broadcasting should fail now. +// err = engine.Broadcast(ctx, 1, []byte("after unregister")) +// require.Error(t, err) +// +// cancel() +// } +// +// func TestEngine_HandleUnit_UnregisteredChannel(t *testing.T) { +// env := newEngineTestEnv(t, 4) +// +// ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) +// defer cancel() +// +// engine := env.engines[0] +// +// go func() { +// engine.Run(ctx) //nolint:errcheck // test helper +// }() +// +// // Send a unit for an unregistered channel. +// unit := &Unit{ +// CommitteeID: 99, +// Publisher: env.peers[1], +// MessageRoot: MessageRoot{0x01}, +// ShardIndex: 0, +// ShardData: []byte("data"), +// } +// engine.HandleUnit(unit, env.peers[1]) +// +// // Allow time for processing. +// time.Sleep(100 * time.Millisecond) +// +// // No crash, no panic -- the unit is silently dropped. +// cancel() +// } +// +// func TestEngine_GracefulShutdown(t *testing.T) { +// env := newEngineTestEnv(t, 4) +// +// ctx, cancel := context.WithCancel(t.Context()) +// +// engine := env.engines[0] +// done := make(chan error, 1) +// go func() { +// done <- engine.Run(ctx) +// }() +// +// cancel() +// +// select { +// case err := <-done: +// assert.ErrorIs(t, err, context.Canceled) +// case <-time.After(2 * time.Second): +// t.Fatal("engine did not shut down in time") +// } +// } +// +// func TestEngine_SendFailureEmitsEvent(t *testing.T) { +// env := newEngineTestEnv(t, 4) +// +// ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) +// defer cancel() +// +// // Create an engine with a failing send function. +// engine := NewEngine( +// env.peers[0], env.privKeys[0], +// Config{ +// StaleMessageTimeout: 5 * time.Second, +// StreamProtocol: "/propeller/test/0.1.0", +// MaxWireMessageSize: 1 << 20, +// }, +// func(_ context.Context, _ peer.ID, _ *Unit) error { +// return fmt.Errorf("simulated network failure") +// }, +// utils.NewNopZapLogger(), +// ) +// +// go func() { +// engine.Run(ctx) //nolint:errcheck // test helper +// }() +// +// err := engine.RegisterChannel(ctx, 1, env.peers) +// require.NoError(t, err) +// +// err = engine.Broadcast(ctx, 1, []byte("will fail sending")) +// require.NoError(t, err) // Broadcast itself succeeds; send failures are events. +// +// // Collect send failure events. +// deadline := time.After(2 * time.Second) +// failures := 0 +// loop: +// for failures < 3 { +// select { +// case ev := <-engine.Events(): +// if _, ok := ev.(EventShardSendFailed); ok { +// failures++ +// } +// case <-deadline: +// break loop +// } +// } +// assert.Equal(t, 3, failures, "should have 3 send failures (one per non-publisher peer)") +// +// cancel() +// } +// +// func TestEngine_RegisterChannelTooFewPeers(t *testing.T) { +// env := newEngineTestEnv(t, 4) +// +// ctx, cancel := context.WithTimeout(t.Context(), 5*time.Second) +// defer cancel() +// +// engine := env.engines[0] +// +// go func() { +// engine.Run(ctx) //nolint:errcheck // test helper +// }() +// +// // A single peer cannot form a channel (0 shards). +// err := engine.RegisterChannel(ctx, 1, []peer.ID{env.peers[0]}) +// require.Error(t, err) +// +// cancel() +// } diff --git a/consensus/propeller/merkle/merkle.go b/consensus/propeller/merkle/merkle.go new file mode 100644 index 0000000000..eb708d024a --- /dev/null +++ b/consensus/propeller/merkle/merkle.go @@ -0,0 +1,173 @@ +// Package merkle implements Merkle tree construction and verification using a +// SHA-256 tagging scheme. Tags prevent second-preimage attacks by +// domain-separating leaf hashes from internal node hashes. The exact tag +// format matches the Propeller protocol specification so that all +// implementations produce identical trees. +// +// Tree layout: leaves are at the bottom, padded to the next power-of-two +// with the hash of empty data. The tree is built bottom-up by hashing pairs. +package merkle + +import ( + "crypto/sha256" + "math/bits" +) + +const ( + leafOpenTag = "" + leafCloseTag = "" + nodeOpenTag = "" + nodeMidTag = "" + nodeCloseTag = "" +) + +// Pre-computed domain-separator tags to avoid repeated []byte conversions. +var ( + leafOpen = []byte(leafOpenTag) + leafClose = []byte(leafCloseTag) + nodeOpen = []byte(nodeOpenTag) + nodeMid = []byte(nodeMidTag) + nodeClose = []byte(nodeCloseTag) +) + +// emptyLeafHash is the hash of a padding leaf (no data). We precompute it +// because the same value is used repeatedly when the leaf count is not a +// power of two. +var emptyLeafHash = merkleLeafHash(nil) + +type Hash [32]byte + +// Proof contains the sibling hashes needed to verify that a leaf +// belongs to a Merkle tree with a known root. Siblings are ordered from +// leaf level (index 0) up to the root. +type Proof struct { + Siblings []Hash +} + +// Verify checks that a leaf at the given index is included in a tree with +// the claimed root. The proof contains sibling hashes from the leaf level +// up to the root. +// +// The index determines the path through the tree: at each level, if the +// current bit of the index is 0 the current hash is the left child and the +// sibling is the right child, and vice versa. +func (p *Proof) Verify(root *Hash, leaf []byte, index uint32) bool { + current := merkleLeafHash(leaf) + + idx := index + for i := range p.Siblings { + if idx%2 == 0 { + current = merkleNodeHash(¤t, &p.Siblings[i]) + } else { + current = merkleNodeHash(&p.Siblings[i], ¤t) + } + idx /= 2 + } + + return current == *root +} + +// Tree is a set of inclusion proofs, one per original leaf. +type Tree []Proof + +// New constructs a binary Merkle tree from the given leaf data +// and returns the root hash plus one inclusion proof per original leaf. +// +// The tree is padded to the next power-of-two size with empty leaves. This +// simplifies the proof logic: every node at every level has a sibling, and +// the proof path length is always log2(paddedSize). +// +// Returns a zero root and nil Tree if leaves is empty. +func New(leaves [][]byte) (root Hash, tree Tree) { + n := len(leaves) + if n == 0 { + // todo(rdr): maybe here we return a default merkle tree + return Hash{}, nil + } + + size := nextPowerOfTwo(n) + + // Build the bottom layer: hash each leaf, pad to power-of-two. + layer := make([]Hash, size) + for i := range n { + //nolint: gosec // Everything is inbouds here + layer[i] = merkleLeafHash(leaves[i]) + } + for i := n; i < size; i++ { + layer[i] = emptyLeafHash + } + + // proofSiblings[i] accumulates the sibling hashes for leaf i's proof. + // ancestors[i] tracks leaf i's ancestor position in the current layer. + proofSiblings := make([][]Hash, n) + ancestors := make([]int, n) + for j := range n { + ancestors[j] = j + } + + // Build the tree bottom-up, one level at a time. + for len(layer) > 1 { + nextLayer := make([]Hash, len(layer)/2) + for i := 0; i < len(layer); i += 2 { + nextLayer[i/2] = merkleNodeHash(&layer[i], &layer[i+1]) + } + for i := range n { + proofSiblings[i] = append(proofSiblings[i], layer[ancestors[i]^1]) + ancestors[i] /= 2 + } + layer = nextLayer + } + + root = layer[0] + + tree = make([]Proof, n) + for i := range n { + tree[i] = Proof{Siblings: proofSiblings[i]} + } + + return root, tree +} + +// merkleLeafHash computes: SHA256("" || data || "") +// +// The XML-like tags are the domain separator specified by the Propeller +// protocol. They ensure a leaf hash can never collide with a node hash, +// even if an attacker controls the data. +func merkleLeafHash(data []byte) Hash { + buf := make([]byte, len(leafOpenTag)+len(data)+len(leafCloseTag)) + + n := copy(buf, leafOpen) + n += copy(buf[n:], data) + copy(buf[n:], leafClose) + + return sha256.Sum256(buf) +} + +// merkleNodeHash computes: +// +// SHA256("" || left || "" || right || "") +// +// The nested tags ensure node hashes are in a separate domain from leaf hashes. +func merkleNodeHash(left, right *Hash) Hash { + const size = len(nodeOpenTag) + 32 + len(nodeMidTag) + 32 + len(nodeCloseTag) + var buf [size]byte + + n := copy(buf[:], nodeOpen) + n += copy(buf[n:], left[:]) + n += copy(buf[n:], nodeMid) + n += copy(buf[n:], right[:]) + copy(buf[n:], nodeClose) + + return sha256.Sum256(buf[:]) +} + +// nextPowerOfTwo returns the smallest power of two >= n, with a minimum of 2. +// A minimum of 2 ensures even a single-leaf tree has a sibling for its proof. +func nextPowerOfTwo(n int) int { + if n <= 2 { + return 2 + } + // bits.Len returns the position of the highest set bit + 1. + // Subtracting 1 before Len handles exact powers-of-two correctly. + return 1 << bits.Len(uint(n-1)) +} diff --git a/consensus/propeller/merkle/merkle_test.go b/consensus/propeller/merkle/merkle_test.go new file mode 100644 index 0000000000..eaf870cc2e --- /dev/null +++ b/consensus/propeller/merkle/merkle_test.go @@ -0,0 +1,113 @@ +package merkle_test + +import ( + "fmt" + "testing" + + "github.com/NethermindEth/juno/consensus/propeller/merkle" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func makeLeaves(n int) [][]byte { + leaves := make([][]byte, n) + for i := range n { + leaves[i] = fmt.Appendf(nil, "leaf-%d", i) + } + return leaves +} + +func TestNew_Empty(t *testing.T) { + root, proofs := merkle.New(nil) + assert.Equal(t, merkle.Hash{}, root) + assert.Nil(t, proofs) +} + +func TestNew_ProofsVerify(t *testing.T) { + for _, n := range []int{1, 2, 3, 4, 5, 8, 16, 31} { + t.Run(fmt.Sprintf("n=%d", n), func(t *testing.T) { + leaves := makeLeaves(n) + root, proofs := merkle.New(leaves) + + require.Len(t, proofs, n) + assert.NotEqual(t, merkle.Hash{}, root) + + for i, leaf := range leaves { + assert.True(t, + proofs[i].Verify(&root, leaf, uint32(i)), + "proof for leaf %d should verify", i, + ) + } + }) + } +} + +func TestNew_WrongDataDoesNotVerify(t *testing.T) { + for _, n := range []int{1, 2, 3, 4, 5, 31} { + t.Run(fmt.Sprintf("n=%d", n), func(t *testing.T) { + leaves := makeLeaves(n) + root, proofs := merkle.New(leaves) + + for i := range leaves { + assert.False(t, + proofs[i].Verify(&root, []byte("tampered"), uint32(i)), + "tampered data should not verify for leaf %d", i, + ) + } + }) + } +} + +func TestVerify_Rejects(t *testing.T) { + leaves := [][]byte{[]byte("A"), []byte("B"), []byte("C"), []byte("D")} + root, proofs := merkle.New(leaves) + + t.Run("wrong index", func(t *testing.T) { + assert.False(t, proofs[0].Verify(&root, leaves[0], 1)) + }) + + t.Run("wrong root", func(t *testing.T) { + fakeRoot := merkle.Hash{0xff} + assert.False(t, proofs[0].Verify(&fakeRoot, leaves[0], 0)) + }) + + t.Run("tampered sibling", func(t *testing.T) { + badProof := merkle.Proof{Siblings: []merkle.Hash{{0xde, 0xad}}} + assert.False(t, badProof.Verify(&root, leaves[0], 0)) + }) +} + +func TestNew_Deterministic(t *testing.T) { + leaves := makeLeaves(7) + + root1, proofs1 := merkle.New(leaves) + root2, proofs2 := merkle.New(leaves) + + assert.Equal(t, root1, root2) + require.Len(t, proofs1, len(proofs2)) + for i := range proofs1 { + assert.Equal(t, proofs1[i], proofs2[i], "proof %d should be identical", i) + } +} + +func TestNew_DifferentLeavesDifferentRoots(t *testing.T) { + rootA, _ := merkle.New([][]byte{[]byte("A"), []byte("B")}) + rootB, _ := merkle.New([][]byte{[]byte("X"), []byte("Y")}) + + assert.NotEqual(t, rootA, rootB) +} + +func TestNew_CrossTreeIsolation(t *testing.T) { + leavesA := [][]byte{[]byte("A"), []byte("B"), []byte("C"), []byte("D")} + leavesB := [][]byte{[]byte("W"), []byte("X"), []byte("Y"), []byte("Z")} + + rootB, _ := merkle.New(leavesB) + _, proofsA := merkle.New(leavesA) + + for i, leaf := range leavesA { + assert.False(t, + proofsA[i].Verify(&rootB, leaf, uint32(i)), + "proof from tree A should not verify against tree B root", + ) + } +} diff --git a/consensus/propeller/padding.go b/consensus/propeller/padding.go new file mode 100644 index 0000000000..3dca643e93 --- /dev/null +++ b/consensus/propeller/padding.go @@ -0,0 +1,59 @@ +package propeller + +import ( + "encoding/binary" + "fmt" +) + +// PadMessage prepends an unsigned varint-encoded length to the message and +// pads the result with zeros so the total length is divisible by +// 2*numDataShards. +// +// The varint prefix lets the receiver recover the exact original message +// length after reconstruction. The zero-padding ensures the padded message +// can be evenly split into numDataShards pieces, which is required by +// Reed-Solomon encoding (all shards must be equal length). +// +// Layout: [varint(len(msg))] [msg bytes] [zero padding] +func PadMessage(msg []byte, numDataShards int) []byte { + // Compute the varint-encoded length prefix. + var varintBuf [binary.MaxVarintLen64]byte + varintLen := binary.PutUvarint(varintBuf[:], uint64(len(msg))) + + unpaddedMsgLen := uint64(varintLen + len(msg)) + + // Round up to the next multiple of divisor. + divisor := uint64(2 * numDataShards) + paddedMsgLen := unpaddedMsgLen + if remainder := paddedMsgLen % divisor; remainder != 0 { + paddedMsgLen += divisor - remainder + } + + result := make([]byte, paddedMsgLen) + copy(result, varintBuf[:varintLen]) + copy(result[varintLen:], msg) + + return result +} + +// UnpadMessage performs the reverse operation to PadMessage: it reads the varint length prefix and +// extracts the original message bytes, discarding the zero padding. The slice returned uses the +// input's backing array but re-sliced to start and finish on the original message. +// +// An error is returned if the varint is malformed or the encoded length exceeds the available data. +func UnpadMessage(padded []byte) ([]byte, error) { + msgLen, varintLen := binary.Uvarint(padded) + if varintLen <= 0 { + return nil, fmt.Errorf("invalid varint prefix in padded message: %d", varintLen) + } + + end := uint64(varintLen) + msgLen + if end > uint64(len(padded)) { + return nil, fmt.Errorf( + "varint length %d exceeds available data (have %d bytes after prefix)", + msgLen, len(padded)-varintLen, + ) + } + + return padded[varintLen:end], nil +} diff --git a/consensus/propeller/padding_test.go b/consensus/propeller/padding_test.go new file mode 100644 index 0000000000..bfb6d0cb2b --- /dev/null +++ b/consensus/propeller/padding_test.go @@ -0,0 +1,115 @@ +package propeller_test + +import ( + "testing" + + "github.com/NethermindEth/juno/consensus/propeller" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestPadMessage_RoundTrip(t *testing.T) { + tests := []struct { + name string + msg []byte + numDataShards int + }{ + { + name: "empty message, 1 shard", + msg: []byte{}, + numDataShards: 1, + }, + { + name: "small message, 1 shard", + msg: []byte("hello"), + numDataShards: 1, + }, + { + name: "small message, 3 shards", + msg: []byte("hello world"), + numDataShards: 3, + }, + { + name: "message exactly divisible", + msg: make([]byte, 6), // varint(6)=1 byte, total=7, divisor=2*1=2 -> pad to 8 + numDataShards: 1, + }, + { + name: "larger message, 10 shards", + msg: make([]byte, 1000), + numDataShards: 10, + }, + { + name: "single byte", + msg: []byte{0x42}, + numDataShards: 5, + }, + { + name: "large message requiring multi-byte varint", + msg: makeSequentialBytes(300), + numDataShards: 4, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + padded := propeller.PadMessage(tc.msg, tc.numDataShards) + + // Verify divisibility. + divisor := 2 * tc.numDataShards + assert.Equal(t, 0, len(padded)%divisor, + "padded length %d should be divisible by %d", len(padded), divisor) + + // Verify round-trip. + recovered, err := propeller.UnpadMessage(padded) + require.NoError(t, err) + assert.Equal(t, tc.msg, recovered) + }) + } +} + +func TestUnpadMessage_Errors(t *testing.T) { + tests := []struct { + name string + input []byte + wantErr string + }{ + { + name: "empty buffer", + input: []byte{}, + wantErr: "invalid varint", + }, + { + name: "truncated varint", + input: []byte{0x80}, // continuation bit set, no following byte + wantErr: "invalid varint", + }, + { + name: "length exceeds data", + input: append([]byte{100}, []byte("short")...), // varint 100, only 5 bytes + wantErr: "exceeds available data", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + _, err := propeller.UnpadMessage(tc.input) + require.ErrorContains(t, err, tc.wantErr) + }) + } +} + +func TestPadMessage_Size(t *testing.T) { + msg := []byte("ab") // 2 bytes + padded := propeller.PadMessage(msg, 3) // divisor = 6 + // varint(2) = 1 byte, payload = 3 bytes, next multiple of 6 = 6 + require.Len(t, padded, 6) +} + +func makeSequentialBytes(n int) []byte { + b := make([]byte, n) + for i := range b { + b[i] = byte(i) + } + return b +} diff --git a/consensus/propeller/processor.go b/consensus/propeller/processor.go new file mode 100644 index 0000000000..713d99237e --- /dev/null +++ b/consensus/propeller/processor.go @@ -0,0 +1,515 @@ +package propeller + +import ( + "context" + "errors" + "fmt" + "math/rand" + "sync" + "time" + + "github.com/NethermindEth/juno/consensus/propeller/timecache" + "github.com/NethermindEth/juno/utils/log" + "github.com/libp2p/go-libp2p/core/peer" + "go.uber.org/zap" +) + +type Event interface { + isEvent() +} + +type messageFinalized struct { + message []byte +} + +func (*messageFinalized) isEvent() {} + +type broadcastUnit struct { + unit *Unit + peers []peer.ID +} + +func (*broadcastUnit) isEvent() {} + +type broadcastMessage struct { + unit []Unit +} + +func (*broadcastMessage) isEvent() {} + +type unitWithSender struct { + unit *Unit + sender peer.ID +} + +type subprocessor struct { + scheduler *Scheduler + localPeer peer.ID + localShardIndex ShardIndex + + unitsChan <-chan unitWithSender + invalidUnitsChan chan<- invalidUnit + processingEvents chan<- Event + + validator UnitValidator +} + +func newSubprocessor( + publisher peer.ID, + scheduler *Scheduler, + localPeer peer.ID, + localShardIndex ShardIndex, + unitsChan <-chan unitWithSender, + invalidUnitsChan chan<- invalidUnit, +) subprocessor { + return subprocessor{ + scheduler: scheduler, + localPeer: localPeer, + localShardIndex: localShardIndex, + + unitsChan: unitsChan, + invalidUnitsChan: invalidUnitsChan, + + validator: NewValidator(publisher, scheduler), + } +} + +func (s *subprocessor) broadcastUnit(unit *Unit) { + index := 0 + peers := make([]peer.ID, len(s.scheduler.Peers())-2) + for _, peerCommittee := range s.scheduler.Peers() { + if peerCommittee.ID == unit.Publisher || peerCommittee.ID == s.localPeer { + continue + } + // todo(rdr): index out of range issue in this code + peers[index] = peerCommittee.ID + index += 1 + } + rand.Shuffle(len(peers), func(i, j int) { + peers[i], peers[j] = peers[j], peers[i] + }) + + s.processingEvents <- &broadcastUnit{ + unit: unit, + peers: peers, + } +} + +func (s *subprocessor) beforeMessageBuiltStage(ctx context.Context) ( + int, []byte, error, +) { + // Keep track of the units received + unitsReceived := make([]*Unit, s.scheduler.NumTotalShards()) + unitCount := 0 + + localShardWasBroadcast := false + + // todo(rdr): we are triggering message building (expensive) as soon as the bulid threshold is + // achieved, but it might be convenient to wait a few seconds to see if more messages + // will arrive. Although, that will mean we also need to validate any of those extra messages. + // The question is then: Do the cost of validating missing messages reduces greatly the cost + // of recovering them? Cases to consider: + // - Perfect network condition: a lot of bandwidth and everybody is good. Does receiving all + // all the missing messages and validating them is cheaper than recovering them? What's the + // performance difference? <- Write benchmark + // - Bad network conditions: does the time waiting but receiving no messages will + // cause to waste a few seconds were the build was already done + // - Bad messages: the remaining messages we are waiting for and hence we incur on the cost + // of validating them but we get no benefit and we don't reduce the cost of recovering them. + for unitCount != s.scheduler.BuildThreshold() { + select { + case <-ctx.Done(): + return 0, nil, ctx.Err() + case unitWithSender := <-s.unitsChan: + unit := unitWithSender.unit + sender := unitWithSender.sender + if err := s.validator.Validate(unit, sender); err != nil { + s.invalidUnitsChan <- invalidUnit{ + // todo(rdr): not sure if we need message key. + // We just want to penalize the sender + messageKey: extractKey(unit), + sender: sender, + error: err, + } + // if this is the first unit we are receiving, finish abruptly since + // it can be a DOS attack. + if unitCount == 0 { + return 0, nil, fmt.Errorf("couldn't validate first unit received: %w", err) + } + continue + } + + unitsReceived[int(unit.ShardIndex)] = unit + unitCount += 1 + + // broadcast as soon as I get my shard + if !localShardWasBroadcast && s.localShardIndex == unit.ShardIndex { + localShardWasBroadcast = true + s.broadcastUnit(unit) + } + } + } + + fullMessage, localShardData, localProof, err := ConstructMessageFromUnits( + unitsReceived, + s.localShardIndex, + s.scheduler.NumDataShards(), + s.scheduler.NumCodingShards(), + ) + if err != nil { + return 0, nil, err + } + + if !localShardWasBroadcast { + // We pick a unit at random to fill the common data between the two. All of these values + // have already been verified up top. + // todo(rdr): there is an issue where unit in 0 is not guaranteed to be non-nil + unit := unitsReceived[0] + localUnit := Unit{ + CommitteeID: unit.CommitteeID, + Publisher: unit.Publisher, + MessageRoot: unit.MessageRoot, + Nonce: unit.Nonce, + Signature: unit.Signature, + MerkleProof: localProof, + ShardIndex: s.localShardIndex, + ShardData: localShardData, + } + s.broadcastUnit(&localUnit) + unitCount += 1 + } + + return unitCount, fullMessage, nil +} + +//nolint:unparam // message will be used once the receive-stage forwarding is wired up. +func (s *subprocessor) beforeMessageReceivedStage( + ctx context.Context, + unitCount int, + message []byte, +) error { + receiveThreshold := s.scheduler.ReceiveThreshold() + for unitCount != receiveThreshold { + select { + case <-ctx.Done(): + return ctx.Err() + case unitWithSender := <-s.unitsChan: + unit := unitWithSender.unit + sender := unitWithSender.sender + if err := s.validator.Validate(unit, sender); err != nil { + s.invalidUnitsChan <- invalidUnit{ + messageKey: extractKey(unit), + sender: sender, + error: err, + } + continue + } + if unit.ShardIndex == s.localShardIndex { + continue + } + unitCount += 1 + } + } + + // todo(rdr): if we are here it means the message has been received. + // forward it to (proc/engine/service <- one of these) + + return nil +} + +// todo(rdr): we need to be sure to test both cases: +// - when built threshold == received threshold +// - when build threshold != received threshold +func (s *subprocessor) Run( + ctx context.Context, +) error { + // The Run function works in two main loops depending on the stage we are in. + // First stage is before we can build the message, in which we receive messsages + // until we have enough to build the full messsage. The local shard will be broadcasted + // during this stage. + // Second stage starts with the full message built and waits until we receive enough + // messages to reach the received threshold, which guarantees that at least 2/3 of the + // network is non faulty. Once there, Broadcast the rebuilt message and finishes + + unitCount, message, err := s.beforeMessageBuiltStage(ctx) + if err != nil { + return err + } + + return s.beforeMessageReceivedStage(ctx, unitCount, message) +} + +// messageKey are a copy of the values of a propeller unit that uniquely identifies it +// all unit that carries shard of the same message will have the same "key" fields +type messageKey struct { + CommitteeID CommitteeID + Publisher peer.ID + Root MessageRoot + Nonce Nonce +} + +func extractKey(unit *Unit) messageKey { + return messageKey{ + CommitteeID: unit.CommitteeID, + Publisher: unit.Publisher, + Root: unit.MessageRoot, + Nonce: unit.Nonce, + } +} + +func (mk *messageKey) String() string { + return fmt.Sprintf("%+v", *mk) +} + +// invalidUnit is sent when a unit identified with `messageKey` failed validation with +// error `error` +type invalidUnit struct { + messageKey messageKey + sender peer.ID + error error +} + +// finalizedSubprocessor is sent once a subprocessor finalizes processing a message +// identified with `messageKey`. If it finalized on error the `error` field will be non-nil +type finalizedSubprocessor struct { + messageKey messageKey + error error +} + +type concurrentTasksBounds struct { + maxWorkers uint64 + maxWorkersPerPublisher uint64 +} + +// Processor handles all concurrent work on message processing +type Processor struct { + // to avoid processing units already finalized + finalized *timecache.TimeCache[messageKey] + + subProcessors map[messageKey]chan<- unitWithSender + // channel through which subprocessors signal they have finalized execution + subProcessorsFinalized chan finalizedSubprocessor + // channel through which subprocessor share units that failed validation + invalidUnits chan invalidUnit + // channel through which important events are shared + processingEvents chan<- Event + + // track current open and closed tasks to avoid resource starvation + mu sync.Mutex + publisherTasks map[peer.ID]uint64 + tasks uint64 + // config inherited from Engine + localPeer peer.ID + timeout time.Duration + concurrentTasksBounds concurrentTasksBounds + logger log.StructuredLogger +} + +// finalizedCacheSize bounds the number of recently-finalized message keys retained +// to avoid re-processing units belonging to messages already completed. +const finalizedCacheSize = 2048 + +func NewProcessor(localPeer peer.ID, config *Config) (*Processor, <-chan Event) { + timeout := config.StaleMessageTimeout + processingEvents := make(chan Event) + + return &Processor{ + finalized: timecache.New[messageKey](finalizedCacheSize, timeout), + + subProcessors: make(map[messageKey]chan<- unitWithSender), + subProcessorsFinalized: make(chan finalizedSubprocessor), + invalidUnits: make(chan invalidUnit), + processingEvents: processingEvents, + + mu: sync.Mutex{}, + publisherTasks: make(map[peer.ID]uint64), + tasks: 0, + + localPeer: localPeer, + timeout: timeout, + // todo(rdr): set this ones based on the config (or some consts?) + concurrentTasksBounds: concurrentTasksBounds{ + // dummy values for now + maxWorkers: 1000, + maxWorkersPerPublisher: 250, + }, + }, processingEvents +} + +func (p *Processor) Run(ctx context.Context) { + for { + select { + case <-ctx.Done(): + return + case finalizedSubP := <-p.subProcessorsFinalized: + if finalizedSubP.error != nil { + p.logger.Error( + "subprocessor finalized with error", + zap.String("message key", finalizedSubP.messageKey.String()), + zap.Error(finalizedSubP.error), + ) + } else { + p.logger.Info( + "subprocessor finalized", + zap.String("message key", finalizedSubP.messageKey.String()), + ) + } + p.finalize(&finalizedSubP.messageKey) + + case invalidUnit := <-p.invalidUnits: + p.logger.Error( + "unit validation failed", + zap.String("message key", invalidUnit.messageKey.String()), + zap.Error(invalidUnit.error), + ) + // todo(rdr): should we mark sender to penalize? + } + } +} + +// ProcessMessage validates and process the received `unit` non-blockingly. It returns an +// error if the unit couldn't start processing. +func (p *Processor) ProcessMessage( + ctx context.Context, + unit *Unit, + sender peer.ID, + scheduler *Scheduler, +) error { + key := extractKey(unit) + if p.finalized.Get(&key) { + return nil + } + + // todo(rdr): currently on a single go-routine the validation is performed and then the unit + // is processed. This could be divided into: + // - A validation task that performs validation (go routine A) + // - A processing task that process the message (go routine B) + // - Then A will send the correct units to B + // This means that when many messages are received in quick succession, they can be validated + // non blockingly. This also means we have two go routines for sub processor rather than just + // a single one. + unitChan, err := p.subprocessorChannel(ctx, &key, scheduler) + if err != nil { + return fmt.Errorf("couldn't get processor channel for key: %w", err) + } + + select { + case unitChan <- unitWithSender{unit: unit, sender: sender}: + return nil + default: + } + + return errors.New("dropping shard, processor channel full") +} + +// createSubprocessor creates a go-routine (subprocessor) that handles all the processing of the +// messages identified with the given `messageKey`. +// It returns a channel through which this processor can be given units to process +// todo(rdr): I would like not to create a channel for everytime we have a different messageKey +// since that can be a bit rough to the GC, better to have a pool of them. Benchmarks will give +// the final word +func (p *Processor) createSubprocessor( + ctx context.Context, + key *messageKey, + scheduler *Scheduler, +) (chan<- unitWithSender, error) { + localShardIndex, err := scheduler.ShardIndexForPublisher(key.Publisher) + if err != nil { + return nil, fmt.Errorf( + "cannot get local shard index for publisher %s: %w", key.Publisher, err, + ) + } + + err = p.increaseTasks(key.Publisher) + if err != nil { + return nil, err + } + + // create communication channel + unitChan := make(chan unitWithSender) + p.subProcessors[*key] = unitChan + + // launch subprocessor + ctxWithTimeout, cancel := context.WithTimeout(ctx, p.timeout) + // todo(rdr): passing to avoid closures. Does it makes sense? + // need to learn more how closures work in Go if it makes any difference + // in performance. + // todo(rdr): should I pass p.chan as an argument? + go func( + ctx context.Context, + key messageKey, + scheduler *Scheduler, + localShardIndex ShardIndex, + unitChan <-chan unitWithSender, + ) { + defer cancel() + subProcessor := newSubprocessor( + key.Publisher, scheduler, p.localPeer, localShardIndex, unitChan, p.invalidUnits, + ) + err := subProcessor.Run(ctx) + p.subProcessorsFinalized <- finalizedSubprocessor{ + messageKey: key, + error: err, + } + }(ctxWithTimeout, *key, scheduler, localShardIndex, unitChan) + + return unitChan, nil +} + +// Given a message key it returns a channel that communicates with the subprocessor +// handling this specific message key. +func (p *Processor) subprocessorChannel( + ctx context.Context, + key *messageKey, + scheduler *Scheduler, +) (chan<- unitWithSender, error) { + unitChan, ok := p.subProcessors[*key] + if ok { + return unitChan, nil + } + + unitChan, err := p.createSubprocessor(ctx, key, scheduler) + if err != nil { + return nil, fmt.Errorf("creating new subprocessor: %w", err) + } + return unitChan, nil +} + +func (p *Processor) finalize(key *messageKey) { + p.decreaseTask(key.Publisher) + delete(p.subProcessors, *key) + p.finalized.Add(key) +} + +func (p *Processor) increaseTasks(publisher peer.ID) error { + p.mu.Lock() + defer p.mu.Unlock() + + if p.publisherTasks[publisher] == p.concurrentTasksBounds.maxWorkersPerPublisher { + return fmt.Errorf( + "tasks per publisher exceeded (max: %d): %s", + p.publisherTasks[publisher], + publisher, + ) + } + + if p.tasks == p.concurrentTasksBounds.maxWorkers { + return fmt.Errorf( + "max tasks that the processor can handle has been reached (max: %d)", + p.tasks, + ) + } + + p.publisherTasks[publisher] += 1 + p.tasks += 1 + + return nil +} + +func (p *Processor) decreaseTask(publisher peer.ID) { + p.mu.Lock() + defer p.mu.Unlock() + + p.publisherTasks[publisher] -= 1 + p.tasks -= 1 +} diff --git a/consensus/propeller/processor_test.go b/consensus/propeller/processor_test.go new file mode 100644 index 0000000000..bb61865bd4 --- /dev/null +++ b/consensus/propeller/processor_test.go @@ -0,0 +1 @@ +package propeller_test diff --git a/consensus/propeller/propeller.go b/consensus/propeller/propeller.go new file mode 100644 index 0000000000..8e38c89dbf --- /dev/null +++ b/consensus/propeller/propeller.go @@ -0,0 +1,200 @@ +package propeller + +import ( + "bytes" + "context" + "io" + + pb "github.com/NethermindEth/juno/consensus/propeller/proto" + "github.com/NethermindEth/juno/utils/log" + "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/host" + "github.com/libp2p/go-libp2p/core/network" + "github.com/libp2p/go-libp2p/core/peer" + "go.uber.org/zap" + "google.golang.org/protobuf/proto" +) + +// This would represent the propeller service that glues the whole +// thing to p2p. Thing is, I've no clue how to do that. +type Service any + +type propellerService struct { + // P2P config + host host.Host + // Internal config + config Config + logger log.Logger + // Propeller communication + engine *Engine + cmdCh chan<- engineCommand + eventsCh <-chan Event + // External communication + messageRecv chan []byte +} + +func New( + host host.Host, + privKey crypto.PrivKey, + config *Config, + logger log.Logger, +) Service { + engine, cmdCh, eventsCh := NewEngine( + privKey, + config, + logger, + ) + + return &propellerService{ + host: host, + engine: engine, + cmdCh: cmdCh, + eventsCh: eventsCh, + config: *config, + logger: logger, + } +} + +func (s *propellerService) receiveUnits(stream network.Stream) { + defer stream.Close() + + sender := stream.Conn().RemotePeer() + + reader := io.LimitReader(stream, int64(s.config.MaxWireMessageSize)) + + var buf bytes.Buffer + _, err := buf.ReadFrom(reader) + if err != nil { + s.logger.Debug( + "error reading inbound propeller stream", + zap.Stringer("peer", sender), + zap.Error(err), + ) + } + + var batch pb.PropellerUnitBatch + err = proto.Unmarshal(buf.Bytes(), &batch) + if err != nil { + s.logger.Debug( + "error unmarshalling propeller batch", + zap.Stringer("peer", sender), + zap.Error(err), + ) + } + + for _, protoUnit := range batch.GetBatch() { + unit, err := UnitFromProto(protoUnit) + if err != nil { + s.logger.Warn("received invalid unit", zap.Error(err)) + // todo(rdr): penalize sender? + // If we do it here then it means it shouldn't be handled at + // subP or Processor level, and all should be handled here, + // which means, every invalid unit should be handled at Service + // level. To be determined yet. + continue + } + // send unit to engine + s.cmdCh <- processUnit{ + &unit, + sender, + } + } +} + +func (s *propellerService) broadcastUnit(ctx context.Context, unit *Unit, peers []peer.ID) { + batch := &pb.PropellerUnitBatch{ + Batch: []*pb.PropellerUnit{unit.ToProto()}, + } + data, err := proto.Marshal(batch) + if err != nil { + // todo(rdr): log the error? What if this cannot get it done? + // Our batch is correct unless there is an internal bug + panic(err) + } + + for _, p := range peers { + err = s.sendToPeer(ctx, p, data) + if err != nil { + // Why would there be any error + // Based on the error type, what should we do + panic(err) + } + } +} + +func (s *propellerService) sendToPeer(ctx context.Context, p peer.ID, data []byte) error { + stream, err := s.host.NewStream(ctx, p, s.config.StreamProtocol) + if err != nil { + return err + } + defer stream.Close() + + _, err = stream.Write(data) + return err +} + +func (s *propellerService) broadcastMessage(ctx context.Context, units []Unit) { +} + +func (s *propellerService) handleEvent(ctx context.Context, event Event) { + switch event := event.(type) { + case *messageFinalized: + // if the message is finalized it should have a receive + s.messageRecv <- event.message + case *broadcastUnit: + s.broadcastUnit(ctx, event.unit, event.peers) + case *broadcastMessage: + s.broadcastMessage(ctx, event.unit) + } +} + +func (s *propellerService) Run(ctx context.Context) error { + // Start engine service in the background + go func() { + err := s.engine.Run(ctx) + if err != nil { + s.logger.Error("shutting down propeller engine", zap.Error(err)) + return + } + s.logger.Info("shutting down propeller engine") + }() + + // Subscribe to receiving certain topics + s.host.SetStreamHandler(s.config.StreamProtocol, s.receiveUnits) + defer s.host.RemoveStreamHandler(s.config.StreamProtocol) + + // Handle Engine outputs + for { + // todo(rdr): handle the engines output such as units to broadcast + select { + case <-ctx.Done(): + return ctx.Err() + case event := <-s.eventsCh: + s.handleEvent(ctx, event) + } + } +} + +// todo(rdr): I am not sure of the Propeller <-> Engine separation... +// TBD how it looks like or if there should be any in the future + +func (s *propellerService) Broadcast(committeeID *CommitteeID, msg []byte) error { + return s.engine.Broadcast(committeeID, msg) +} + +func (s *propellerService) Recv() <-chan []byte { + return s.messageRecv +} + +func (s *propellerService) RegisterCommittee( + committeeID *CommitteeID, + peers []PeerCommittee, + // todo(rdr): peersKeys is something I don't know how to set correctly yet + peersKeys []*StakerID, +) error { + return s.engine.RegisterCommittee(committeeID, peers, peersKeys) +} + +func (s *propellerService) UnregisterCommittee(comitteeID *CommitteeID) { + s.engine.UnregisterCommittee(comitteeID) +} diff --git a/consensus/propeller/propeller_test.go b/consensus/propeller/propeller_test.go new file mode 100644 index 0000000000..bb61865bd4 --- /dev/null +++ b/consensus/propeller/propeller_test.go @@ -0,0 +1 @@ +package propeller_test diff --git a/consensus/propeller/proto/README.md b/consensus/propeller/proto/README.md new file mode 100644 index 0000000000..8daac27a4c --- /dev/null +++ b/consensus/propeller/proto/README.md @@ -0,0 +1,25 @@ +# Generating Go code from propeller.proto + +The `propeller.proto` file imports `p2p/proto/common.proto` from the upstream +[starknet-p2p-specs](https://github.com/starknet-io/starknet-p2p-specs) repository. +Since the upstream module is not on the Buf Schema Registry, we use `buf export` + `protoc` directly. + +From the project root: + +```bash +# 1. Export upstream .proto sources (needed for import resolution) +buf export \ + "https://github.com/starknet-io/starknet-p2p-specs.git#branch=bcfa353a169c859e4d5d97757caccbe76f75bc06,depth=1" \ + -o /tmp/starknet-p2p-specs-proto + +# 2. Generate Go code +protoc \ + --go_out=. \ + --go_opt=paths=source_relative \ + --go_opt=Mp2p/proto/common.proto=github.com/starknet-io/starknet-p2p-specs/p2p/proto/common \ + -I /tmp/starknet-p2p-specs-proto \ + -I . \ + consensus/propeller/proto/propeller.proto +``` + +This produces `propeller.pb.go` in this directory. diff --git a/consensus/propeller/proto/propeller.pb.go b/consensus/propeller/proto/propeller.pb.go new file mode 100644 index 0000000000..fdf371adb1 --- /dev/null +++ b/consensus/propeller/proto/propeller.pb.go @@ -0,0 +1,406 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// versions: +// protoc-gen-go v1.36.9 +// protoc v7.34.1 +// source: consensus/propeller/proto/propeller.proto + +package proto + +import ( + common "github.com/starknet-io/starknet-p2p-specs/p2p/proto/common" + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" + reflect "reflect" + sync "sync" + unsafe "unsafe" +) + +const ( + // Verify that this generated code is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(20 - protoimpl.MinVersion) + // Verify that runtime/protoimpl is sufficiently up-to-date. + _ = protoimpl.EnforceVersion(protoimpl.MaxVersion - 20) +) + +// A Merkle proof consisting of sibling hashes used to verify that a leaf belongs to a Merkle tree. +// Each sibling hash is 32 bytes (SHA-256). The siblings are ordered from leaf level to root level. +type MerkleProof struct { + state protoimpl.MessageState `protogen:"open.v1"` + // The sibling hashes needed to reconstruct the path from the leaf to the root. + // Each hash is 32 bytes. + Siblings []*common.Hash256 `protobuf:"bytes,1,rep,name=siblings,proto3" json:"siblings,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *MerkleProof) Reset() { + *x = MerkleProof{} + mi := &file_consensus_propeller_proto_propeller_proto_msgTypes[0] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *MerkleProof) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*MerkleProof) ProtoMessage() {} + +func (x *MerkleProof) ProtoReflect() protoreflect.Message { + mi := &file_consensus_propeller_proto_propeller_proto_msgTypes[0] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use MerkleProof.ProtoReflect.Descriptor instead. +func (*MerkleProof) Descriptor() ([]byte, []int) { + return file_consensus_propeller_proto_propeller_proto_rawDescGZIP(), []int{0} +} + +func (x *MerkleProof) GetSiblings() []*common.Hash256 { + if x != nil { + return x.Siblings + } + return nil +} + +// A single erasure-coded fragment of the original message. +type Shard struct { + state protoimpl.MessageState `protogen:"open.v1"` + Data []byte `protobuf:"bytes,1,opt,name=data,proto3" json:"data,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *Shard) Reset() { + *x = Shard{} + mi := &file_consensus_propeller_proto_propeller_proto_msgTypes[1] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *Shard) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*Shard) ProtoMessage() {} + +func (x *Shard) ProtoReflect() protoreflect.Message { + mi := &file_consensus_propeller_proto_propeller_proto_msgTypes[1] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use Shard.ProtoReflect.Descriptor instead. +func (*Shard) Descriptor() ([]byte, []int) { + return file_consensus_propeller_proto_propeller_proto_rawDescGZIP(), []int{1} +} + +func (x *Shard) GetData() []byte { + if x != nil { + return x.Data + } + return nil +} + +// A collection of shards assigned to a single peer. +// The proto-encoded bytes of this message are used as Merkle tree leaf data, +// ensuring cross-language determinism. +type ShardsOfPeer struct { + state protoimpl.MessageState `protogen:"open.v1"` + Shards []*Shard `protobuf:"bytes,1,rep,name=shards,proto3" json:"shards,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *ShardsOfPeer) Reset() { + *x = ShardsOfPeer{} + mi := &file_consensus_propeller_proto_propeller_proto_msgTypes[2] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *ShardsOfPeer) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*ShardsOfPeer) ProtoMessage() {} + +func (x *ShardsOfPeer) ProtoReflect() protoreflect.Message { + mi := &file_consensus_propeller_proto_propeller_proto_msgTypes[2] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use ShardsOfPeer.ProtoReflect.Descriptor instead. +func (*ShardsOfPeer) Descriptor() ([]byte, []int) { + return file_consensus_propeller_proto_propeller_proto_rawDescGZIP(), []int{2} +} + +func (x *ShardsOfPeer) GetShards() []*Shard { + if x != nil { + return x.Shards + } + return nil +} + +// A single unit in the Propeller protocol containing shards of erasure-coded data +// along with cryptographic proofs for verification. +type PropellerUnit struct { + state protoimpl.MessageState `protogen:"open.v1"` + // The shards assigned to this unit's peer. + Shards *ShardsOfPeer `protobuf:"bytes,1,opt,name=shards,proto3" json:"shards,omitempty"` + // The position of this shard in the erasure coding scheme. + Index uint64 `protobuf:"varint,2,opt,name=index,proto3" json:"index,omitempty"` + // The Merkle root of all shards, used to verify shard integrity. + MerkleRoot *common.Hash256 `protobuf:"bytes,3,opt,name=merkle_root,json=merkleRoot,proto3" json:"merkle_root,omitempty"` + // The Merkle proof that this shard belongs to the tree with the given root. + MerkleProof *MerkleProof `protobuf:"bytes,4,opt,name=merkle_proof,json=merkleProof,proto3" json:"merkle_proof,omitempty"` + // The peer ID of the original publisher who created and signed this unit. + Publisher *common.PeerID `protobuf:"bytes,5,opt,name=publisher,proto3" json:"publisher,omitempty"` + // Cryptographic signature from the publisher over the merkle_root. + Signature []byte `protobuf:"bytes,6,opt,name=signature,proto3" json:"signature,omitempty"` + // Committee identifier for multiplexing different message streams. + CommitteeId *common.Hash256 `protobuf:"bytes,7,opt,name=committee_id,json=committeeId,proto3" json:"committee_id,omitempty"` + // A strictly increasing number, to avoid replays. + // Current implementation is nanoseconds since UNIX_EPOCH. + // TODO(guyn): CRITICAL: protect against replay attacks using this timestamp + Nonce uint64 `protobuf:"varint,8,opt,name=nonce,proto3" json:"nonce,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *PropellerUnit) Reset() { + *x = PropellerUnit{} + mi := &file_consensus_propeller_proto_propeller_proto_msgTypes[3] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *PropellerUnit) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*PropellerUnit) ProtoMessage() {} + +func (x *PropellerUnit) ProtoReflect() protoreflect.Message { + mi := &file_consensus_propeller_proto_propeller_proto_msgTypes[3] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use PropellerUnit.ProtoReflect.Descriptor instead. +func (*PropellerUnit) Descriptor() ([]byte, []int) { + return file_consensus_propeller_proto_propeller_proto_rawDescGZIP(), []int{3} +} + +func (x *PropellerUnit) GetShards() *ShardsOfPeer { + if x != nil { + return x.Shards + } + return nil +} + +func (x *PropellerUnit) GetIndex() uint64 { + if x != nil { + return x.Index + } + return 0 +} + +func (x *PropellerUnit) GetMerkleRoot() *common.Hash256 { + if x != nil { + return x.MerkleRoot + } + return nil +} + +func (x *PropellerUnit) GetMerkleProof() *MerkleProof { + if x != nil { + return x.MerkleProof + } + return nil +} + +func (x *PropellerUnit) GetPublisher() *common.PeerID { + if x != nil { + return x.Publisher + } + return nil +} + +func (x *PropellerUnit) GetSignature() []byte { + if x != nil { + return x.Signature + } + return nil +} + +func (x *PropellerUnit) GetCommitteeId() *common.Hash256 { + if x != nil { + return x.CommitteeId + } + return nil +} + +func (x *PropellerUnit) GetNonce() uint64 { + if x != nil { + return x.Nonce + } + return 0 +} + +// A batch of PropellerUnits for efficient transmission. +type PropellerUnitBatch struct { + state protoimpl.MessageState `protogen:"open.v1"` + Batch []*PropellerUnit `protobuf:"bytes,1,rep,name=batch,proto3" json:"batch,omitempty"` + unknownFields protoimpl.UnknownFields + sizeCache protoimpl.SizeCache +} + +func (x *PropellerUnitBatch) Reset() { + *x = PropellerUnitBatch{} + mi := &file_consensus_propeller_proto_propeller_proto_msgTypes[4] + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + ms.StoreMessageInfo(mi) +} + +func (x *PropellerUnitBatch) String() string { + return protoimpl.X.MessageStringOf(x) +} + +func (*PropellerUnitBatch) ProtoMessage() {} + +func (x *PropellerUnitBatch) ProtoReflect() protoreflect.Message { + mi := &file_consensus_propeller_proto_propeller_proto_msgTypes[4] + if x != nil { + ms := protoimpl.X.MessageStateOf(protoimpl.Pointer(x)) + if ms.LoadMessageInfo() == nil { + ms.StoreMessageInfo(mi) + } + return ms + } + return mi.MessageOf(x) +} + +// Deprecated: Use PropellerUnitBatch.ProtoReflect.Descriptor instead. +func (*PropellerUnitBatch) Descriptor() ([]byte, []int) { + return file_consensus_propeller_proto_propeller_proto_rawDescGZIP(), []int{4} +} + +func (x *PropellerUnitBatch) GetBatch() []*PropellerUnit { + if x != nil { + return x.Batch + } + return nil +} + +var File_consensus_propeller_proto_propeller_proto protoreflect.FileDescriptor + +const file_consensus_propeller_proto_propeller_proto_rawDesc = "" + + "\n" + + ")consensus/propeller/proto/propeller.proto\x1a\x16p2p/proto/common.proto\"3\n" + + "\vMerkleProof\x12$\n" + + "\bsiblings\x18\x01 \x03(\v2\b.Hash256R\bsiblings\"\x1b\n" + + "\x05Shard\x12\x12\n" + + "\x04data\x18\x01 \x01(\fR\x04data\".\n" + + "\fShardsOfPeer\x12\x1e\n" + + "\x06shards\x18\x01 \x03(\v2\x06.ShardR\x06shards\"\xb0\x02\n" + + "\rPropellerUnit\x12%\n" + + "\x06shards\x18\x01 \x01(\v2\r.ShardsOfPeerR\x06shards\x12\x14\n" + + "\x05index\x18\x02 \x01(\x04R\x05index\x12)\n" + + "\vmerkle_root\x18\x03 \x01(\v2\b.Hash256R\n" + + "merkleRoot\x12/\n" + + "\fmerkle_proof\x18\x04 \x01(\v2\f.MerkleProofR\vmerkleProof\x12%\n" + + "\tpublisher\x18\x05 \x01(\v2\a.PeerIDR\tpublisher\x12\x1c\n" + + "\tsignature\x18\x06 \x01(\fR\tsignature\x12+\n" + + "\fcommittee_id\x18\a \x01(\v2\b.Hash256R\vcommitteeId\x12\x14\n" + + "\x05nonce\x18\b \x01(\x04R\x05nonce\":\n" + + "\x12PropellerUnitBatch\x12$\n" + + "\x05batch\x18\x01 \x03(\v2\x0e.PropellerUnitR\x05batchB9Z7github.com/NethermindEth/juno/consensus/propeller/protob\x06proto3" + +var ( + file_consensus_propeller_proto_propeller_proto_rawDescOnce sync.Once + file_consensus_propeller_proto_propeller_proto_rawDescData []byte +) + +func file_consensus_propeller_proto_propeller_proto_rawDescGZIP() []byte { + file_consensus_propeller_proto_propeller_proto_rawDescOnce.Do(func() { + file_consensus_propeller_proto_propeller_proto_rawDescData = protoimpl.X.CompressGZIP(unsafe.Slice(unsafe.StringData(file_consensus_propeller_proto_propeller_proto_rawDesc), len(file_consensus_propeller_proto_propeller_proto_rawDesc))) + }) + return file_consensus_propeller_proto_propeller_proto_rawDescData +} + +var file_consensus_propeller_proto_propeller_proto_msgTypes = make([]protoimpl.MessageInfo, 5) +var file_consensus_propeller_proto_propeller_proto_goTypes = []any{ + (*MerkleProof)(nil), // 0: MerkleProof + (*Shard)(nil), // 1: Shard + (*ShardsOfPeer)(nil), // 2: ShardsOfPeer + (*PropellerUnit)(nil), // 3: PropellerUnit + (*PropellerUnitBatch)(nil), // 4: PropellerUnitBatch + (*common.Hash256)(nil), // 5: Hash256 + (*common.PeerID)(nil), // 6: PeerID +} +var file_consensus_propeller_proto_propeller_proto_depIdxs = []int32{ + 5, // 0: MerkleProof.siblings:type_name -> Hash256 + 1, // 1: ShardsOfPeer.shards:type_name -> Shard + 2, // 2: PropellerUnit.shards:type_name -> ShardsOfPeer + 5, // 3: PropellerUnit.merkle_root:type_name -> Hash256 + 0, // 4: PropellerUnit.merkle_proof:type_name -> MerkleProof + 6, // 5: PropellerUnit.publisher:type_name -> PeerID + 5, // 6: PropellerUnit.committee_id:type_name -> Hash256 + 3, // 7: PropellerUnitBatch.batch:type_name -> PropellerUnit + 8, // [8:8] is the sub-list for method output_type + 8, // [8:8] is the sub-list for method input_type + 8, // [8:8] is the sub-list for extension type_name + 8, // [8:8] is the sub-list for extension extendee + 0, // [0:8] is the sub-list for field type_name +} + +func init() { file_consensus_propeller_proto_propeller_proto_init() } +func file_consensus_propeller_proto_propeller_proto_init() { + if File_consensus_propeller_proto_propeller_proto != nil { + return + } + type x struct{} + out := protoimpl.TypeBuilder{ + File: protoimpl.DescBuilder{ + GoPackagePath: reflect.TypeOf(x{}).PkgPath(), + RawDescriptor: unsafe.Slice(unsafe.StringData(file_consensus_propeller_proto_propeller_proto_rawDesc), len(file_consensus_propeller_proto_propeller_proto_rawDesc)), + NumEnums: 0, + NumMessages: 5, + NumExtensions: 0, + NumServices: 0, + }, + GoTypes: file_consensus_propeller_proto_propeller_proto_goTypes, + DependencyIndexes: file_consensus_propeller_proto_propeller_proto_depIdxs, + MessageInfos: file_consensus_propeller_proto_propeller_proto_msgTypes, + }.Build() + File_consensus_propeller_proto_propeller_proto = out.File + file_consensus_propeller_proto_propeller_proto_goTypes = nil + file_consensus_propeller_proto_propeller_proto_depIdxs = nil +} diff --git a/consensus/propeller/proto/propeller.proto b/consensus/propeller/proto/propeller.proto new file mode 100644 index 0000000000..a45dc7ed19 --- /dev/null +++ b/consensus/propeller/proto/propeller.proto @@ -0,0 +1,54 @@ +syntax = "proto3"; + +import "p2p/proto/common.proto"; + +option go_package = "github.com/NethermindEth/juno/consensus/propeller/proto"; + +// A Merkle proof consisting of sibling hashes used to verify that a leaf belongs to a Merkle tree. +// Each sibling hash is 32 bytes (SHA-256). The siblings are ordered from leaf level to root level. +message MerkleProof { + // The sibling hashes needed to reconstruct the path from the leaf to the root. + // Each hash is 32 bytes. + repeated Hash256 siblings = 1; +} + +// A single erasure-coded fragment of the original message. +message Shard { + bytes data = 1; +} + +// A collection of shards assigned to a single peer. +// The proto-encoded bytes of this message are used as Merkle tree leaf data, +// ensuring cross-language determinism. +message ShardsOfPeer { + repeated Shard shards = 1; +} + +// A single unit in the Propeller protocol containing shards of erasure-coded data +// along with cryptographic proofs for verification. +message PropellerUnit { + // The shards assigned to this unit's peer. + ShardsOfPeer shards = 1; + // The position of this shard in the erasure coding scheme. + uint64 index = 2; + // The Merkle root of all shards, used to verify shard integrity. + Hash256 merkle_root = 3; + // The Merkle proof that this shard belongs to the tree with the given root. + MerkleProof merkle_proof = 4; + // The peer ID of the original publisher who created and signed this unit. + PeerID publisher = 5; + // Cryptographic signature from the publisher over the merkle_root. + bytes signature = 6; + // Committee identifier for multiplexing different message streams. + Hash256 committee_id = 7; + // A strictly increasing number, to avoid replays. + // Current implementation is nanoseconds since UNIX_EPOCH. + // TODO(guyn): CRITICAL: protect against replay attacks using this timestamp + uint64 nonce = 8; +} + +// A batch of PropellerUnits for efficient transmission. +message PropellerUnitBatch { + repeated PropellerUnit batch = 1; +} + diff --git a/consensus/propeller/reedsolomon/reedsolomon.go b/consensus/propeller/reedsolomon/reedsolomon.go new file mode 100644 index 0000000000..78a7b9da77 --- /dev/null +++ b/consensus/propeller/reedsolomon/reedsolomon.go @@ -0,0 +1,79 @@ +package reedsolomon + +import ( + "errors" + "fmt" + + "github.com/klauspost/reedsolomon" +) + +// EncodeData generates the coding shards usign Reed-Solomon +// erasure codes. Receives the data, amount of shards and parity number. +// It will return the Reed Solomon encoding where the first `numDataShards` +// `[]byte` slices will be occupied by the original data. The remaining `parity` +// `[]byte` slices will contain the coding shards. +func EncodeData( + data []byte, + numDataShards, + parity int, +) ([][]byte, error) { + if len(data) == 0 { + return nil, errors.New("received empty data") + } + + encoder, err := reedsolomon.New(numDataShards, parity) + if err != nil { + return nil, fmt.Errorf("creating Reed-Solomon encoder: %w", err) + } + + split, err := encoder.Split(data) + if err != nil { + return nil, fmt.Errorf("splitting the data into shards: %w", err) + } + + err = encoder.Encode(split) + if err != nil { + return nil, fmt.Errorf("encoding the data shards: %w", err) + } + + return split, nil +} + +// RecoverData restores the missing data using Reed-Solomon erasure codes. +// There cannot be more than `parity` shards missing otherwise the recover will fail. +// Data that is considered missing needs to be marked as `nil`. Returns the recovered data. +// The input data shards well be modified in place. +func RecoverData( + shards [][]byte, + numDataShards, + parity int, +) ([][]byte, error) { + if len(shards) == 0 { + return nil, errors.New("no data shards provided") + } + + // todo(rdr): numDataShards can be inferred by getting the length of the shards + decoder, err := reedsolomon.New(numDataShards, parity) + if err != nil { + return nil, fmt.Errorf("creating Reed-Solomon decoder: %w", err) + } + + // todo(rdr): this is a slow approach where we are reconstructing parity shards as + // well. This is safe because at the end we can verify that it is correct. We might + // want to speed this up using `ReconstructData` with no `Verify` which should be 3x faster. + err = decoder.Reconstruct(shards) + if err != nil { + return nil, fmt.Errorf("recovering the data shards: %w", err) + } + + correct, err := decoder.Verify(shards) + if err != nil { + return nil, fmt.Errorf("verifying the data shards: %w", err) + } + + if !correct { + return nil, errors.New("data shard failed verification") + } + + return shards, nil +} diff --git a/consensus/propeller/reedsolomon/reedsolomon_test.go b/consensus/propeller/reedsolomon/reedsolomon_test.go new file mode 100644 index 0000000000..dc1dc27050 --- /dev/null +++ b/consensus/propeller/reedsolomon/reedsolomon_test.go @@ -0,0 +1,282 @@ +package reedsolomon_test + +import ( + "bytes" + "crypto/rand" + "slices" + "testing" + + "github.com/NethermindEth/juno/consensus/propeller/reedsolomon" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestEncodeData(t *testing.T) { + requireEqualLength := func(t *testing.T, fragments [][]byte) { + length := len(fragments[0]) + for i := range fragments { + require.Len(t, fragments[i], length) + } + } + + requireEqualPrefix := func(t *testing.T, expected []byte, fragments [][]byte) { + actual := bytes.Join(fragments, nil) + require.Truef( + t, + bytes.HasPrefix(actual, expected), + "expected to get prefix: %s in %s", + expected, + actual, + ) + } + + largeData := make([]byte, 10*1024) + _, err := rand.Read(largeData) + require.NoError(t, err) + + successTests := []struct { + name string + data []byte + numData int + parity int + }{ + { + name: "success", + data: []byte("A journey of a thousands shards begins with a single byte"), + numData: 5, + parity: 3, + }, + { + name: "single data shard and single parity", + data: []byte("some data"), + numData: 1, + parity: 1, + }, + { + name: "large data", + data: largeData, + numData: 8, + parity: 4, + }, + { + name: "above 256 total shards", + data: largeData, + numData: 200, + parity: 100, + }, + } + for _, tc := range successTests { + t.Run(tc.name, func(t *testing.T) { + shards, err := reedsolomon.EncodeData(tc.data, tc.numData, tc.parity) + require.NoError(t, err) + require.Len(t, shards, tc.numData+tc.parity) + requireEqualLength(t, shards) + requireEqualPrefix(t, tc.data, shards) + }) + } + + errorTests := []struct { + name string + data []byte + numData int + parity int + errContains string + }{ + { + name: "empty data", + data: []byte{}, + numData: 5, + parity: 3, + errContains: "received empty data", + }, + { + name: "zero data shards", + data: []byte("data"), + numData: 0, + parity: 3, + errContains: "creating Reed-Solomon encoder", + }, + { + name: "negative parity", + data: []byte("data"), + numData: 5, + parity: -1, + errContains: "creating Reed-Solomon encoder", + }, + { + name: "exceeds max shard count", + data: []byte("data"), + numData: 40000, + parity: 40000, + errContains: "creating Reed-Solomon encoder", + }, + } + for _, tc := range errorTests { + t.Run(tc.name, func(t *testing.T) { + _, err := reedsolomon.EncodeData(tc.data, tc.numData, tc.parity) + require.ErrorContains(t, err, tc.errContains) + }) + } +} + +func TestRecoverData(t *testing.T) { + encode := func(t *testing.T, data []byte, numData, parity int) [][]byte { + t.Helper() + shards, err := reedsolomon.EncodeData(data, numData, parity) + require.NoError(t, err) + return shards + } + + buildDataShards := func(t *testing.T, original [][]byte, missingData ...int) [][]byte { + t.Helper() + dataShards := make([][]byte, len(original)) + for i := range original { + if slices.Contains(missingData, i) { + continue + } + dataShards[i] = make([]byte, len(original[i])) + copy(dataShards[i], original[i]) + } + return dataShards + } + + requireEqualShards := func(t *testing.T, expected [][]byte, actual [][]byte) { + t.Helper() + for i := range expected { + assert.Equalf( + t, expected[i], actual[i], + "at index %d, expected: %s, actual: %s", + i, expected[i], actual[i], + ) + } + } + + successTests := []struct { + name string + data []byte + numData int + parity int + missingIdx []int + }{ + { + name: "no missing shards", + data: []byte("nothing is missing here"), + numData: 4, + parity: 2, + }, + { + name: "missing parity shards", + data: []byte("recover parity shards"), + numData: 4, + parity: 3, + missingIdx: []int{5, 6, 7}, + }, + { + name: "missing data shards within parity limit", + data: []byte("recover data shards from parity"), + numData: 5, + parity: 3, + missingIdx: []int{0, 2, 4}, + }, + { + name: "missing mixed data and parity shards", + data: []byte("mixed missing shards scenario"), + numData: 5, + parity: 4, + missingIdx: []int{1, 3, 5, 6}, + }, + } + for _, tc := range successTests { + t.Run(tc.name, func(t *testing.T) { + expected := encode(t, tc.data, tc.numData, tc.parity) + dataShards := buildDataShards(t, expected, tc.missingIdx...) + + recovered, err := reedsolomon.RecoverData(dataShards, tc.numData, tc.parity) + require.NoError(t, err) + requireEqualShards(t, expected, recovered) + }) + } + + errorTests := []struct { + name string + data []byte + numData int + parity int + missingIdx []int + errContains string + }{ + { + name: "too many missing shards", + data: []byte("too many shards gone"), + numData: 4, + parity: 2, + missingIdx: []int{0, 1, 4}, + errContains: "recovering the data shards:", + }, + { + name: "empty shards slice", + numData: 4, + parity: 2, + errContains: "no data shards provided", + }, + } + for _, tc := range errorTests { + t.Run(tc.name, func(t *testing.T) { + var shards [][]byte + if tc.data != nil { + expected := encode(t, tc.data, tc.numData, tc.parity) + shards = buildDataShards(t, expected, tc.missingIdx...) + } + + recovered, err := reedsolomon.RecoverData(shards, tc.numData, tc.parity) + require.Nil(t, recovered) + require.ErrorContains(t, err, tc.errContains) + }) + } +} + +func TestEncodeDecodeRoundTrip(t *testing.T) { + tests := []struct { + name string + data []byte + numData int + parity int + nilIdxs []int // indices to nil out before recovery + }{ + { + name: "small data, lose 1 data shard", + data: []byte("round trip test"), + numData: 4, parity: 2, + nilIdxs: []int{0}, + }, + { + name: "medium data, lose max shards", + data: bytes.Repeat([]byte("abcdefghij"), 100), + numData: 5, parity: 3, + nilIdxs: []int{1, 3, 6}, + }, + { + name: "single byte", + data: []byte{0xff}, + numData: 2, parity: 1, + nilIdxs: []int{0}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + shards, err := reedsolomon.EncodeData(tc.data, tc.numData, tc.parity) + require.NoError(t, err) + + for _, idx := range tc.nilIdxs { + shards[idx] = nil + } + + recovered, err := reedsolomon.RecoverData(shards, tc.numData, tc.parity) + require.NoError(t, err) + + joined := bytes.Join(recovered[:tc.numData], nil) + assert.Equal(t, tc.data, joined[:len(tc.data)]) + }) + } +} diff --git a/consensus/propeller/scheduler.go b/consensus/propeller/scheduler.go new file mode 100644 index 0000000000..14cc9027a4 --- /dev/null +++ b/consensus/propeller/scheduler.go @@ -0,0 +1,264 @@ +package propeller + +import ( + "cmp" + "errors" + "fmt" + "slices" + + "github.com/libp2p/go-libp2p/core/peer" +) + +type Stake uint64 + +// todo(rdr): this is a Peer that belongs to a committee and has a stake. I would like to +// give it a better name +type PeerCommittee struct { + ID peer.ID + Stake Stake +} + +// Scheduler represents the tree manager that computes the tree topology on demand for each +// publisher. It holds a deterministic shard-to-peer mapping for a committee. +// Given a sorted set of peers and a publisher, it computes which peer is +// responsible for broadcasting each shard index. The mapping is deterministic +// so that all nodes agree on the assignment without coordination. +// +// The design relies on the invariant that there are N-1 shards for N peers, +// and each non-publisher peer gets exactly one shard. The publisher is "skipped" +// in the sorted peer list when assigning shard indices. +// +// Propeller uses a distributed broadcast approach where: +// - numDataShards = floor((N-1)/3) where N is total number of nodes +// - numDataShards represents both max faulty nodes AND number of data shards +// - numCodingShards = N-1-numDataShards (meaning, the rest) +// - Message is BUILT when numDataShards are received (can reconstruct) +// - Message is RECEIVED when 2*numDataShards shards are received (guarantees gossip property) +// - Each peer broadcasts received shards to all other peers (full mesh) +type Scheduler struct { + localPeerID peer.ID + localPeerIDIndex int + peers []PeerCommittee + numDataShards int + numCodingShards int +} + +// NewScheduler creates a schedule from a list of peers. The peers are sorted +// lexicographically by their string representation to ensure all nodes derive +// the same ordering regardless of discovery order. +// Note that `nodes` will be mutated after this function gets called +// todo(rdr): should we return scheduler by reference or by value? +func NewScheduler( + id peer.ID, + nodes []PeerCommittee, +) (*Scheduler, error) { + if len(nodes) < 2 { + return nil, fmt.Errorf( + "at least 2 peers are required to form a new committee: %d given", + len(nodes), + ) + } + + // todo(rdr): check with function is faster for sorting in our case: + // `slices.Sort` or `sort.Slice`. Alternative: + // sort.Slice(nodes, func(i, j int) bool { return nodes[i].ID < nodes[j].ID }) + slices.SortFunc(nodes, func(i, j PeerCommittee) int { return cmp.Compare(i.ID, j.ID) }) + + // check that the local peer ID is part of the peer committee + idIndex, exists := slices.BinarySearchFunc( + nodes, + id, + func(elem PeerCommittee, target peer.ID) int { + return cmp.Compare(elem.ID, target) + }, + ) + if !exists { + return nil, errors.New("the local peer id is not part of the supplied list of peeers") + } + + // check that there is no duplicated ID in the node list + for i := range len(nodes) - 1 { + if nodes[i].ID == nodes[i+1].ID { + return nil, fmt.Errorf("duplicated ids in the supplied list of peers: %s", nodes[i].ID) + } + } + + totalNodes := len(nodes) + // We guarantee always one data shard for small networks (N = 2 or N = 3) + numDataShards := max(1, (totalNodes-1)/3) + // We avoid the possibility of an underflow + numCodingShards := max(0, totalNodes-1-numDataShards) + + return &Scheduler{ + localPeerID: id, + localPeerIDIndex: idIndex, + peers: nodes, + numDataShards: numDataShards, + numCodingShards: numCodingShards, + }, nil +} + +// PeerID returns the Scheduler Peer ID +func (s *Scheduler) PeerID() peer.ID { + return s.localPeerID +} + +// Peers return the Scheduler list of nodes +func (s *Scheduler) Peers() []PeerCommittee { + return s.peers +} + +// DataShards returns the number of data (systematic) shards. +func (s *Scheduler) NumDataShards() int { return s.numDataShards } + +// CodingShards returns the number of parity (coding) shards. +func (s *Scheduler) NumCodingShards() int { return s.numCodingShards } + +// NumShards returns the total number of shards (data + coding = N-1). +func (s *Scheduler) NumTotalShards() int { return s.numDataShards + s.numCodingShards } + +// Minimum (inclusive) amount of shards required to build a message +func (s *Scheduler) BuildThreshold() int { return s.numDataShards } + +// Minimum (inclusive) amount of shards required to guarantee a message is received +func (s *Scheduler) ReceiveThreshold() int { + if len(s.peers) <= 3 { + return s.BuildThreshold() + } + return s.numDataShards * 2 +} + +func (s *Scheduler) publisherIndex(publisher peer.ID) (int, error) { + publisherIndex, found := slices.BinarySearchFunc( + s.peers, + publisher, + func(elem PeerCommittee, target peer.ID) int { + return cmp.Compare(elem.ID, target) + }, + ) + if !found { + return -1, fmt.Errorf("publisher with id \"%s\" not found in the peer list", publisher) + } + return publisherIndex, nil +} + +// PeerForShardIndex returns the peer responsible for broadcasting a given +// shard index to a given publisher. The mapping skips the publisher in the +// sorted list: +// +// if shardIndex < publisherIndex: peer = peers[shardIndex] +// if shardIndex >= publisherIndex: peer = peers[shardIndex + 1] +// +// Example with peers [A, B, C, D] and publisher C (index 2): +// +// shard 0 -> A, shard 1 -> B, shard 2 -> D +func (s *Scheduler) PeerForShardIndex( + publisher peer.ID, shardIndex ShardIndex, +) (peer.ID, error) { + if int(shardIndex) >= s.NumTotalShards() { + return "", fmt.Errorf( + "shard index %d out of range [0, %d)", shardIndex, s.NumTotalShards(), + ) + } + + pubIdx, err := s.publisherIndex(publisher) + if err != nil { + return "", err + } + + // Skip the publisher's position. + peerIdx := int(shardIndex) + if peerIdx >= pubIdx { + peerIdx++ + } + + return s.peers[peerIdx].ID, nil +} + +// ShardIndexForPublisher returns the shard index that shceduler is responsible for +// broadcasting for a given publisher. This is the inverse of PeerForShard: +// +// if localPeerIndex < publisherIndex: shard = localPeerIndex +// if localPeerIndex > publisherIndex: shard = localPeerIndex - 1 +// +// Returns an error if Scheduler's peer is the publisher (publishers don't have an +// assigned shard) or if the publisher is not in the list. +func (s *Scheduler) ShardIndexForPublisher( + publisher peer.ID, +) (ShardIndex, error) { + if s.localPeerID == publisher { + return 0, fmt.Errorf( + "scheduler peer is the same as the publisher and has no assigned shard: %s", + publisher, + ) + } + + pubIdx, err := s.publisherIndex(publisher) + if err != nil { + return 0, fmt.Errorf("couldn't locate shard index for publisher: %w", err) + } + + shardIdx := s.localPeerIDIndex + if s.localPeerIDIndex >= pubIdx { + shardIdx = s.localPeerIDIndex - 1 + } + + return ShardIndex(shardIdx), nil +} + +// ValidateShardOrigin verifies that a shard unit was received from the expected sender. +// The sender has to be either the publisher for direct shards or a designated +// broadcaster for the given shard index. +// todo(rdr): This implementation should probably be part of `UnitValidator` +func (s *Scheduler) ValidateShardOrigin( + sender peer.ID, + publisher peer.ID, + shardIndex ShardIndex, +) error { + if sender == s.localPeerID { + return fmt.Errorf("self sending message from %s", sender) + } + if publisher == s.localPeerID { + return fmt.Errorf("self published shard was sent back by %s", sender) + } + + expectedBroadcaster, err := s.PeerForShardIndex(publisher, shardIndex) + if err != nil { + return fmt.Errorf( + "couldn't validate publisher %s with shard %d: %w", + publisher, + shardIndex, + err, + ) + } + + validDirectShard := expectedBroadcaster == s.localPeerID && sender == publisher + if validDirectShard { + return nil + } + + validBroadcastShard := expectedBroadcaster == sender + if validBroadcastShard { + return nil + } + + return fmt.Errorf( + "received shard index %d from unexpected sender %s", + shardIndex, + sender, + ) +} + +// BroadcastTargets returns all peers whom to broadcast to, in shard-index order. +// The i-th element of the returned slice is the peer responsible for shard i. +func (s *Scheduler) BroadcastTargets() []peer.ID { + // todo(rdr): I would like to not use `append` and index directly instead (it's faster) + targets := make([]peer.ID, 0, s.NumTotalShards()) + for i, p := range s.peers { + if i == s.localPeerIDIndex { + continue + } + targets = append(targets, p.ID) + } + return targets +} diff --git a/consensus/propeller/scheduler_test.go b/consensus/propeller/scheduler_test.go new file mode 100644 index 0000000000..60bedf2d43 --- /dev/null +++ b/consensus/propeller/scheduler_test.go @@ -0,0 +1,381 @@ +package propeller + +import ( + "testing" + + "github.com/libp2p/go-libp2p/core/peer" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +// testPeers creates PeerCommittee entries from the given names. +// Each test chooses its own local peer explicitly. +func testPeers(t *testing.T, names ...string) []PeerCommittee { + t.Helper() + peers := make([]PeerCommittee, len(names)) + for i, n := range names { + peers[i] = PeerCommittee{ID: peer.ID(n), Stake: 1} + } + return peers +} + +func TestScheduler_NewScheduler_Validation(t *testing.T) { + t.Run("fewer than 2 peers", func(t *testing.T) { + peers := testPeers(t, "A") + _, err := NewScheduler(peer.ID("A"), peers) + assert.Error(t, err) + }) + + t.Run("local peer not in list", func(t *testing.T) { + peers := testPeers(t, "A", "B", "C") + _, err := NewScheduler(peer.ID("Z"), peers) + assert.Error(t, err) + }) + + t.Run("duplicate peers", func(t *testing.T) { + peers := testPeers(t, "A", "B", "B") + _, err := NewScheduler(peer.ID("A"), peers) + assert.Error(t, err) + }) + + t.Run("valid construction", func(t *testing.T) { + peers := testPeers(t, "A", "B", "C") + s, err := NewScheduler(peer.ID("B"), peers) + require.NoError(t, err) + assert.Equal(t, peer.ID("B"), s.PeerID()) + }) +} + +func TestScheduler_ShardCounts(t *testing.T) { + tests := []struct { + name string + n int + numDataShards int + numCodingShards int + numTotalShards int + buildThreshold int + receiveThreshold int + }{ + { + name: "N=2", + n: 2, + numDataShards: 1, + numCodingShards: 0, + numTotalShards: 1, + buildThreshold: 1, + receiveThreshold: 1, // len<=3, falls back to buildThreshold + }, + { + name: "N=3", + n: 3, + numDataShards: 1, + numCodingShards: 1, + numTotalShards: 2, + buildThreshold: 1, + receiveThreshold: 1, // len<=3, falls back to buildThreshold + }, + { + name: "N=4", + n: 4, + numDataShards: 1, + numCodingShards: 2, + numTotalShards: 3, + buildThreshold: 1, + receiveThreshold: 2, + }, + { + name: "N=5", + n: 5, + numDataShards: 1, + numCodingShards: 3, + numTotalShards: 4, + buildThreshold: 1, + receiveThreshold: 2, + }, + { + name: "N=7", + n: 7, + numDataShards: 2, + numCodingShards: 4, + numTotalShards: 6, + buildThreshold: 2, + receiveThreshold: 4, + }, + { + name: "N=10", + n: 10, + numDataShards: 3, + numCodingShards: 6, + numTotalShards: 9, + buildThreshold: 3, + receiveThreshold: 6, + }, + { + name: "N=31", + n: 31, + numDataShards: 10, + numCodingShards: 20, + numTotalShards: 30, + buildThreshold: 10, + receiveThreshold: 20, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + names := make([]string, tc.n) + for i := range tc.n { + names[i] = string(rune('A' + i)) + } + peers := testPeers(t, names...) + + s, err := NewScheduler(peers[0].ID, peers) + require.NoError(t, err) + + assert.Equal(t, tc.numDataShards, s.NumDataShards()) + assert.Equal(t, tc.numCodingShards, s.NumCodingShards()) + assert.Equal(t, tc.numTotalShards, s.NumTotalShards()) + assert.Equal(t, tc.buildThreshold, s.BuildThreshold()) + assert.Equal(t, tc.receiveThreshold, s.ReceiveThreshold()) + }) + } +} + +func TestScheduler_DeterministicMapping(t *testing.T) { + // Two schedulers built from differently-ordered peer lists must + // produce identical shard-to-peer mappings for every publisher. + s1, err := NewScheduler(peer.ID("A"), testPeers(t, "A", "B", "C", "D")) + require.NoError(t, err) + + s2, err := NewScheduler(peer.ID("A"), testPeers(t, "D", "B", "A", "C")) + require.NoError(t, err) + + for _, pub := range s1.Peers() { + for idx := range s1.NumTotalShards() { + p1, err1 := s1.PeerForShardIndex(pub.ID, ShardIndex(idx)) + p2, err2 := s2.PeerForShardIndex(pub.ID, ShardIndex(idx)) + require.NoError(t, err1) + require.NoError(t, err2) + assert.Equal(t, p1, p2, "publisher=%s shard=%d", pub.ID, idx) + } + } +} + +func TestScheduler_PeerForShardIndex(t *testing.T) { + // peers [A, B, C, D]: the publisher is skipped in the sorted list, + // so each remaining peer maps to shard indices 0..2 in order. + tests := []struct { + name string + publisher string + expected []peer.ID + }{ + { + name: "publisher middle (C, index 2)", + publisher: "C", + expected: []peer.ID{"A", "B", "D"}, + }, + { + name: "publisher first (A, index 0)", + publisher: "A", + expected: []peer.ID{"B", "C", "D"}, + }, + { + name: "publisher last (D, index 3)", + publisher: "D", + expected: []peer.ID{"A", "B", "C"}, + }, + } + + s, err := NewScheduler(peer.ID("A"), testPeers(t, "A", "B", "C", "D")) + require.NoError(t, err) + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + for i, want := range tc.expected { + got, err := s.PeerForShardIndex(peer.ID(tc.publisher), ShardIndex(i)) + require.NoError(t, err) + assert.Equal(t, want, got, "shard %d", i) + } + }) + } +} + +func TestScheduler_PeerForShardIndex_Errors(t *testing.T) { + s, err := NewScheduler(peer.ID("A"), testPeers(t, "A", "B", "C")) + require.NoError(t, err) + + t.Run("publisher not in list", func(t *testing.T) { + _, err := s.PeerForShardIndex(peer.ID("Z"), 0) + assert.Error(t, err) + }) + + t.Run("shard index out of range", func(t *testing.T) { + _, err := s.PeerForShardIndex(peer.ID("A"), ShardIndex(s.NumTotalShards())) + assert.Error(t, err) + }) +} + +func TestScheduler_ShardIndexForPublisher(t *testing.T) { + // peers [A, B, C, D], publisher=C. + // A(idx 0) -> shard 0, B(idx 1) -> shard 1, D(idx 3) -> shard 2 + tests := []struct { + localPeer string + expectedShard ShardIndex + }{ + {"A", 0}, + {"B", 1}, + {"D", 2}, + } + publisher := peer.ID("C") + + for _, tc := range tests { + t.Run("local="+tc.localPeer, func(t *testing.T) { + s, err := NewScheduler( + peer.ID(tc.localPeer), testPeers(t, "A", "B", "C", "D"), + ) + require.NoError(t, err) + + got, err := s.ShardIndexForPublisher(publisher) + require.NoError(t, err) + assert.Equal(t, tc.expectedShard, got) + }) + } +} + +func TestScheduler_ShardIndexForPublisher_Errors(t *testing.T) { + s, err := NewScheduler(peer.ID("B"), testPeers(t, "A", "B", "C")) + require.NoError(t, err) + + t.Run("local peer is the publisher", func(t *testing.T) { + _, err := s.ShardIndexForPublisher(peer.ID("B")) + assert.Error(t, err) + }) + + t.Run("publisher not in list", func(t *testing.T) { + _, err := s.ShardIndexForPublisher(peer.ID("Z")) + assert.Error(t, err) + }) +} + +func TestScheduler_InverseProperty(t *testing.T) { + // For every local peer and every publisher, verify that + // PeerForShardIndex and ShardIndexForPublisher are inverses. + names := []string{"A", "B", "C", "D", "E"} + + for _, local := range names { + s, err := NewScheduler(peer.ID(local), testPeers(t, names...)) + require.NoError(t, err) + + for _, pub := range names { + if pub == local { + continue + } + // ShardIndexForPublisher -> PeerForShardIndex should round-trip + shardIdx, err := s.ShardIndexForPublisher(peer.ID(pub)) + require.NoError(t, err) + + gotPeer, err := s.PeerForShardIndex(peer.ID(pub), shardIdx) + require.NoError(t, err) + assert.Equal(t, peer.ID(local), gotPeer, + "local=%s publisher=%s shard=%d", local, pub, shardIdx) + } + } +} + +func TestScheduler_BroadcastTargets(t *testing.T) { + // BroadcastTargets returns every peer except the local peer, in sorted order. + tests := []struct { + name string + local string + expected []peer.ID + }{ + { + name: "local first", + local: "A", + expected: []peer.ID{"B", "C", "D"}, + }, + { + name: "local middle", + local: "C", + expected: []peer.ID{"A", "B", "D"}, + }, + { + name: "local last", + local: "D", + expected: []peer.ID{"A", "B", "C"}, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + s, err := NewScheduler(peer.ID(tc.local), testPeers(t, "A", "B", "C", "D")) + require.NoError(t, err) + + targets := s.BroadcastTargets() + assert.Equal(t, tc.expected, targets) + }) + } +} + +func TestScheduler_ValidateShardOrigin(t *testing.T) { + // peers [A, B, C, D], local = C. + // For publisher A (index 0): shard 0 -> B, shard 1 -> C, shard 2 -> D + // So local=C is the designated broadcaster for shard 1 when publisher=A. + s, err := NewScheduler(peer.ID("C"), testPeers(t, "A", "B", "C", "D")) + require.NoError(t, err) + + tests := []struct { + name string + sender string + publisher string + shardIndex ShardIndex + wantErr bool + }{ + { + name: "valid direct shard from publisher", + sender: "A", + publisher: "A", + shardIndex: 1, // C is the designated broadcaster, so publisher sends directly + wantErr: false, + }, + { + name: "valid broadcast shard from designated peer", + sender: "B", + publisher: "A", + shardIndex: 0, // B is the designated broadcaster for shard 0 + wantErr: false, + }, + { + name: "self-send rejected", + sender: "C", + publisher: "A", + shardIndex: 0, + wantErr: true, + }, + { + name: "self-published shard sent back rejected", + sender: "A", + publisher: "C", // local peer is the publisher + shardIndex: 0, + wantErr: true, + }, + { + name: "wrong sender rejected", + sender: "D", + publisher: "A", + shardIndex: 0, // designated broadcaster is B, not D + wantErr: true, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := s.ValidateShardOrigin(peer.ID(tc.sender), peer.ID(tc.publisher), tc.shardIndex) + if tc.wantErr { + assert.Error(t, err) + } else { + assert.NoError(t, err) + } + }) + } +} diff --git a/consensus/propeller/sharding.go b/consensus/propeller/sharding.go new file mode 100644 index 0000000000..e2e555b6ae --- /dev/null +++ b/consensus/propeller/sharding.go @@ -0,0 +1,126 @@ +package propeller + +import ( + "errors" + "fmt" + + "github.com/NethermindEth/juno/consensus/propeller/merkle" + "github.com/NethermindEth/juno/consensus/propeller/reedsolomon" + "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/peer" +) + +// CreatePropellerUnits creates the PropellerUnits for publishing +// todo(rdr): maybe call it create message for sharing or somth like that +func CreatePropellerUnits( + privKey crypto.PrivKey, + committeeID *CommitteeID, + nonce Nonce, + message []byte, + numDataShards, + parity int, +) ([]Unit, error) { + publisherID, err := peer.IDFromPrivateKey(privKey) + if err != nil { + return nil, fmt.Errorf("getting publisher id %s from private key: %w", publisherID, err) + } + + paddedMessage := PadMessage(message, numDataShards) + encodedMessage, err := reedsolomon.EncodeData(paddedMessage, numDataShards, parity) + if err != nil { + return nil, fmt.Errorf("encoding the message: %w", err) + } + + merkleRoot, merkleTree := merkle.New(encodedMessage) + messageRoot := MessageRoot(merkleRoot) + + signature, err := SignMessage(privKey, &messageRoot, committeeID, nonce) + if err != nil { + return nil, err + } + + units := make([]Unit, len(encodedMessage)) + for i, shard := range encodedMessage { + units[i] = Unit{ + CommitteeID: *committeeID, + Publisher: publisherID, + MessageRoot: messageRoot, + MerkleProof: merkleTree[i], + Signature: signature, + ShardIndex: ShardIndex(i), + // todo(rdr): assigning one shard per unit until multi shard algo per unit + // is clear to me. + ShardData: []Shard{shard}, + } + } + return units, nil +} + +// ConstructMessageFromUnits receives Propeller units, recovers any missing data and returns +// the fully verified message, together with the corresponding shard data and merkle proof. +func ConstructMessageFromUnits( + units []*Unit, + localShardIndex ShardIndex, + numDataShards int, + parity int, +) ([]byte, ShardData, merkle.Proof, error) { + if len(units) == 0 { + return nil, nil, merkle.Proof{}, errors.New("no propeller units to decode") + } + + shards := make([][]byte, len(units)) + for i := range shards { + if units[i] != nil { + // todo(rdr): we are assuming that every unit only carries one shard data for now + // Not sure how the matrix is built when unit carries more than one + // Probably it is an algorithm based on stake levels (?) + shards[i] = units[i].ShardData[0] + } + } + + shards, err := reedsolomon.RecoverData(shards, numDataShards, parity) + if err != nil { + return nil, nil, merkle.Proof{}, fmt.Errorf("recovering shards data: %w", err) + } + shardSize := len(shards[0]) + for i := range numDataShards { + if shards[i] != nil && len(shards[i]) != shardSize { + return nil, nil, merkle.Proof{}, fmt.Errorf( + "missmatch on shard size: %d (at index 0) vs %d (at index %d)", + len(shards[0]), + len(shards[i]), + i, + ) + } + } + + merkleRoot, merkleTree := merkle.New(shards) + + messageRoot := units[0].MessageRoot + expectedRoot := MessageRoot(merkleRoot) + if messageRoot != expectedRoot { + // todo(rdr): probably need to write string methods for the MessageRoot type + return nil, nil, merkle.Proof{}, fmt.Errorf( + "wrong message root hash. Expected %v but got %v", + &expectedRoot, + &messageRoot, + ) + } + + paddedMessage := make([]byte, len(shards[0])*len(shards)) + for i := range shards { + copy(paddedMessage[i*shardSize:], shards[i]) + } + message, err := UnpadMessage(paddedMessage) + if err != nil { + return nil, nil, merkle.Proof{}, fmt.Errorf("unpadding reconstructed message: %w", err) + } + + // todo(rdr): only one for now, but there can be more.TBD how that works + localShard := []Shard{ + shards[localShardIndex], + } + localProof := merkleTree[localShardIndex] + + return message, localShard, localProof, nil +} diff --git a/consensus/propeller/sharding_test.go b/consensus/propeller/sharding_test.go new file mode 100644 index 0000000000..108f138015 --- /dev/null +++ b/consensus/propeller/sharding_test.go @@ -0,0 +1,204 @@ +package propeller + +// import ( +// "testing" +// +// "github.com/libp2p/go-libp2p/core/peer" +// "github.com/stretchr/testify/assert" +// "github.com/stretchr/testify/require" +// ) +// +// // makeSchedule is a test helper that creates a schedule from N single-char peers. +// func makeSchedule(n int) *Scheduler { +// names := make([]peer.ID, n) +// for i := range n { +// names[i] = peer.ID(string(rune('A' + i))) +// } +// return NewScheduler(names) +// } +// +// func TestEncodeMessage_RoundTrip(t *testing.T) { +// tests := []struct { +// name string +// n int +// msgLen int +// }{ +// {"4 peers, short message", 4, 10}, +// {"4 peers, medium message", 4, 500}, +// {"7 peers, short message", 7, 20}, +// {"10 peers, 1KB message", 10, 1024}, +// {"2 peers, tiny message", 2, 1}, +// {"3 peers, empty message", 3, 0}, +// } +// +// for _, tc := range tests { +// t.Run(tc.name, func(t *testing.T) { +// schedule := makeSchedule(tc.n) +// if schedule.NumShards() == 0 { +// t.Skip("no shards for single peer") +// } +// +// enc, err := NewEncoder( +// schedule.NumDataShards(), schedule.NumCodingShards(), +// ) +// require.NoError(t, err) +// +// msg := make([]byte, tc.msgLen) +// for i := range msg { +// msg[i] = byte(i) +// } +// +// units, root, err := EncodeMessage(msg, schedule, enc) +// require.NoError(t, err) +// assert.Len(t, units, schedule.NumShards()) +// +// // All units should reference the same root. +// for _, u := range units { +// assert.Equal(t, root, u.MerkleRoot) +// } +// +// // Reconstruct from all shards. +// shards := make([][]byte, schedule.NumShards()) +// for _, u := range units { +// shards[u.ShardIndex] = u.ShardData +// } +// +// recovered, err := ReconstructMessage( +// shards, schedule, enc, root, +// ) +// require.NoError(t, err) +// assert.Equal(t, msg, recovered) +// }) +// } +// } +// +// func TestEncodeMessage_ReconstructFromMinimumShards(t *testing.T) { +// // With N=10 we have 3 data shards and 6 coding shards. +// // We should be able to reconstruct from just the 3 data shards. +// schedule := makeSchedule(10) +// enc, err := NewEncoder( +// schedule.NumDataShards(), schedule.NumCodingShards(), +// ) +// require.NoError(t, err) +// +// msg := []byte("reconstruct me from minimum shards please") +// units, root, err := EncodeMessage(msg, schedule, enc) +// require.NoError(t, err) +// +// // Keep only the first numDataShards shards. +// shards := make([][]byte, schedule.NumShards()) +// for i := range schedule.NumDataShards() { +// shards[units[i].ShardIndex] = units[i].ShardData +// } +// +// recovered, err := ReconstructMessage(shards, schedule, enc, root) +// require.NoError(t, err) +// assert.Equal(t, msg, recovered) +// } +// +// func TestEncodeMessage_ReconstructWithMissingDataShards(t *testing.T) { +// // With N=7 we have 2 data shards and 4 coding shards. +// // Drop all data shards, keep only coding shards -> should reconstruct. +// schedule := makeSchedule(7) +// enc, err := NewEncoder( +// schedule.NumDataShards(), schedule.NumCodingShards(), +// ) +// require.NoError(t, err) +// +// msg := []byte("even without data shards") +// units, root, err := EncodeMessage(msg, schedule, enc) +// require.NoError(t, err) +// +// // Keep only coding shards (indices >= numDataShards). +// shards := make([][]byte, schedule.NumShards()) +// for _, u := range units { +// if int(u.ShardIndex) >= schedule.NumDataShards() { +// shards[u.ShardIndex] = u.ShardData +// } +// } +// +// recovered, err := ReconstructMessage(shards, schedule, enc, root) +// require.NoError(t, err) +// assert.Equal(t, msg, recovered) +// } +// +// func TestEncodeMessage_MerkleProofsVerify(t *testing.T) { +// schedule := makeSchedule(5) +// enc, err := NewEncoder( +// schedule.NumDataShards(), schedule.NumCodingShards(), +// ) +// require.NoError(t, err) +// +// msg := []byte("verify all proofs") +// units, root, err := EncodeMessage(msg, schedule, enc) +// require.NoError(t, err) +// +// for _, u := range units { +// ok := VerifyMerkleProof(root, u.ShardData, uint32(u.ShardIndex), u.MerkleProof) +// assert.True(t, ok, "proof for shard %d should verify", u.ShardIndex) +// } +// } +// +// func TestReconstructMessage_MismatchedRoot(t *testing.T) { +// schedule := makeSchedule(4) +// enc, err := NewEncoder( +// schedule.NumDataShards(), schedule.NumCodingShards(), +// ) +// require.NoError(t, err) +// +// msg := []byte("good message") +// units, _, err := EncodeMessage(msg, schedule, enc) +// require.NoError(t, err) +// +// shards := make([][]byte, schedule.NumShards()) +// for _, u := range units { +// shards[u.ShardIndex] = u.ShardData +// } +// +// // Pass a wrong root. +// fakeRoot := MessageRoot{0xff} +// _, err = ReconstructMessage(shards, schedule, enc, fakeRoot) +// require.Error(t, err) +// +// var reconErr *ReconstructionError +// require.ErrorAs(t, err, &reconErr) +// assert.Equal(t, ReasonMismatchedMessageRoot, reconErr.Reason) +// } +// +// func TestReconstructMessage_InsufficientShards(t *testing.T) { +// schedule := makeSchedule(10) // 3 data, 6 coding +// enc, err := NewEncoder( +// schedule.NumDataShards(), schedule.NumCodingShards(), +// ) +// require.NoError(t, err) +// +// msg := []byte("not enough shards") +// units, root, err := EncodeMessage(msg, schedule, enc) +// require.NoError(t, err) +// +// // Provide only 2 shards when 3 are needed. +// shards := make([][]byte, schedule.NumShards()) +// shards[units[0].ShardIndex] = units[0].ShardData +// shards[units[1].ShardIndex] = units[1].ShardData +// +// _, err = ReconstructMessage(shards, schedule, enc, root) +// require.Error(t, err) +// +// var reconErr *ReconstructionError +// require.ErrorAs(t, err, &reconErr) +// assert.Equal(t, ReasonErasureReconstructionFailed, reconErr.Reason) +// } +// +// func TestEncodeMessage_NoShards(t *testing.T) { +// // A single-node schedule has no shards. +// schedule := makeSchedule(1) +// enc, err := NewEncoder(1, 0) +// require.NoError(t, err) +// +// _, _, err = EncodeMessage([]byte("x"), schedule, enc) +// require.Error(t, err) +// +// var pubErr *ShardPublishError +// require.ErrorAs(t, err, &pubErr) +// assert.Equal(t, ReasonInvalidDataSize, pubErr.Reason) +// } diff --git a/consensus/propeller/signing.go b/consensus/propeller/signing.go new file mode 100644 index 0000000000..9fb93d1f23 --- /dev/null +++ b/consensus/propeller/signing.go @@ -0,0 +1,77 @@ +package propeller + +import ( + "encoding/binary" + "errors" + "fmt" + + "github.com/libp2p/go-libp2p/core/crypto" +) + +const payloadLen = 95 + +// buildSignPayload constructs the byte sequence that the publisher signs. +// Does it in constant time without heap allocations +func buildSignPayload( + root *MessageRoot, committeeID *CommitteeID, nonce Nonce, +) [payloadLen]byte { + // The tags domain-separate propeller signatures from any other protocol + // that might use the same key, preventing cross-protocol signature reuse. + const prefix = "" + const suffix = "" + + // cumulative lengths denoting the ranges in where each bytes of data should be stored + const prefixLen = len(prefix) + const rootLen = prefixLen + 32 + const committeeIDLen = rootLen + 32 + const nonceLen = committeeIDLen + 8 + const suffixLen = nonceLen + len(suffix) + + var payload [payloadLen]byte + + copy(payload[0:prefixLen], prefix) + copy(payload[prefixLen:rootLen], root[:]) + copy(payload[rootLen:committeeIDLen], committeeID[:]) + binary.BigEndian.PutUint64(payload[committeeIDLen:nonceLen], uint64(nonce)) + copy(payload[nonceLen:suffixLen], suffix) + + return payload +} + +func SignMessage( + privKey crypto.PrivKey, + root *MessageRoot, + committeeID *CommitteeID, + nonce Nonce, +) (Signature, error) { + payload := buildSignPayload(root, committeeID, nonce) + sig, err := privKey.Sign(payload[:]) + if err != nil { + return nil, fmt.Errorf("signing message root: %w", err) + } + return sig, nil +} + +func VerifyMessageSignature( + pubKey crypto.PubKey, + root *MessageRoot, + committeeID *CommitteeID, + nonce Nonce, + signature Signature, +) error { + if len(signature) == 0 { + return errors.New("empty signature") + } + + payload := buildSignPayload(root, committeeID, nonce) + valid, err := pubKey.Verify(payload[:], signature) + if err != nil { + return fmt.Errorf("failed pub key verification: %w", err) + } + + if !valid { + return errors.New("signature is invalid") + } + + return nil +} diff --git a/consensus/propeller/signing_test.go b/consensus/propeller/signing_test.go new file mode 100644 index 0000000000..29e7f50b8d --- /dev/null +++ b/consensus/propeller/signing_test.go @@ -0,0 +1,135 @@ +package propeller_test + +import ( + "bytes" + "crypto/ed25519" + "testing" + "time" + + "github.com/NethermindEth/juno/consensus/propeller" + "github.com/libp2p/go-libp2p/core/crypto" + "github.com/stretchr/testify/require" +) + +func generateKey(t *testing.T, seed byte) (crypto.PrivKey, crypto.PubKey) { + t.Helper() + s := make([]byte, ed25519.SeedSize) + s[0] = seed + priv, pub, err := crypto.GenerateEd25519Key(bytes.NewReader(s)) + require.NoError(t, err) + return priv, pub +} + +func TestSignAndVerify(t *testing.T) { + privA, pubA := generateKey(t, 1) + + root := propeller.MessageRoot{0xAA} + committeeID := propeller.CommitteeID{0xBB} + nonce := propeller.Nonce(time.Second) + + sig, err := propeller.SignMessage(privA, &root, &committeeID, nonce) + require.NoError(t, err) + + t.Run("success", func(t *testing.T) { + tests := []struct { + name string + root propeller.MessageRoot + committeeID propeller.CommitteeID + nonce propeller.Nonce + }{ + { + name: "valid roundtrip", + root: root, + committeeID: committeeID, + nonce: nonce, + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := propeller.VerifyMessageSignature( + pubA, &tc.root, &tc.committeeID, tc.nonce, sig, + ) + require.NoError(t, err) + }) + } + }) + + t.Run("error", func(t *testing.T) { + _, pubB := generateKey(t, 2) + + tests := []struct { + name string + pubKey crypto.PubKey + root propeller.MessageRoot + committeeID propeller.CommitteeID + nonce propeller.Nonce + signature propeller.Signature + wantErr string + }{ + { + name: "wrong public key", + pubKey: pubB, + root: root, + committeeID: committeeID, + nonce: nonce, + signature: sig, + wantErr: "signature is invalid", + }, + { + name: "tampered root", + pubKey: pubA, + root: propeller.MessageRoot{0xFF}, + committeeID: committeeID, + nonce: nonce, + signature: sig, + wantErr: "signature is invalid", + }, + { + name: "tampered committee ID", + pubKey: pubA, + root: root, + committeeID: propeller.CommitteeID{0xFF}, + nonce: nonce, + signature: sig, + wantErr: "signature is invalid", + }, + { + name: "tampered nonce", + pubKey: pubA, + root: root, + committeeID: committeeID, + nonce: propeller.Nonce(time.Hour), + signature: sig, + wantErr: "signature is invalid", + }, + { + name: "empty signature", + pubKey: pubA, + root: root, + committeeID: committeeID, + nonce: nonce, + signature: nil, + wantErr: "empty signature", + }, + { + name: "corrupted signature", + pubKey: pubA, + root: root, + committeeID: committeeID, + nonce: nonce, + signature: append(append([]byte{}, sig...), 0xFF), + wantErr: "signature is invalid", + }, + } + + for _, tc := range tests { + t.Run(tc.name, func(t *testing.T) { + err := propeller.VerifyMessageSignature( + tc.pubKey, &tc.root, &tc.committeeID, tc.nonce, tc.signature, + ) + require.ErrorContains(t, err, tc.wantErr) + }) + } + }) +} diff --git a/consensus/propeller/timecache/timecache.go b/consensus/propeller/timecache/timecache.go new file mode 100644 index 0000000000..371a49fd7f --- /dev/null +++ b/consensus/propeller/timecache/timecache.go @@ -0,0 +1,151 @@ +package timecache + +import ( + "sync" + "time" +) + +type index int + +type timedValue[K any] struct { + value K + expiry time.Time +} + +type TimeCache[K comparable] struct { + // todo(rdr): there is a possibility of make the value an `index` and find the time + // information on `timestamps`. This would reduce duplication of time.Time + // and it would also allow to easily detect expired values (no longer need to + // perform time.Time substractions since everything before `index` is expired) + // This is might be hyper-optimising.... + // Access valid keys in O(1) + values map[K]time.Time + // Clean expired keys O(k) where `k` is the amount of expired keys + timestamps []timedValue[K] + mu sync.RWMutex + + // Track currently stored values, first value at `start` and last + // one at `end`. `size` is the maximum amount of elements we can hold + start index + end index + size int + // used to delete any timed value which has been inserted more than + // `exipiry` time.Duration ago. + expiry time.Duration +} + +// New allocates a new Timecache with initial allocation size and expiry time. +// If `size` gets filled the timecache will allocate more memory to fit more +// elements into it. The cache will not shrink after regrowing. It is safe for +// concurrent use. +func New[K comparable](size int, expiry time.Duration) *TimeCache[K] { + // we allocate size+1 because we allways leave the last position empty + // to detect when the cache is full + return &TimeCache[K]{ + values: make(map[K]time.Time, size+1), + timestamps: make([]timedValue[K], size+1), + mu: sync.RWMutex{}, + + start: 0, + end: 0, + size: size + 1, + expiry: expiry, + } +} + +// Add adds a new key into the timecache, it doesn't guard against duplicated +// entries. Adding the same entry twice will result in undefined behaviour. +func (tc *TimeCache[K]) Add(value *K) { + tc.mu.Lock() + defer tc.mu.Unlock() + + now := time.Now() + tc.removeExpired(now) + if tc.almostFull() { + tc.regrowth() + } + + expiryTime := now.Add(tc.expiry) + tc.values[*value] = expiryTime + tc.timestamps[tc.end] = timedValue[K]{ + value: *value, + expiry: expiryTime, + } + tc.increaseIndex(&tc.end) +} + +// Get returns true if the entry exists and it hasn't expired, false otherwise +func (tc *TimeCache[K]) Get(value *K) bool { + tc.mu.RLock() + expiry, ok := tc.values[*value] + tc.mu.RUnlock() + + if !ok { + return false + } + + now := time.Now() + if expiry.After(now) { + return true + } + + // If we know we have an expired value + // let's clean the expired entries + tc.mu.Lock() + tc.removeExpired(now) + tc.mu.Unlock() + return false +} + +func (tc *TimeCache[K]) increaseIndex(idx *index) { + *idx = (*idx + 1) % index(tc.size) +} + +// removeExpired deletes all the elements that have already expired until it +// finds the first one that hasn't or the cache empties +func (tc *TimeCache[K]) removeExpired(now time.Time) { + for tc.start != tc.end { + tv := &tc.timestamps[tc.start] + if now.Before(tv.expiry) { + break + } + + delete(tc.values, tv.value) + tc.increaseIndex(&tc.start) + } +} + +// almostFull returns if the time cache will get full on the next insertion +func (tc *TimeCache[K]) almostFull() bool { + return (tc.end+1)%index(tc.size) == tc.start +} + +func (tc *TimeCache[K]) regrowth() { + const standardSize = 1024 + + nextSize := tc.size * 2 + if tc.size > standardSize { + // growth by 20% + nextSize = (tc.size * 12) / 10 + } + + nextTimestamps := make([]timedValue[K], nextSize) + + // This case only applies when start == 0 and end == size-1 + if tc.start < tc.end { + copy(nextTimestamps, tc.timestamps) + tc.size = nextSize + tc.timestamps = nextTimestamps + return + } + + count := tc.size - int(tc.start) + copy(nextTimestamps[0:count], tc.timestamps[tc.start:tc.size]) + nextEnd := count + int(tc.end) + copy(nextTimestamps[count:nextEnd], tc.timestamps[0:tc.end]) + + tc.start = 0 + tc.end = index(nextEnd) + tc.size = nextSize + tc.timestamps = nextTimestamps +} diff --git a/consensus/propeller/timecache/timecache_bench_test.go b/consensus/propeller/timecache/timecache_bench_test.go new file mode 100644 index 0000000000..f885127cc5 --- /dev/null +++ b/consensus/propeller/timecache/timecache_bench_test.go @@ -0,0 +1,131 @@ +package timecache_test + +import ( + "math/rand/v2" + "sync/atomic" + "testing" + "time" + + "github.com/NethermindEth/juno/consensus/propeller/timecache" +) + +func BenchmarkTimeCacheAdd(b *testing.B) { + b.Run("small cache size", func(b *testing.B) { + tc := timecache.New[int](100, 3*time.Second) + for i := range b.N { + tc.Add(&i) + } + }) + + b.Run("big cache size", func(b *testing.B) { + tc := timecache.New[int](5000, 3*time.Second) + for i := range b.N { + tc.Add(&i) + } + }) + + b.Run("big cache size with some expired values", func(b *testing.B) { + const size = 2000 + tc := timecache.New[int](size, 2*time.Second) + for i := range size - 1 { + tc.Add(&i) + time.Sleep(1 * time.Millisecond) + } + + // Let some values expire + time.Sleep(500 * time.Millisecond) + b.ResetTimer() + + // Add a lot of new ones + for i := range b.N { + tc.Add(&i) + } + }) + + b.Run("custom key", func(b *testing.B) { + // the same size as `messageKey` + type key [10]uint64 + + tc := timecache.New[key](5000, 3*time.Second) + for i := range b.N { + key := key{} + key[0] = uint64(i) + tc.Add(&key) + } + }) +} + +func BenchmarkTimeCacheGet(b *testing.B) { + b.Run("empty cache", func(b *testing.B) { + tc := timecache.New[int](100, 3*time.Second) + for i := range b.N { + tc.Get(&i) + } + }) + + b.Run("full unexpired cache", func(b *testing.B) { + const size = 10000 + tc := timecache.New[int](size, 1*time.Hour) + for i := range size { + tc.Add(&i) + } + b.ResetTimer() + + for i := range b.N { + key := i % size + tc.Get(&key) + } + }) + + b.Run("full with some expired values", func(b *testing.B) { + const size = 2000 + tc := timecache.New[int](size, 2*time.Second) + for i := range size / 2 { + tc.Add(&i) + time.Sleep(time.Millisecond) + } + time.Sleep(200 * time.Millisecond) + + b.ResetTimer() + + for i := range b.N { + key := i % size + tc.Get(&key) + } + }) + + b.Run("while additions are ocurring at the same time", func(b *testing.B) { + const size = 100 + tc := timecache.New[int64](size, 100*time.Millisecond) + + var insertedKeys atomic.Int64 + // random val != 0 + insertedKeys.Store(13) + + done := make(chan struct{}) + go func() { + var key int64 + for { + select { + case <-done: + return + case <-time.After(time.Duration((rand.IntN(10))+1) * time.Millisecond): + tc.Add(&key) + key += 1 + insertedKeys.Store(key) + } + } + }() + + b.ResetTimer() + for range b.N { + n := insertedKeys.Load() + + key := rand.Int64N(n) + tc.Get(&key) + } + + b.StopTimer() + close(done) + }) +} diff --git a/consensus/propeller/timecache/timecache_test.go b/consensus/propeller/timecache/timecache_test.go new file mode 100644 index 0000000000..0f858f4fc9 --- /dev/null +++ b/consensus/propeller/timecache/timecache_test.go @@ -0,0 +1,165 @@ +package timecache_test + +import ( + "math/rand/v2" + "sync" + "testing" + "time" + + "github.com/NethermindEth/juno/consensus/propeller/timecache" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestTimeCacheSequentially(t *testing.T) { + t.Parallel() + t.Run("correctly expires keys", func(t *testing.T) { + t.Parallel() + + const expiry = 3 * time.Second + tc := timecache.New[int](3, expiry) + + key := 3 + tc.Add(&key) + + require.True(t, tc.Get(&key), "key should exists while it hasn't expired") + + time.Sleep(expiry + 1*time.Second) + + require.False(t, tc.Get(&key), "key shouldn't exist after expiration window") + }) + + t.Run("correctly increases in size the cache's exceeds orignal size", func(t *testing.T) { + t.Parallel() + + const size = 3 + const expiry = 3 * time.Second + + // Fill the time cache with data to it's maximum size + tc := timecache.New[int](size, expiry) + for i := range size { + tc.Add(&i) + } + for i := range size { + require.Truef(t, tc.Get(&i), "key %d should still exist", i) + } + + // Add more element and check that old and new keys both exist in the cache + time.Sleep(time.Second) + for i := range size { + newKey := i + size + tc.Add(&newKey) + require.Truef(t, tc.Get(&i), "key %d should exist", i) + require.Truef(t, tc.Get(&newKey), "new key %d should also exist ", newKey) + } + + // Wait for old keys to expire + time.Sleep(2*time.Second + 200*time.Millisecond) + for i := range size { + newKey := i + size + require.Falsef(t, tc.Get(&i), "key %d shouldn't exist", i) + require.Truef(t, tc.Get(&newKey), "new key %d should exist ", newKey) + } + + // Add back the initial keys and check that they can both are being held + for i := range size { + newKey := i + size + tc.Add(&i) + require.Truef(t, tc.Get(&i), "key %d should exist again", i) + require.Truef(t, tc.Get(&newKey), "new key %d should still exist ", newKey) + } + }) +} + +func TestTimeCacheConcurrently(t *testing.T) { + const size = 100 + const expiry = 1 * time.Second + + tc := timecache.New[int](size, expiry) + + // Go-routine A will send non stop elements on an interval for 3 seconds + // Go-routine B will check for the elements right after + // Go-routine C will check for the elements after expiry time + + fastCheckCh := make(chan int) + + type timedInt struct { + val int + time time.Time + } + slowCheckCh := make(chan timedInt) + + var wg sync.WaitGroup + + // Go-routine A + const sendPeriod = expiry * 3 + wg.Go(func() { + key := 0 + timeout := time.After(sendPeriod) + for { + select { + case <-timeout: + close(fastCheckCh) + close(slowCheckCh) + return + case <-time.After(time.Duration(rand.IntN(250)+1) * time.Millisecond): + tc.Add(&key) + fastCheckCh <- key + slowCheckCh <- timedInt{val: key, time: time.Now()} + key += 1 + } + } + }) + + // Go-routine B + wg.Go(func() { + for v := range fastCheckCh { + assert.Truef(t, tc.Get(&v), "value %d should still exists", v) + } + }) + + // Go-routine C + wg.Go(func() { + valToReview := 0 + valsToCheck := make([]timedInt, 0, 100) + waitDuration := time.Now().Add(sendPeriod * 2) + + for { + select { + case timedInt, ok := <-slowCheckCh: + if !ok { + if valToReview == len(valsToCheck) { + // This shouldn't happen, as with the current config after sending + // is finished there should be more values to check + return + } + slowCheckCh = nil + continue + } + valsToCheck = append(valsToCheck, timedInt) + if valToReview == len(valsToCheck)-1 { + waitDuration = valsToCheck[valToReview].time.Add(expiry) + } + case <-time.After(time.Until(waitDuration)): + val := valsToCheck[valToReview] + assert.Falsef(t, tc.Get(&val.val), "value %d shouldn't exist", val.val) + + valToReview += 1 + if valToReview == len(valsToCheck) { + if slowCheckCh == nil { + // all values have been received and checked + return + } + // all values have been checked but there are still more to receive + // Set a long enough wait duration to avoid triggering this until + // a new new element is received + waitDuration = time.Now().Add(sendPeriod * 2) + continue + } + waitDuration = valsToCheck[valToReview].time.Add(expiry) + } + } + }) + + wg.Wait() +} diff --git a/consensus/propeller/types.go b/consensus/propeller/types.go new file mode 100644 index 0000000000..63b37c6093 --- /dev/null +++ b/consensus/propeller/types.go @@ -0,0 +1,271 @@ +// Package propeller implements an erasure-coding based message broadcast protocol +// for Byzantine fault-tolerant consensus. A publisher splits a message into shards, +// erasure-encodes them via Reed-Solomon, and distributes one shard per peer. +// Any peer can reconstruct the full message from a threshold number of shards, +// then forwards its own assigned shard to all others. +// +// The protocol tolerates up to f = floor((N-1)/3) Byzantine faulty nodes. +package propeller + +import ( + "fmt" + "time" + + "github.com/libp2p/go-libp2p/core/peer" + "github.com/libp2p/go-libp2p/core/protocol" +) + +// Config holds tunable parameters for the propeller engine. Sensible defaults +// are provided by DefaultConfig(). +type Config struct { + // StaleMessageTimeout is how long the engine waits for a message to + // reach the receive threshold before giving up. This prevents memory + // leaks from partially-received messages that will never complete + // (e.g., due to a crashed publisher or network partition). + StaleMessageTimeout time.Duration + + // StreamProtocol is the libp2p protocol identifier used for direct + // shard transfers between peers. + StreamProtocol protocol.ID + + // MaxWireMessageSize caps the size of a single serialised PropellerUnit + // on the wire. Units exceeding this are rejected to prevent memory + // exhaustion from malicious peers. + MaxWireMessageSize int +} + +// DefaultConfig returns production-ready defaults. +func DefaultConfig() Config { + return Config{ + StaleMessageTimeout: 120 * time.Second, + StreamProtocol: "/propeller/0.1.0", + MaxWireMessageSize: 1 << 20, // 1 MiB + } +} + +// --------------------------------------------------------------------------- +// Events: structured outputs from the engine to the application layer. +// Each event is emitted at most once per message lifecycle. +// --------------------------------------------------------------------------- + +// EventMessageReceived signals that a message has been fully reconstructed +// and enough shards have been forwarded to guarantee delivery to all honest +// nodes. The application can safely process the contained message bytes. +type EventMessageReceived struct { + Publisher peer.ID + Root MessageRoot + Message []byte +} + +// EventReconstructionFailed signals that Reed-Solomon reconstruction or +// post-reconstruction verification failed. This typically indicates Byzantine +// behaviour from the publisher (e.g., inconsistent shards). +type EventReconstructionFailed struct { + Root MessageRoot + Publisher peer.ID + Err error +} + +// EventShardPublishFailed signals that the local node failed to encode or +// distribute shards when acting as publisher. +type EventShardPublishFailed struct { + Err error +} + +// EventShardSendFailed signals that sending a single shard to a specific +// peer failed. The engine continues sending to other peers; this is +// informational for monitoring. +type EventShardSendFailed struct { + From peer.ID + To peer.ID + Err error +} + +// EventShardValidationFailed signals that an incoming shard was rejected +// during validation. This may indicate Byzantine behaviour from the sender +// or publisher. +type EventShardValidationFailed struct { + Sender peer.ID + ClaimedRoot MessageRoot + ClaimedPublisher peer.ID + Err error +} + +// EventMessageTimeout signals that a message did not reach the receive +// threshold before the stale message timeout elapsed. The engine cleans +// up state for this message. +type EventMessageTimeout struct { + Channel CommitteeID + Publisher peer.ID + Root MessageRoot +} + +// reasonUnknown is the string representation for unrecognised enum values. +// Extracted as a constant to satisfy goconst. +const reasonUnknown = "unknown" + +// --------------------------------------------------------------------------- +// Error types: structured errors for each failure domain. +// Using typed errors rather than sentinel values lets callers inspect the +// specific failure reason programmatically. +// --------------------------------------------------------------------------- + +// ShardValidationReason enumerates the specific causes of shard rejection. +type ShardValidationReason int + +const ( + // ReasonSelfSending means a peer sent us a unit claiming to be from us. + ReasonSelfSending ShardValidationReason = iota + // ReasonReceivedSelfPublishedShard means we received a shard for a + // message we published ourselves -- we already have all shards. + ReasonReceivedSelfPublishedShard + // ReasonDuplicateShard means we already have a shard at this index + // for this message. + ReasonDuplicateShard + // ReasonUnexpectedSender means the sender is not the peer assigned + // to broadcast this shard index. + ReasonUnexpectedSender + // ReasonSignatureVerificationFailed means the publisher's signature + // over the Merkle root did not verify. + ReasonSignatureVerificationFailed + // ReasonMerkleProofVerificationFailed means the Merkle inclusion + // proof for this shard is invalid. + ReasonMerkleProofVerificationFailed + // ReasonScheduleError means the shard-to-peer mapping lookup failed + // (e.g., publisher not in the channel's peer set). + ReasonScheduleError +) + +func (r ShardValidationReason) String() string { + switch r { + case ReasonSelfSending: + return "self_sending" + case ReasonReceivedSelfPublishedShard: + return "received_self_published_shard" + case ReasonDuplicateShard: + return "duplicate_shard" + case ReasonUnexpectedSender: + return "unexpected_sender" + case ReasonSignatureVerificationFailed: + return "signature_verification_failed" + case ReasonMerkleProofVerificationFailed: + return "merkle_proof_verification_failed" + case ReasonScheduleError: + return "schedule_error" + default: + return reasonUnknown + } +} + +// ShardValidationError is returned when an incoming PropellerUnit fails +// validation. The Reason field allows programmatic inspection; the Detail +// field carries human-readable context. +type ShardValidationError struct { + Reason ShardValidationReason + Detail string +} + +func (e *ShardValidationError) Error() string { + return fmt.Sprintf("shard validation failed (%s): %s", e.Reason, e.Detail) +} + +// ReconstructionReason enumerates the specific causes of reconstruction failure. +type ReconstructionReason int + +const ( + // ReasonErasureReconstructionFailed means Reed-Solomon decoding failed, + // likely because too many shards are missing or corrupted. + ReasonErasureReconstructionFailed ReconstructionReason = iota + // ReasonMismatchedMessageRoot means the Merkle root computed from the + // reconstructed shards does not match the claimed root. This indicates + // Byzantine behaviour from the publisher. + ReasonMismatchedMessageRoot + // ReasonUnequalShardLengths means shards have inconsistent lengths, + // which violates Reed-Solomon's equal-length requirement. + ReasonUnequalShardLengths + // ReasonMessagePaddingError means the varint length prefix in the + // unpadded message is malformed or points beyond the data. + ReasonMessagePaddingError +) + +func (r ReconstructionReason) String() string { + switch r { + case ReasonErasureReconstructionFailed: + return "erasure_reconstruction_failed" + case ReasonMismatchedMessageRoot: + return "mismatched_message_root" + case ReasonUnequalShardLengths: + return "unequal_shard_lengths" + case ReasonMessagePaddingError: + return "message_padding_error" + default: + return reasonUnknown + } +} + +// ReconstructionError is returned when message reconstruction fails after +// collecting enough shards. +type ReconstructionError struct { + Reason ReconstructionReason + Detail string +} + +func (e *ReconstructionError) Error() string { + return fmt.Sprintf("reconstruction failed (%s): %s", e.Reason, e.Detail) +} + +// ShardPublishReason enumerates the specific causes of publish failure. +type ShardPublishReason int + +const ( + // ReasonLocalPeerNotInChannel means the local peer is not a member + // of the channel it is trying to broadcast on. + ReasonLocalPeerNotInChannel ShardPublishReason = iota + // ReasonInvalidDataSize means the message is too large to encode. + ReasonInvalidDataSize + // ReasonSigningFailed means the local private key failed to sign. + ReasonSigningFailed + // ReasonEncodingFailed means Reed-Solomon encoding failed. + ReasonEncodingFailed + // ReasonNotConnectedToPeer means we have no open connection to a + // target peer. + ReasonNotConnectedToPeer + // ReasonChannelNotRegistered means the channel has not been registered + // with the engine. + ReasonChannelNotRegistered + // ReasonBroadcastFailed means the broadcast operation failed for an + // unspecified reason. + ReasonBroadcastFailed +) + +func (r ShardPublishReason) String() string { + switch r { + case ReasonLocalPeerNotInChannel: + return "local_peer_not_in_channel" + case ReasonInvalidDataSize: + return "invalid_data_size" + case ReasonSigningFailed: + return "signing_failed" + case ReasonEncodingFailed: + return "encoding_failed" + case ReasonNotConnectedToPeer: + return "not_connected_to_peer" + case ReasonChannelNotRegistered: + return "channel_not_registered" + case ReasonBroadcastFailed: + return "broadcast_failed" + default: + return reasonUnknown + } +} + +// todo(rdr): check if we want to do this. I think it is better not, unless necessary +// ShardPublishError is returned when the local node fails to publish shards. +type ShardPublishError struct { + Reason ShardPublishReason + Detail string +} + +func (e *ShardPublishError) Error() string { + return fmt.Sprintf("shard publish failed (%s): %s", e.Reason, e.Detail) +} diff --git a/consensus/propeller/unit.go b/consensus/propeller/unit.go new file mode 100644 index 0000000000..1c4e75bbaf --- /dev/null +++ b/consensus/propeller/unit.go @@ -0,0 +1,130 @@ +package propeller + +import ( + "errors" + "time" + + "github.com/NethermindEth/juno/consensus/propeller/merkle" + pb "github.com/NethermindEth/juno/consensus/propeller/proto" + "github.com/libp2p/go-libp2p/core/peer" + "github.com/starknet-io/starknet-p2p-specs/p2p/proto/common" + "google.golang.org/protobuf/proto" +) + +// CommitteeID identifies a committee or logical broadcast group. Multiple committees +// can operate concurrently within the same engine, each with its own peer set. +type CommitteeID [32]byte + +// MessageRoot is the SHA-256 Merkle root over all shard leaves. It uniquely +// identifies a message and is signed by the publisher to bind authenticity. +type MessageRoot merkle.Hash + +// The actual shard fragment +type Shard []byte + +// Set of shard fragments held by the Propeller Unit +type ShardData []Shard + +// ShardIndex is the position of a shard within the erasure-coded output. +// Valid range is [0, N-2] where N is the total number of peers. +type ShardIndex uint32 + +func (sd ShardData) MarshalProto() []byte { + shards := make([]*pb.Shard, len(sd)) + for i, s := range sd { + shards[i] = &pb.Shard{Data: s} + } + // We ignore the error because this data has already been converted and it is expected + // to be correct. + res, _ := proto.Marshal(&pb.ShardsOfPeer{Shards: shards}) + return res +} + +// Propeller Unit Signature +type Signature []byte + +// Propeller Unit Nonce +type Nonce time.Duration + +// Unit is the atomic wire message: one erasure-coded shard plus +// the metadata needed for independent verification. Each unit is self-contained +// so a receiver can validate it without any other shards. +type Unit struct { + CommitteeID CommitteeID // Which committee this belongs to + Publisher peer.ID // Original message author + MessageRoot MessageRoot // Merkle root binding all shards together + MerkleProof merkle.Proof // Merkle inclusion proof for this shard + Signature Signature // Publisher's Ed25519 signature over the root + ShardIndex ShardIndex // This shard's position in the erasure-coded output + ShardData ShardData // + // todo(rdr): calling it nonce because that's what is called on the rust side but + // time stamp or some other name would be better + Nonce Nonce // Strictly increasing number, starting from the Unix epoch +} + +func UnitFromProto(protoUnit *pb.PropellerUnit) (Unit, error) { + shards := make(ShardData, len(protoUnit.Shards.GetShards())) + for i, s := range protoUnit.Shards.GetShards() { + shards[i] = Shard(s.Data) + } + + // validate that all shard length is the same + // todo(rdr): What other validations should I do? + // todo(rdr): Should I do these validations here? + shardLen := len(shards[0]) + for i := range shards[1:] { + if len(shards[i]) != shardLen { + return Unit{}, errors.New("unit has shards of different length") + } + } + + siblings := make([]merkle.Hash, len(protoUnit.MerkleProof.GetSiblings())) + for i, s := range protoUnit.MerkleProof.GetSiblings() { + copy(siblings[i][:], s.Elements) + } + + return Unit{ + CommitteeID: committeeIDFromBytes(protoUnit.CommitteeId.GetElements()), + Publisher: peer.ID(protoUnit.Publisher.GetId()), + MessageRoot: MessageRoot(protoUnit.MerkleRoot.GetElements()), + MerkleProof: merkle.Proof{Siblings: siblings}, + Signature: protoUnit.Signature, + ShardIndex: ShardIndex(protoUnit.Index), + ShardData: shards, + Nonce: Nonce(time.Duration(protoUnit.Nonce)), + }, nil +} + +func (u *Unit) ToProto() *pb.PropellerUnit { + protoShards := make([]*pb.Shard, len(u.ShardData)) + for i, s := range u.ShardData { + protoShards[i] = &pb.Shard{Data: s} + } + + siblings := make([]*common.Hash256, len(u.MerkleProof.Siblings)) + for i, s := range u.MerkleProof.Siblings { + siblings[i] = &common.Hash256{Elements: s[:]} + } + + root := merkle.Hash(u.MessageRoot) + return &pb.PropellerUnit{ + Shards: &pb.ShardsOfPeer{Shards: protoShards}, + Index: uint64(u.ShardIndex), + MerkleRoot: &common.Hash256{Elements: root[:]}, + MerkleProof: &pb.MerkleProof{Siblings: siblings}, + Publisher: &common.PeerID{Id: []byte(u.Publisher)}, + Signature: u.Signature, + CommitteeId: &common.Hash256{Elements: committeeIDToBytes(u.CommitteeID)}, + Nonce: uint64(u.Nonce), + } +} + +func committeeIDFromBytes(b []byte) CommitteeID { + var id CommitteeID + copy(id[:], b) + return id +} + +func committeeIDToBytes(id CommitteeID) []byte { + return id[:] +} diff --git a/consensus/propeller/unit_test.go b/consensus/propeller/unit_test.go new file mode 100644 index 0000000000..bb61865bd4 --- /dev/null +++ b/consensus/propeller/unit_test.go @@ -0,0 +1 @@ +package propeller_test diff --git a/consensus/propeller/unit_validator.go b/consensus/propeller/unit_validator.go new file mode 100644 index 0000000000..695e8beff9 --- /dev/null +++ b/consensus/propeller/unit_validator.go @@ -0,0 +1,117 @@ +package propeller + +import ( + "bytes" + "errors" + "fmt" + + "github.com/NethermindEth/juno/consensus/propeller/merkle" + "github.com/libp2p/go-libp2p/core/crypto" + "github.com/libp2p/go-libp2p/core/peer" +) + +// todo(rdr): A validator lifetime is attached to a `subprocessor`. A `subprocessor` is attached +// to a message key field. This logic is handled by a `Processor`. This means that a validator will +// always be given units that have the same committeeID, publisher, messageRoot and Nonce (the +// current fields of a `messageKey`). Does it makes sense for the validator to also hold a copy +// of this. Is there a way of testing this invariant – where a validator only sees the same +// fields. I need to add a test for that invariant + +// Validates all the incoming units / shards given a committee and the publisher +type UnitValidator struct { + publisherPubKey crypto.PubKey + scheduler *Scheduler + + // todo(rdr): `receivedShards` can surely be an boolean array (cheaper than map) + // track of every shard index received + receivedShards map[ShardIndex]struct{} + // Once the validation is done it's stored here, subsequent validation + // compare against it + verifiedSignature Signature +} + +func NewValidator(publisher peer.ID, scheduler *Scheduler) UnitValidator { + // for now we are assuming that extracting a publisher key is always successful + // and done in constant time + pubKey, err := publisher.ExtractPublicKey() + if err != nil { + panic(err) + } + return UnitValidator{ + publisherPubKey: pubKey, + scheduler: scheduler, + receivedShards: make(map[ShardIndex]struct{}, scheduler.NumDataShards()), + verifiedSignature: nil, + } +} + +func (v *UnitValidator) verifyDataShards(unit *Unit) error { + if len(unit.ShardData) != 1 { + return fmt.Errorf( + "unexpected amount of shards. Expected %d. Received %d", + 1, + len(unit.ShardData), + ) + } + + proof := unit.MerkleProof + root := merkle.Hash(unit.MessageRoot) + // We marshal to Proto bytes to make the verification language agnostic + if proof.Verify(&root, unit.ShardData.MarshalProto(), uint32(unit.ShardIndex)) { + return nil + } + + return errors.New("data shards verification failed") +} + +func (v *UnitValidator) verifySignature(unit *Unit) error { + if v.verifiedSignature != nil { + if bytes.Equal(v.verifiedSignature, unit.Signature) { + return nil + } + // todo(rdr): make sure this error is readable + return fmt.Errorf( + "signature missmatch. Expected: %v. Received %v", + v.verifiedSignature, + unit.Signature, + ) + } + + err := VerifyMessageSignature( + v.publisherPubKey, + &unit.MessageRoot, + &unit.CommitteeID, + unit.Nonce, + unit.Signature, + ) + if err != nil { + return fmt.Errorf("failed message signature verification: %w", err) + } + + v.verifiedSignature = unit.Signature + return nil +} + +func (v *UnitValidator) Validate(unit *Unit, sender peer.ID) error { + if _, ok := v.receivedShards[unit.ShardIndex]; ok { + return fmt.Errorf("duplicated shard %d received", unit.ShardIndex) + } + + err := v.scheduler.ValidateShardOrigin(sender, unit.Publisher, unit.ShardIndex) + if err != nil { + return err + } + + if err = v.verifyDataShards(unit); err != nil { + return err + } + + if err = v.verifySignature(unit); err != nil { + return err + } + + // Store the verified shard to avoid re-verification + v.receivedShards[unit.ShardIndex] = struct{}{} + + return nil +} diff --git a/consensus/propeller/unit_validator_test.go b/consensus/propeller/unit_validator_test.go new file mode 100644 index 0000000000..bb61865bd4 --- /dev/null +++ b/consensus/propeller/unit_validator_test.go @@ -0,0 +1 @@ +package propeller_test diff --git a/go.mod b/go.mod index eca739c819..749aa1926d 100644 --- a/go.mod +++ b/go.mod @@ -15,6 +15,7 @@ require ( github.com/ethereum/go-ethereum v1.17.2 github.com/fxamacker/cbor/v2 v2.9.1 github.com/go-playground/validator/v10 v10.30.2 + github.com/klauspost/reedsolomon v1.14.0 github.com/libp2p/go-libp2p v0.48.0 github.com/libp2p/go-libp2p-kad-dht v0.39.1 github.com/libp2p/go-libp2p-pubsub v0.16.0 diff --git a/go.sum b/go.sum index eb2ec50c55..0f81defce2 100644 --- a/go.sum +++ b/go.sum @@ -379,6 +379,8 @@ github.com/klauspost/compress v1.18.4 h1:RPhnKRAQ4Fh8zU2FY/6ZFDwTVTxgJ/EMydqSTzE github.com/klauspost/compress v1.18.4/go.mod h1:R0h/fSBs8DE4ENlcrlib3PsXS61voFxhIs2DeRhCvJ4= github.com/klauspost/cpuid/v2 v2.3.0 h1:S4CRMLnYUhGeDFDqkGriYKdfoFlDnMtqTiI/sFzhA9Y= github.com/klauspost/cpuid/v2 v2.3.0/go.mod h1:hqwkgyIinND0mEev00jJYCxPNVRVXFQeu1XKlok6oO0= +github.com/klauspost/reedsolomon v1.14.0 h1:5YSZeclzSYg5nl349+GDG/agDtQ6MZiwUYXvVKN1Jx0= +github.com/klauspost/reedsolomon v1.14.0/go.mod h1:yjqqjgMTQkBUHSG97/rm4zipffCNbCiZcB3kTqr++sQ= github.com/koron/go-ssdp v0.1.0 h1:ckl5x5H6qSNFmi+wCuROvvGUu2FQnMbQrU95IHCcv3Y= github.com/koron/go-ssdp v0.1.0/go.mod h1:GltaDBjtK1kemZOusWYLGotV0kBeEf59Bp0wtSB0uyU= github.com/kr/fs v0.1.0/go.mod h1:FFnZGqtBN9Gxj7eW1uZ42v5BccTP0vu6NEaFoC2HwRg= diff --git a/p2p/pubsub/pubsub.go b/p2p/pubsub/pubsub.go index a2ee1f1419..7a2f98a7ff 100644 --- a/p2p/pubsub/pubsub.go +++ b/p2p/pubsub/pubsub.go @@ -15,8 +15,6 @@ import ( "github.com/libp2p/go-libp2p/p2p/discovery/routing" ) -const gossipSubHistory = 60 - func GetHost(hostPrivateKey crypto.PrivKey, hostAddress string) (host.Host, error) { return libp2p.New( libp2p.ListenAddrStrings(hostAddress), @@ -49,6 +47,8 @@ func Run( } params := pubsub.DefaultGossipSubParams() + + const gossipSubHistory = 60 params.HistoryLength = gossipSubHistory params.HistoryGossip = gossipSubHistory diff --git a/p2p/server/server.go b/p2p/server/server.go index d8f4621646..6c6301e873 100644 --- a/p2p/server/server.go +++ b/p2p/server/server.go @@ -216,53 +216,60 @@ func (h *Server) onHeadersRequest( HeaderMessage: &header.BlockHeadersResponse_Fin{}, } - return h.processIterationRequest(req.Iteration, finMsg, func(it blockDataAccessor) (proto.Message, error) { - blockHeader, err := it.Header() - if err != nil { - return nil, err - } - - h.logger.Debug("Created Header Iterator", zap.Uint64("blockNumber", blockHeader.Number)) - - stateUpdate, err := h.bcReader.StateUpdateByNumber(blockHeader.Number) - if err != nil { - return nil, err - } - - blockVer, err := core.ParseBlockVersion(blockHeader.ProtocolVersion) - if err != nil { - return nil, err - } - - var commitments *core.BlockCommitments - if blockVer.LessThan(core.Ver0_13_2) { - block, err := it.Block() + return h.processIterationRequest( + req.Iteration, + finMsg, + func(it blockDataAccessor) (proto.Message, error) { + blockHeader, err := it.Header() if err != nil { return nil, err } - // TODO: switch to core.NewTrieBackend once the legacy trie and state are removed. - _, commitments, err = core.Post0132Hash(block, stateUpdate.StateDiff, core.DeprecatedTrieBackend) + + h.logger.Debug("Created Header Iterator", zap.Uint64("blockNumber", blockHeader.Number)) + + stateUpdate, err := h.bcReader.StateUpdateByNumber(blockHeader.Number) if err != nil { return nil, err } - } else { - commitments, err = h.bcReader.BlockCommitmentsByNumber(blockHeader.Number) + + blockVer, err := core.ParseBlockVersion(blockHeader.ProtocolVersion) if err != nil { return nil, err } - } - stateDiffCommitment := stateUpdate.StateDiff.Hash() - return &header.BlockHeadersResponse{ - HeaderMessage: &header.BlockHeadersResponse_Header{ - Header: core2p2p.AdaptHeader( - blockHeader, - commitments, - &stateDiffCommitment, - stateUpdate.StateDiff.Length()), - }, - }, nil - }) + var commitments *core.BlockCommitments + if blockVer.LessThan(core.Ver0_13_2) { + block, err := it.Block() + if err != nil { + return nil, err + } + // TODO: switch to core.NewTrieBackend once the legacy trie and state are removed. + _, commitments, err = core.Post0132Hash( + block, stateUpdate.StateDiff, core.DeprecatedTrieBackend, + ) + if err != nil { + return nil, err + } + } else { + commitments, err = h.bcReader.BlockCommitmentsByNumber(blockHeader.Number) + if err != nil { + return nil, err + } + } + + stateDiffCommitment := stateUpdate.StateDiff.Hash() + return &header.BlockHeadersResponse{ + HeaderMessage: &header.BlockHeadersResponse_Header{ + Header: core2p2p.AdaptHeader( + blockHeader, + commitments, + &stateDiffCommitment, + stateUpdate.StateDiff.Length(), + ), + }, + }, nil + }, + ) } func (h *Server) onEventsRequest( @@ -329,114 +336,120 @@ func (h *Server) onStateDiffRequest( finMsg := &state.StateDiffsResponse{ StateDiffMessage: &state.StateDiffsResponse_Fin{}, } - return h.processIterationRequestMulti(req.Iteration, finMsg, func(it blockDataAccessor) ([]proto.Message, error) { - block, err := it.Block() - if err != nil { - return nil, err - } - blockNumber := block.Number + return h.processIterationRequestMulti( + req.Iteration, + finMsg, + func(it blockDataAccessor) ([]proto.Message, error) { + block, err := it.Block() + if err != nil { + return nil, err + } + blockNumber := block.Number - stateUpdate, err := h.bcReader.StateUpdateByNumber(blockNumber) - if err != nil { - return nil, err - } - diff := stateUpdate.StateDiff + stateUpdate, err := h.bcReader.StateUpdateByNumber(blockNumber) + if err != nil { + return nil, err + } + diff := stateUpdate.StateDiff - type contractDiff struct { - address *felt.Felt - storageDiffs map[felt.Felt]*felt.Felt - nonce *felt.Felt - classHash *felt.Felt // set only if contract deployed or replaced - } - modifiedContracts := make(map[felt.Felt]*contractDiff) + type contractDiff struct { + address *felt.Felt + storageDiffs map[felt.Felt]*felt.Felt + nonce *felt.Felt + classHash *felt.Felt // set only if contract deployed or replaced + } + modifiedContracts := make(map[felt.Felt]*contractDiff) - initContractDiff := func(addr *felt.Felt) *contractDiff { - return &contractDiff{address: addr} - } - updateModifiedContracts := func(addr felt.Felt, f func(*contractDiff)) error { - cDiff, ok := modifiedContracts[addr] - if !ok { - cDiff = initContractDiff(&addr) - if err != nil { - return err - } - modifiedContracts[addr] = cDiff + initContractDiff := func(addr *felt.Felt) *contractDiff { + return &contractDiff{address: addr} } + updateModifiedContracts := func(addr felt.Felt, f func(*contractDiff)) error { + cDiff, ok := modifiedContracts[addr] + if !ok { + cDiff = initContractDiff(&addr) + if err != nil { + return err + } + modifiedContracts[addr] = cDiff + } - f(cDiff) - return nil - } + f(cDiff) + return nil + } - for addr, n := range diff.Nonces { - err = updateModifiedContracts(addr, func(diff *contractDiff) { - diff.nonce = n - }) - if err != nil { - return nil, err + for addr, n := range diff.Nonces { + err = updateModifiedContracts(addr, func(diff *contractDiff) { + diff.nonce = n + }) + if err != nil { + return nil, err + } } - } - for addr, sDiff := range diff.StorageDiffs { - err = updateModifiedContracts(addr, func(diff *contractDiff) { - diff.storageDiffs = sDiff - }) - if err != nil { - return nil, err + for addr, sDiff := range diff.StorageDiffs { + err = updateModifiedContracts(addr, func(diff *contractDiff) { + diff.storageDiffs = sDiff + }) + if err != nil { + return nil, err + } } - } - for addr, classHash := range diff.DeployedContracts { - classHashCopy := classHash - err = updateModifiedContracts(addr, func(diff *contractDiff) { - diff.classHash = classHashCopy - }) - if err != nil { - return nil, err + for addr, classHash := range diff.DeployedContracts { + classHashCopy := classHash + err = updateModifiedContracts(addr, func(diff *contractDiff) { + diff.classHash = classHashCopy + }) + if err != nil { + return nil, err + } } - } - for addr, classHash := range diff.ReplacedClasses { - classHashCopy := classHash - err = updateModifiedContracts(addr, func(diff *contractDiff) { - diff.classHash = classHashCopy - }) - if err != nil { - return nil, err + for addr, classHash := range diff.ReplacedClasses { + classHashCopy := classHash + err = updateModifiedContracts(addr, func(diff *contractDiff) { + diff.classHash = classHashCopy + }) + if err != nil { + return nil, err + } } - } - var responses []proto.Message - for _, c := range modifiedContracts { - responses = append(responses, &state.StateDiffsResponse{ - StateDiffMessage: &state.StateDiffsResponse_ContractDiff{ - ContractDiff: core2p2p.AdaptContractDiff(c.address, c.nonce, c.classHash, c.storageDiffs), - }, - }) - } + var responses []proto.Message + for _, c := range modifiedContracts { + responses = append(responses, &state.StateDiffsResponse{ + StateDiffMessage: &state.StateDiffsResponse_ContractDiff{ + ContractDiff: core2p2p.AdaptContractDiff( + c.address, c.nonce, c.classHash, c.storageDiffs, + ), + }, + }) + } - for _, classHash := range diff.DeclaredV0Classes { - responses = append(responses, &state.StateDiffsResponse{ - StateDiffMessage: &state.StateDiffsResponse_DeclaredClass{ - DeclaredClass: &state.DeclaredClass{ - ClassHash: core2p2p.AdaptHash(classHash), - CompiledClassHash: nil, // for cairo0 it's nil + for _, classHash := range diff.DeclaredV0Classes { + responses = append(responses, &state.StateDiffsResponse{ + StateDiffMessage: &state.StateDiffsResponse_DeclaredClass{ + DeclaredClass: &state.DeclaredClass{ + ClassHash: core2p2p.AdaptHash(classHash), + CompiledClassHash: nil, // for cairo0 it's nil + }, }, - }, - }) - } - for classHash, compiledHash := range diff.DeclaredV1Classes { - responses = append(responses, &state.StateDiffsResponse{ - StateDiffMessage: &state.StateDiffsResponse_DeclaredClass{ - DeclaredClass: &state.DeclaredClass{ - ClassHash: core2p2p.AdaptHash(&classHash), - CompiledClassHash: core2p2p.AdaptHash(compiledHash), + }) + } + for classHash, compiledHash := range diff.DeclaredV1Classes { + responses = append(responses, &state.StateDiffsResponse{ + StateDiffMessage: &state.StateDiffsResponse_DeclaredClass{ + DeclaredClass: &state.DeclaredClass{ + ClassHash: core2p2p.AdaptHash(&classHash), + CompiledClassHash: core2p2p.AdaptHash(compiledHash), + }, }, - }, - }) - } + }) + } - return responses, nil - }) + return responses, nil + }, + ) } func (h *Server) onClassesRequest( @@ -445,58 +458,62 @@ func (h *Server) onClassesRequest( finMsg := &syncclass.ClassesResponse{ ClassMessage: &syncclass.ClassesResponse_Fin{}, } - return h.processIterationRequestMulti(req.Iteration, finMsg, func(it blockDataAccessor) ([]proto.Message, error) { - block, err := it.Block() - if err != nil { - return nil, err - } - blockNumber := block.Number - - stateUpdate, err := h.bcReader.StateUpdateByNumber(blockNumber) - if err != nil { - return nil, err - } - - stateReader, closer, err := h.bcReader.StateAtBlockNumber(blockNumber) - if err != nil { - return nil, err - } - defer func() { - if closeErr := closer(); closeErr != nil { - h.logger.Error("Failed to close state reader", zap.Error(closeErr)) + return h.processIterationRequestMulti( + req.Iteration, + finMsg, + func(it blockDataAccessor) ([]proto.Message, error) { + block, err := it.Block() + if err != nil { + return nil, err } - }() + blockNumber := block.Number - stateDiff := stateUpdate.StateDiff - - var responses []proto.Message - for _, hash := range stateDiff.DeclaredV0Classes { - cls, err := stateReader.Class(hash) + stateUpdate, err := h.bcReader.StateUpdateByNumber(blockNumber) if err != nil { return nil, err } - responses = append(responses, &syncclass.ClassesResponse{ - ClassMessage: &syncclass.ClassesResponse_Class{ - Class: core2p2p.AdaptClass(cls.Class), - }, - }) - } - for classHash := range stateDiff.DeclaredV1Classes { - cls, err := stateReader.Class(&classHash) + stateReader, closer, err := h.bcReader.StateAtBlockNumber(blockNumber) if err != nil { return nil, err } + defer func() { + if closeErr := closer(); closeErr != nil { + h.logger.Error("Failed to close state reader", zap.Error(closeErr)) + } + }() - responses = append(responses, &syncclass.ClassesResponse{ - ClassMessage: &syncclass.ClassesResponse_Class{ - Class: core2p2p.AdaptClass(cls.Class), - }, - }) - } + stateDiff := stateUpdate.StateDiff - return responses, nil - }) + var responses []proto.Message + for _, hash := range stateDiff.DeclaredV0Classes { + cls, err := stateReader.Class(hash) + if err != nil { + return nil, err + } + + responses = append(responses, &syncclass.ClassesResponse{ + ClassMessage: &syncclass.ClassesResponse_Class{ + Class: core2p2p.AdaptClass(cls.Class), + }, + }) + } + for classHash := range stateDiff.DeclaredV1Classes { + cls, err := stateReader.Class(&classHash) + if err != nil { + return nil, err + } + + responses = append(responses, &syncclass.ClassesResponse{ + ClassMessage: &syncclass.ClassesResponse_Class{ + Class: core2p2p.AdaptClass(cls.Class), + }, + }) + } + + return responses, nil + }, + ) } // blockDataAccessor provides access to either entire block or header @@ -570,7 +587,8 @@ func (h *Server) processIterationRequestMulti(iteration *synccommon.Iteration, f return func(yield yieldFunc) { // while iterator is valid for it.Valid() { - // pass it to handler function (some might be interested in header, others in entire block) + // pass it to handler function; some might be interested in header, + // others in entire block messages, err := getMsg(it) if err != nil { if !errors.Is(err, db.ErrKeyNotFound) { @@ -586,7 +604,8 @@ func (h *Server) processIterationRequestMulti(iteration *synccommon.Iteration, f for _, msg := range messages { // push generated msg to caller if !yield(msg) { - // if caller is not interested in remaining data (example: connection to a peer is closed) exit + // if caller is not interested in the remaining data, exit. + // (example: connection to a peer is closed) // note that in this case we won't send finMsg return } diff --git a/starknet-p2p-specs/p2p/proto/capabilities/capabilities.pb.go b/starknet-p2p-specs/p2p/proto/capabilities/capabilities.pb.go index d1ae793cea..374497c461 100644 --- a/starknet-p2p-specs/p2p/proto/capabilities/capabilities.pb.go +++ b/starknet-p2p-specs/p2p/proto/capabilities/capabilities.pb.go @@ -7,11 +7,12 @@ package capabilities import ( - protoreflect "google.golang.org/protobuf/reflect/protoreflect" - protoimpl "google.golang.org/protobuf/runtime/protoimpl" reflect "reflect" sync "sync" unsafe "unsafe" + + protoreflect "google.golang.org/protobuf/reflect/protoreflect" + protoimpl "google.golang.org/protobuf/runtime/protoimpl" ) const ( @@ -333,14 +334,16 @@ func file_p2p_proto_capabilities_proto_rawDescGZIP() []byte { return file_p2p_proto_capabilities_proto_rawDescData } -var file_p2p_proto_capabilities_proto_msgTypes = make([]protoimpl.MessageInfo, 5) -var file_p2p_proto_capabilities_proto_goTypes = []any{ - (*SyncCapability)(nil), // 0: SyncCapability - (*SyncCapability_ArchiveStrategy)(nil), // 1: SyncCapability.ArchiveStrategy - (*SyncCapability_L1PruneStrategy)(nil), // 2: SyncCapability.L1PruneStrategy - (*SyncCapability_ConstSizePruneStrategy)(nil), // 3: SyncCapability.ConstSizePruneStrategy - (*SyncCapability_StaticPruneStrategy)(nil), // 4: SyncCapability.StaticPruneStrategy -} +var ( + file_p2p_proto_capabilities_proto_msgTypes = make([]protoimpl.MessageInfo, 5) + file_p2p_proto_capabilities_proto_goTypes = []any{ + (*SyncCapability)(nil), // 0: SyncCapability + (*SyncCapability_ArchiveStrategy)(nil), // 1: SyncCapability.ArchiveStrategy + (*SyncCapability_L1PruneStrategy)(nil), // 2: SyncCapability.L1PruneStrategy + (*SyncCapability_ConstSizePruneStrategy)(nil), // 3: SyncCapability.ConstSizePruneStrategy + (*SyncCapability_StaticPruneStrategy)(nil), // 4: SyncCapability.StaticPruneStrategy + } +) var file_p2p_proto_capabilities_proto_depIdxs = []int32{ 1, // 0: SyncCapability.archive_strategy:type_name -> SyncCapability.ArchiveStrategy 2, // 1: SyncCapability.l1_prune_strategy:type_name -> SyncCapability.L1PruneStrategy diff --git a/starknet-p2p-specs/p2p/proto/consensus/consensus/consensus.pb.go b/starknet-p2p-specs/p2p/proto/consensus/consensus/consensus.pb.go index 6a13e77e79..10be303364 100644 --- a/starknet-p2p-specs/p2p/proto/consensus/consensus/consensus.pb.go +++ b/starknet-p2p-specs/p2p/proto/consensus/consensus/consensus.pb.go @@ -7,13 +7,14 @@ package consensus import ( + reflect "reflect" + sync "sync" + unsafe "unsafe" + common "github.com/starknet-io/starknet-p2p-specs/p2p/proto/common" transaction "github.com/starknet-io/starknet-p2p-specs/p2p/proto/transaction" protoreflect "google.golang.org/protobuf/reflect/protoreflect" protoimpl "google.golang.org/protobuf/runtime/protoimpl" - reflect "reflect" - sync "sync" - unsafe "unsafe" ) const ( @@ -1109,31 +1110,33 @@ func file_p2p_proto_consensus_consensus_proto_rawDescGZIP() []byte { return file_p2p_proto_consensus_consensus_proto_rawDescData } -var file_p2p_proto_consensus_consensus_proto_enumTypes = make([]protoimpl.EnumInfo, 1) -var file_p2p_proto_consensus_consensus_proto_msgTypes = make([]protoimpl.MessageInfo, 10) -var file_p2p_proto_consensus_consensus_proto_goTypes = []any{ - (Vote_VoteType)(0), // 0: Vote.VoteType - (*ConsensusTransaction)(nil), // 1: ConsensusTransaction - (*Vote)(nil), // 2: Vote - (*ConsensusStreamId)(nil), // 3: ConsensusStreamId - (*ProposalPart)(nil), // 4: ProposalPart - (*ProposalInit)(nil), // 5: ProposalInit - (*ProposalFin)(nil), // 6: ProposalFin - (*TransactionBatch)(nil), // 7: TransactionBatch - (*StreamMessage)(nil), // 8: StreamMessage - (*ProposalCommitment)(nil), // 9: ProposalCommitment - (*BlockInfo)(nil), // 10: BlockInfo - (*transaction.DeclareV3WithClass)(nil), // 11: DeclareV3WithClass - (*transaction.DeployAccountV3)(nil), // 12: DeployAccountV3 - (*transaction.InvokeV3)(nil), // 13: InvokeV3 - (*transaction.L1HandlerV0)(nil), // 14: L1HandlerV0 - (*common.Hash)(nil), // 15: Hash - (*common.Address)(nil), // 16: Address - (*common.Fin)(nil), // 17: Fin - (*common.Felt252)(nil), // 18: Felt252 - (*common.Uint128)(nil), // 19: Uint128 - (common.L1DataAvailabilityMode)(0), // 20: L1DataAvailabilityMode -} +var ( + file_p2p_proto_consensus_consensus_proto_enumTypes = make([]protoimpl.EnumInfo, 1) + file_p2p_proto_consensus_consensus_proto_msgTypes = make([]protoimpl.MessageInfo, 10) + file_p2p_proto_consensus_consensus_proto_goTypes = []any{ + (Vote_VoteType)(0), // 0: Vote.VoteType + (*ConsensusTransaction)(nil), // 1: ConsensusTransaction + (*Vote)(nil), // 2: Vote + (*ConsensusStreamId)(nil), // 3: ConsensusStreamId + (*ProposalPart)(nil), // 4: ProposalPart + (*ProposalInit)(nil), // 5: ProposalInit + (*ProposalFin)(nil), // 6: ProposalFin + (*TransactionBatch)(nil), // 7: TransactionBatch + (*StreamMessage)(nil), // 8: StreamMessage + (*ProposalCommitment)(nil), // 9: ProposalCommitment + (*BlockInfo)(nil), // 10: BlockInfo + (*transaction.DeclareV3WithClass)(nil), // 11: DeclareV3WithClass + (*transaction.DeployAccountV3)(nil), // 12: DeployAccountV3 + (*transaction.InvokeV3)(nil), // 13: InvokeV3 + (*transaction.L1HandlerV0)(nil), // 14: L1HandlerV0 + (*common.Hash)(nil), // 15: Hash + (*common.Address)(nil), // 16: Address + (*common.Fin)(nil), // 17: Fin + (*common.Felt252)(nil), // 18: Felt252 + (*common.Uint128)(nil), // 19: Uint128 + (common.L1DataAvailabilityMode)(0), // 20: L1DataAvailabilityMode + } +) var file_p2p_proto_consensus_consensus_proto_depIdxs = []int32{ 11, // 0: ConsensusTransaction.declare_v3:type_name -> DeclareV3WithClass 12, // 1: ConsensusTransaction.deploy_account_v3:type_name -> DeployAccountV3