From 0042a9ca2e1af82ee5dc82198042b722595aa657 Mon Sep 17 00:00:00 2001 From: Jacob Shufro Date: Wed, 10 Apr 2024 01:39:52 +0000 Subject: [PATCH] Quick stab at making watchtower tasks generic --- rocketpool-daemon/task/task.go | 77 +++++++++++++++++ .../watchtower/generate-rewards-tree.go | 82 ++++++++----------- rocketpool-daemon/watchtower/watchtower.go | 42 ++++++---- 3 files changed, 138 insertions(+), 63 deletions(-) create mode 100644 rocketpool-daemon/task/task.go diff --git a/rocketpool-daemon/task/task.go b/rocketpool-daemon/task/task.go new file mode 100644 index 000000000..739cec838 --- /dev/null +++ b/rocketpool-daemon/task/task.go @@ -0,0 +1,77 @@ +package task + +import ( + "context" + "errors" + "log/slog" + "sync" + + "github.com/rocket-pool/smartnode/v2/rocketpool-daemon/common/state" +) + +var ( + // ErrAlreadyRunning is returned when a background task is kicked off, but it is already in progress. + ErrAlreadyRunning = errors.New("task is already running") +) + +// TaskContext is passed to the Task's Callback function when the invoker wishes the task +// to be kicked off. +// +// Its fields are things that are variable and may change between invokations of a task. +type BackgroundTaskContext struct { + // A context provided by the invoker of this task. + // May be nil, and cancellations should be respected. + Ctx context.Context + // Whether or not the node is on the oDAO at the time the task was invoked + IsOnOdao bool + // A recent network state so each task need not query it redundantly + State *state.NetworkState +} + +type BackgroundTask interface { + // Returns a function to call that starts the task in the background + Run(*BackgroundTaskContext) error + // A function that tasks must call when all async portions are completed + Done() +} + +type LockingBackgroundTask struct { + logger *slog.Logger + description string + run func(*BackgroundTaskContext) error + + lock sync.Mutex + isRunning bool +} + +func NewLockingBackgroundTask(logger *slog.Logger, description string, f func(*BackgroundTaskContext) error) *LockingBackgroundTask { + return &LockingBackgroundTask{ + description: description, + logger: logger, + run: f, + } +} + +func (lbt *LockingBackgroundTask) Run(taskContext *BackgroundTaskContext) error { + lbt.lock.Lock() + defer lbt.lock.Unlock() + + lbt.logger.Info("Starting task", "description", lbt.description) + if lbt.isRunning { + lbt.logger.Info("Task is already running", "description", lbt.description) + return ErrAlreadyRunning + } + + lbt.isRunning = true + err := lbt.run(taskContext) + if err != nil { + lbt.Done() + } + return err +} + +func (lbt *LockingBackgroundTask) Done() { + lbt.lock.Lock() + defer lbt.lock.Unlock() + lbt.isRunning = false +} diff --git a/rocketpool-daemon/watchtower/generate-rewards-tree.go b/rocketpool-daemon/watchtower/generate-rewards-tree.go index c5ca59e86..12f113bf9 100644 --- a/rocketpool-daemon/watchtower/generate-rewards-tree.go +++ b/rocketpool-daemon/watchtower/generate-rewards-tree.go @@ -8,7 +8,6 @@ import ( "path/filepath" "strconv" "strings" - "sync" "time" "github.com/ethereum/go-ethereum/accounts/abi/bind" @@ -24,6 +23,7 @@ import ( rprewards "github.com/rocket-pool/smartnode/v2/rocketpool-daemon/common/rewards" "github.com/rocket-pool/smartnode/v2/rocketpool-daemon/common/services" "github.com/rocket-pool/smartnode/v2/rocketpool-daemon/common/state" + "github.com/rocket-pool/smartnode/v2/rocketpool-daemon/task" "github.com/rocket-pool/smartnode/v2/shared/config" "github.com/rocket-pool/smartnode/v2/shared/keys" sharedtypes "github.com/rocket-pool/smartnode/v2/shared/types" @@ -31,45 +31,35 @@ import ( // Generate rewards Merkle Tree task type GenerateRewardsTree struct { - ctx context.Context - sp *services.ServiceProvider - logger *slog.Logger - cfg *config.SmartNodeConfig - rp *rocketpool.RocketPool - ec eth.IExecutionClient - bc beacon.IBeaconClient - lock *sync.Mutex - isRunning bool + *task.LockingBackgroundTask + sp *services.ServiceProvider + logger *slog.Logger + cfg *config.SmartNodeConfig + rp *rocketpool.RocketPool + ec eth.IExecutionClient + bc beacon.IBeaconClient } // Create generate rewards Merkle Tree task func NewGenerateRewardsTree(ctx context.Context, sp *services.ServiceProvider, logger *log.Logger) *GenerateRewardsTree { - lock := &sync.Mutex{} - return &GenerateRewardsTree{ - ctx: ctx, - sp: sp, - logger: logger.With(slog.String(keys.RoutineKey, "Generate Rewards Tree")), - cfg: sp.GetConfig(), - rp: sp.GetRocketPool(), - ec: sp.GetEthClient(), - bc: sp.GetBeaconClient(), - lock: lock, - isRunning: false, + out := &GenerateRewardsTree{ + sp: sp, + logger: logger.With(slog.String(keys.RoutineKey, "Generate Rewards Tree")), + cfg: sp.GetConfig(), + rp: sp.GetRocketPool(), + ec: sp.GetEthClient(), + bc: sp.GetBeaconClient(), } + out.LockingBackgroundTask = task.NewLockingBackgroundTask( + logger.With(slog.String(keys.RoutineKey, "Generate Rewards Tree")), + "manual rewards tree generation request check", + out.Run, + ) + return out } // Check for generation requests -func (t *GenerateRewardsTree) Run() error { - t.logger.Info("Starting manual rewards tree generation request check.") - - // Check if rewards generation is already running - t.lock.Lock() - if t.isRunning { - t.logger.Info("Tree generation is already running.") - t.lock.Unlock() - return nil - } - t.lock.Unlock() +func (t *GenerateRewardsTree) Run(taskCtx *task.BackgroundTaskContext) error { // Check for requests requestDir := t.cfg.GetWatchtowerFolder() @@ -102,10 +92,7 @@ func (t *GenerateRewardsTree) Run() error { } // Generate the rewards tree - t.lock.Lock() - t.isRunning = true - t.lock.Unlock() - go t.generateRewardsTree(index) + go t.generateRewardsTree(taskCtx.Ctx, index) // Return after the first request, do others at other intervals return nil @@ -115,7 +102,10 @@ func (t *GenerateRewardsTree) Run() error { return nil } -func (t *GenerateRewardsTree) generateRewardsTree(index uint64) { +func (t *GenerateRewardsTree) generateRewardsTree(ctx context.Context, index uint64) { + // This function is the async portion of the task, so make sure to signal completion + defer t.LockingBackgroundTask.Done() + // Begin generation of the tree logger := t.logger.With(slog.Uint64(keys.IntervalKey, index)) logger.Info("Starting generation of Merkle rewards tree.") @@ -150,7 +140,7 @@ func (t *GenerateRewardsTree) generateRewardsTree(index uint64) { }, opts) if err == nil { // Create the state manager with using the primary or fallback (not necessarily archive) EC - stateManager, err = state.NewNetworkStateManager(t.ctx, client, t.cfg, t.rp.Client, t.bc, logger) + stateManager, err = state.NewNetworkStateManager(ctx, client, t.cfg, t.rp.Client, t.bc, logger) if err != nil { t.handleError(fmt.Errorf("error creating new NetworkStateManager with Archive EC: %w", err), logger) return @@ -189,7 +179,7 @@ func (t *GenerateRewardsTree) generateRewardsTree(index uint64) { return } // Create the state manager with the archive EC - stateManager, err = state.NewNetworkStateManager(t.ctx, client, t.cfg, ec, t.bc, logger) + stateManager, err = state.NewNetworkStateManager(ctx, client, t.cfg, ec, t.bc, logger) if err != nil { t.handleError(fmt.Errorf("error creating new NetworkStateManager with Archive EC: %w", err), logger) return @@ -210,18 +200,18 @@ func (t *GenerateRewardsTree) generateRewardsTree(index uint64) { } // Get the state for the target slot - state, err := stateManager.GetStateForSlot(t.ctx, rewardsEvent.ConsensusBlock.Uint64()) + state, err := stateManager.GetStateForSlot(ctx, rewardsEvent.ConsensusBlock.Uint64()) if err != nil { t.handleError(fmt.Errorf("error getting state for beacon slot %d: %w", rewardsEvent.ConsensusBlock.Uint64(), err), logger) return } // Generate the tree - t.generateRewardsTreeImpl(logger, client, index, rewardsEvent, elBlockHeader, state) + t.generateRewardsTreeImpl(ctx, logger, client, index, rewardsEvent, elBlockHeader, state) } // Implementation for rewards tree generation using a viable EC -func (t *GenerateRewardsTree) generateRewardsTreeImpl(logger *slog.Logger, rp *rocketpool.RocketPool, index uint64, rewardsEvent rewards.RewardsEvent, elBlockHeader *types.Header, state *state.NetworkState) { +func (t *GenerateRewardsTree) generateRewardsTreeImpl(ctx context.Context, logger *slog.Logger, rp *rocketpool.RocketPool, index uint64, rewardsEvent rewards.RewardsEvent, elBlockHeader *types.Header, state *state.NetworkState) { // Generate the rewards file start := time.Now() treegen, err := rprewards.NewTreeGenerator(t.logger, rp, t.cfg, t.bc, index, rewardsEvent.IntervalStartTime, rewardsEvent.IntervalEndTime, rewardsEvent.ConsensusBlock.Uint64(), elBlockHeader, rewardsEvent.IntervalsPassed.Uint64(), state, nil) @@ -229,7 +219,7 @@ func (t *GenerateRewardsTree) generateRewardsTreeImpl(logger *slog.Logger, rp *r t.handleError(fmt.Errorf("Error creating Merkle tree generator: %w", err), logger) return } - rewardsFile, err := treegen.GenerateTree(t.ctx) + rewardsFile, err := treegen.GenerateTree(ctx) if err != nil { t.handleError(fmt.Errorf("%s Error generating Merkle tree: %w", err), logger) return @@ -273,14 +263,8 @@ func (t *GenerateRewardsTree) generateRewardsTreeImpl(logger *slog.Logger, rp *r } t.logger.Info("Merkle tree generation complete!") - t.lock.Lock() - t.isRunning = false - t.lock.Unlock() } func (t *GenerateRewardsTree) handleError(err error, logger *slog.Logger) { logger.Error("*** Rewards tree generation failed. ***", log.Err(err)) - t.lock.Lock() - t.isRunning = false - t.lock.Unlock() } diff --git a/rocketpool-daemon/watchtower/watchtower.go b/rocketpool-daemon/watchtower/watchtower.go index de7f6e856..c7c8fae77 100644 --- a/rocketpool-daemon/watchtower/watchtower.go +++ b/rocketpool-daemon/watchtower/watchtower.go @@ -11,6 +11,7 @@ import ( "github.com/rocket-pool/rocketpool-go/v2/rocketpool" "github.com/rocket-pool/smartnode/v2/rocketpool-daemon/common/services" "github.com/rocket-pool/smartnode/v2/rocketpool-daemon/common/state" + "github.com/rocket-pool/smartnode/v2/rocketpool-daemon/task" "github.com/rocket-pool/smartnode/v2/rocketpool-daemon/watchtower/collectors" "github.com/rocket-pool/smartnode/v2/shared/config" ) @@ -32,8 +33,10 @@ type TaskManager struct { rp *rocketpool.RocketPool bc beacon.IBeaconClient + // Generic Tasks to run + tasks []task.BackgroundTask + // Tasks - generateRewardsTree *GenerateRewardsTree respondChallenges *RespondChallenges submitRplPrice *SubmitRplPrice submitNetworkBalances *SubmitNetworkBalances @@ -84,7 +87,6 @@ func NewTaskManager( } // Initialize tasks - generateRewardsTree := NewGenerateRewardsTree(ctx, sp, logger) respondChallenges := NewRespondChallenges(sp, logger, stateMgr) submitRplPrice := NewSubmitRplPrice(ctx, sp, logger) submitNetworkBalances := NewSubmitNetworkBalances(ctx, sp, logger) @@ -101,13 +103,15 @@ func NewTaskManager( finalizePdaoProposals := NewFinalizePdaoProposals(sp, logger) return &TaskManager{ - sp: sp, - logger: logger, - ctx: ctx, - cfg: cfg, - rp: rp, - bc: bc, - generateRewardsTree: generateRewardsTree, + sp: sp, + logger: logger, + ctx: ctx, + cfg: cfg, + rp: rp, + bc: bc, + tasks: []task.BackgroundTask{ + NewGenerateRewardsTree(ctx, sp, logger).LockingBackgroundTask, + }, respondChallenges: respondChallenges, submitRplPrice: submitRplPrice, submitNetworkBalances: submitNetworkBalances, @@ -146,12 +150,22 @@ func (t *TaskManager) Initialize(stateMgr *state.NetworkStateManager) error { // Run the task loop func (t *TaskManager) Run(isOnOdao bool, state *state.NetworkState) error { - // Run the manual rewards tree generation - if err := t.generateRewardsTree.Run(); err != nil { - t.logger.Error(err.Error()) + taskCtx := &task.BackgroundTaskContext{ + // TODO: having a single global context stemming from + // context.Background is basically the same as passing around nil, + // and we should remove ctx from t and add it to Run() + Ctx: t.ctx, + IsOnOdao: isOnOdao, + State: state, } - if utils.SleepWithCancel(t.ctx, taskCooldown) { - return nil + // Run the generic tasks + for _, taskItem := range t.tasks { + if err := taskItem.Run(taskCtx); err != nil && err != task.ErrAlreadyRunning { + t.logger.Error(err.Error()) + } + if utils.SleepWithCancel(t.ctx, taskCooldown) { + return nil + } } if isOnOdao {