Skip to content

Commit

Permalink
Quick stab at making watchtower tasks generic
Browse files Browse the repository at this point in the history
  • Loading branch information
jshufro committed Apr 10, 2024
1 parent 445465f commit 0042a9c
Show file tree
Hide file tree
Showing 3 changed files with 138 additions and 63 deletions.
77 changes: 77 additions & 0 deletions rocketpool-daemon/task/task.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,77 @@
package task

import (
"context"
"errors"
"log/slog"
"sync"

"github.com/rocket-pool/smartnode/v2/rocketpool-daemon/common/state"
)

var (
// ErrAlreadyRunning is returned when a background task is kicked off, but it is already in progress.
ErrAlreadyRunning = errors.New("task is already running")
)

// TaskContext is passed to the Task's Callback function when the invoker wishes the task
// to be kicked off.
//
// Its fields are things that are variable and may change between invokations of a task.
type BackgroundTaskContext struct {
// A context provided by the invoker of this task.
// May be nil, and cancellations should be respected.
Ctx context.Context
// Whether or not the node is on the oDAO at the time the task was invoked
IsOnOdao bool
// A recent network state so each task need not query it redundantly
State *state.NetworkState
}

type BackgroundTask interface {
// Returns a function to call that starts the task in the background
Run(*BackgroundTaskContext) error
// A function that tasks must call when all async portions are completed
Done()
}

type LockingBackgroundTask struct {
logger *slog.Logger
description string
run func(*BackgroundTaskContext) error

lock sync.Mutex
isRunning bool
}

func NewLockingBackgroundTask(logger *slog.Logger, description string, f func(*BackgroundTaskContext) error) *LockingBackgroundTask {
return &LockingBackgroundTask{
description: description,
logger: logger,
run: f,
}
}

func (lbt *LockingBackgroundTask) Run(taskContext *BackgroundTaskContext) error {
lbt.lock.Lock()
defer lbt.lock.Unlock()

lbt.logger.Info("Starting task", "description", lbt.description)
if lbt.isRunning {
lbt.logger.Info("Task is already running", "description", lbt.description)
return ErrAlreadyRunning
}

lbt.isRunning = true
err := lbt.run(taskContext)
if err != nil {
lbt.Done()
}
return err
}

func (lbt *LockingBackgroundTask) Done() {
lbt.lock.Lock()
defer lbt.lock.Unlock()
lbt.isRunning = false
}
82 changes: 33 additions & 49 deletions rocketpool-daemon/watchtower/generate-rewards-tree.go
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,6 @@ import (
"path/filepath"
"strconv"
"strings"
"sync"
"time"

"github.com/ethereum/go-ethereum/accounts/abi/bind"
Expand All @@ -24,52 +23,43 @@ import (
rprewards "github.com/rocket-pool/smartnode/v2/rocketpool-daemon/common/rewards"
"github.com/rocket-pool/smartnode/v2/rocketpool-daemon/common/services"
"github.com/rocket-pool/smartnode/v2/rocketpool-daemon/common/state"
"github.com/rocket-pool/smartnode/v2/rocketpool-daemon/task"
"github.com/rocket-pool/smartnode/v2/shared/config"
"github.com/rocket-pool/smartnode/v2/shared/keys"
sharedtypes "github.com/rocket-pool/smartnode/v2/shared/types"
)

// Generate rewards Merkle Tree task
type GenerateRewardsTree struct {
ctx context.Context
sp *services.ServiceProvider
logger *slog.Logger
cfg *config.SmartNodeConfig
rp *rocketpool.RocketPool
ec eth.IExecutionClient
bc beacon.IBeaconClient
lock *sync.Mutex
isRunning bool
*task.LockingBackgroundTask
sp *services.ServiceProvider
logger *slog.Logger
cfg *config.SmartNodeConfig
rp *rocketpool.RocketPool
ec eth.IExecutionClient
bc beacon.IBeaconClient
}

// Create generate rewards Merkle Tree task
func NewGenerateRewardsTree(ctx context.Context, sp *services.ServiceProvider, logger *log.Logger) *GenerateRewardsTree {
lock := &sync.Mutex{}
return &GenerateRewardsTree{
ctx: ctx,
sp: sp,
logger: logger.With(slog.String(keys.RoutineKey, "Generate Rewards Tree")),
cfg: sp.GetConfig(),
rp: sp.GetRocketPool(),
ec: sp.GetEthClient(),
bc: sp.GetBeaconClient(),
lock: lock,
isRunning: false,
out := &GenerateRewardsTree{
sp: sp,
logger: logger.With(slog.String(keys.RoutineKey, "Generate Rewards Tree")),
cfg: sp.GetConfig(),
rp: sp.GetRocketPool(),
ec: sp.GetEthClient(),
bc: sp.GetBeaconClient(),
}
out.LockingBackgroundTask = task.NewLockingBackgroundTask(
logger.With(slog.String(keys.RoutineKey, "Generate Rewards Tree")),
"manual rewards tree generation request check",
out.Run,
)
return out
}

// Check for generation requests
func (t *GenerateRewardsTree) Run() error {
t.logger.Info("Starting manual rewards tree generation request check.")

// Check if rewards generation is already running
t.lock.Lock()
if t.isRunning {
t.logger.Info("Tree generation is already running.")
t.lock.Unlock()
return nil
}
t.lock.Unlock()
func (t *GenerateRewardsTree) Run(taskCtx *task.BackgroundTaskContext) error {

// Check for requests
requestDir := t.cfg.GetWatchtowerFolder()
Expand Down Expand Up @@ -102,10 +92,7 @@ func (t *GenerateRewardsTree) Run() error {
}

// Generate the rewards tree
t.lock.Lock()
t.isRunning = true
t.lock.Unlock()
go t.generateRewardsTree(index)
go t.generateRewardsTree(taskCtx.Ctx, index)

// Return after the first request, do others at other intervals
return nil
Expand All @@ -115,7 +102,10 @@ func (t *GenerateRewardsTree) Run() error {
return nil
}

func (t *GenerateRewardsTree) generateRewardsTree(index uint64) {
func (t *GenerateRewardsTree) generateRewardsTree(ctx context.Context, index uint64) {
// This function is the async portion of the task, so make sure to signal completion
defer t.LockingBackgroundTask.Done()

// Begin generation of the tree
logger := t.logger.With(slog.Uint64(keys.IntervalKey, index))
logger.Info("Starting generation of Merkle rewards tree.")
Expand Down Expand Up @@ -150,7 +140,7 @@ func (t *GenerateRewardsTree) generateRewardsTree(index uint64) {
}, opts)
if err == nil {
// Create the state manager with using the primary or fallback (not necessarily archive) EC
stateManager, err = state.NewNetworkStateManager(t.ctx, client, t.cfg, t.rp.Client, t.bc, logger)
stateManager, err = state.NewNetworkStateManager(ctx, client, t.cfg, t.rp.Client, t.bc, logger)
if err != nil {
t.handleError(fmt.Errorf("error creating new NetworkStateManager with Archive EC: %w", err), logger)
return
Expand Down Expand Up @@ -189,7 +179,7 @@ func (t *GenerateRewardsTree) generateRewardsTree(index uint64) {
return
}
// Create the state manager with the archive EC
stateManager, err = state.NewNetworkStateManager(t.ctx, client, t.cfg, ec, t.bc, logger)
stateManager, err = state.NewNetworkStateManager(ctx, client, t.cfg, ec, t.bc, logger)
if err != nil {
t.handleError(fmt.Errorf("error creating new NetworkStateManager with Archive EC: %w", err), logger)
return
Expand All @@ -210,26 +200,26 @@ func (t *GenerateRewardsTree) generateRewardsTree(index uint64) {
}

// Get the state for the target slot
state, err := stateManager.GetStateForSlot(t.ctx, rewardsEvent.ConsensusBlock.Uint64())
state, err := stateManager.GetStateForSlot(ctx, rewardsEvent.ConsensusBlock.Uint64())
if err != nil {
t.handleError(fmt.Errorf("error getting state for beacon slot %d: %w", rewardsEvent.ConsensusBlock.Uint64(), err), logger)
return
}

// Generate the tree
t.generateRewardsTreeImpl(logger, client, index, rewardsEvent, elBlockHeader, state)
t.generateRewardsTreeImpl(ctx, logger, client, index, rewardsEvent, elBlockHeader, state)
}

// Implementation for rewards tree generation using a viable EC
func (t *GenerateRewardsTree) generateRewardsTreeImpl(logger *slog.Logger, rp *rocketpool.RocketPool, index uint64, rewardsEvent rewards.RewardsEvent, elBlockHeader *types.Header, state *state.NetworkState) {
func (t *GenerateRewardsTree) generateRewardsTreeImpl(ctx context.Context, logger *slog.Logger, rp *rocketpool.RocketPool, index uint64, rewardsEvent rewards.RewardsEvent, elBlockHeader *types.Header, state *state.NetworkState) {
// Generate the rewards file
start := time.Now()
treegen, err := rprewards.NewTreeGenerator(t.logger, rp, t.cfg, t.bc, index, rewardsEvent.IntervalStartTime, rewardsEvent.IntervalEndTime, rewardsEvent.ConsensusBlock.Uint64(), elBlockHeader, rewardsEvent.IntervalsPassed.Uint64(), state, nil)
if err != nil {
t.handleError(fmt.Errorf("Error creating Merkle tree generator: %w", err), logger)
return
}
rewardsFile, err := treegen.GenerateTree(t.ctx)
rewardsFile, err := treegen.GenerateTree(ctx)
if err != nil {
t.handleError(fmt.Errorf("%s Error generating Merkle tree: %w", err), logger)
return
Expand Down Expand Up @@ -273,14 +263,8 @@ func (t *GenerateRewardsTree) generateRewardsTreeImpl(logger *slog.Logger, rp *r
}

t.logger.Info("Merkle tree generation complete!")
t.lock.Lock()
t.isRunning = false
t.lock.Unlock()
}

func (t *GenerateRewardsTree) handleError(err error, logger *slog.Logger) {
logger.Error("*** Rewards tree generation failed. ***", log.Err(err))
t.lock.Lock()
t.isRunning = false
t.lock.Unlock()
}
42 changes: 28 additions & 14 deletions rocketpool-daemon/watchtower/watchtower.go
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,7 @@ import (
"github.com/rocket-pool/rocketpool-go/v2/rocketpool"
"github.com/rocket-pool/smartnode/v2/rocketpool-daemon/common/services"
"github.com/rocket-pool/smartnode/v2/rocketpool-daemon/common/state"
"github.com/rocket-pool/smartnode/v2/rocketpool-daemon/task"
"github.com/rocket-pool/smartnode/v2/rocketpool-daemon/watchtower/collectors"
"github.com/rocket-pool/smartnode/v2/shared/config"
)
Expand All @@ -32,8 +33,10 @@ type TaskManager struct {
rp *rocketpool.RocketPool
bc beacon.IBeaconClient

// Generic Tasks to run
tasks []task.BackgroundTask

// Tasks
generateRewardsTree *GenerateRewardsTree
respondChallenges *RespondChallenges
submitRplPrice *SubmitRplPrice
submitNetworkBalances *SubmitNetworkBalances
Expand Down Expand Up @@ -84,7 +87,6 @@ func NewTaskManager(
}

// Initialize tasks
generateRewardsTree := NewGenerateRewardsTree(ctx, sp, logger)
respondChallenges := NewRespondChallenges(sp, logger, stateMgr)
submitRplPrice := NewSubmitRplPrice(ctx, sp, logger)
submitNetworkBalances := NewSubmitNetworkBalances(ctx, sp, logger)
Expand All @@ -101,13 +103,15 @@ func NewTaskManager(
finalizePdaoProposals := NewFinalizePdaoProposals(sp, logger)

return &TaskManager{
sp: sp,
logger: logger,
ctx: ctx,
cfg: cfg,
rp: rp,
bc: bc,
generateRewardsTree: generateRewardsTree,
sp: sp,
logger: logger,
ctx: ctx,
cfg: cfg,
rp: rp,
bc: bc,
tasks: []task.BackgroundTask{
NewGenerateRewardsTree(ctx, sp, logger).LockingBackgroundTask,
},
respondChallenges: respondChallenges,
submitRplPrice: submitRplPrice,
submitNetworkBalances: submitNetworkBalances,
Expand Down Expand Up @@ -146,12 +150,22 @@ func (t *TaskManager) Initialize(stateMgr *state.NetworkStateManager) error {

// Run the task loop
func (t *TaskManager) Run(isOnOdao bool, state *state.NetworkState) error {
// Run the manual rewards tree generation
if err := t.generateRewardsTree.Run(); err != nil {
t.logger.Error(err.Error())
taskCtx := &task.BackgroundTaskContext{
// TODO: having a single global context stemming from
// context.Background is basically the same as passing around nil,
// and we should remove ctx from t and add it to Run()
Ctx: t.ctx,
IsOnOdao: isOnOdao,
State: state,
}
if utils.SleepWithCancel(t.ctx, taskCooldown) {
return nil
// Run the generic tasks
for _, taskItem := range t.tasks {
if err := taskItem.Run(taskCtx); err != nil && err != task.ErrAlreadyRunning {
t.logger.Error(err.Error())
}
if utils.SleepWithCancel(t.ctx, taskCooldown) {
return nil
}
}

if isOnOdao {
Expand Down

0 comments on commit 0042a9c

Please sign in to comment.