Skip to content

Commit

Permalink
fix: cumulative payment dynamo db unit conversion (#979)
Browse files Browse the repository at this point in the history
  • Loading branch information
hopeyen authored Dec 12, 2024
1 parent 1430d56 commit be47a6c
Show file tree
Hide file tree
Showing 3 changed files with 78 additions and 66 deletions.
11 changes: 7 additions & 4 deletions core/meterer/meterer.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,6 +3,7 @@ package meterer
import (
"context"
"fmt"
"math/big"
"slices"
"time"

Expand Down Expand Up @@ -230,20 +231,22 @@ func (m *Meterer) ValidatePayment(ctx context.Context, header core.PaymentMetada
return fmt.Errorf("failed to get relevant on-demand records: %w", err)
}
// the current request must increment cumulative payment by a magnitude sufficient to cover the blob size
if prevPmt+m.PaymentCharged(numSymbols) > header.CumulativePayment.Uint64() {
if prevPmt.Add(prevPmt, m.PaymentCharged(numSymbols)).Cmp(header.CumulativePayment) > 0 {
return fmt.Errorf("insufficient cumulative payment increment")
}
// the current request must not break the payment magnitude for the next payment if the two requests were delivered out-of-order
if nextPmt != 0 && header.CumulativePayment.Uint64()+m.PaymentCharged(uint(nextPmtnumSymbols)) > nextPmt {
if nextPmt.Cmp(big.NewInt(0)) != 0 && header.CumulativePayment.Add(header.CumulativePayment, m.PaymentCharged(uint(nextPmtnumSymbols))).Cmp(nextPmt) > 0 {
return fmt.Errorf("breaking cumulative payment invariants")
}
// check passed: blob can be safely inserted into the set of payments
return nil
}

// PaymentCharged returns the chargeable price for a given data length
func (m *Meterer) PaymentCharged(numSymbols uint) uint64 {
return uint64(m.SymbolsCharged(numSymbols)) * uint64(m.ChainPaymentState.GetPricePerSymbol())
func (m *Meterer) PaymentCharged(numSymbols uint) *big.Int {
symbolsCharged := big.NewInt(int64(m.SymbolsCharged(numSymbols)))
pricePerSymbol := big.NewInt(int64(m.ChainPaymentState.GetPricePerSymbol()))
return symbolsCharged.Mul(symbolsCharged, pricePerSymbol)
}

// SymbolsCharged returns the number of symbols charged for a given data length
Expand Down
52 changes: 26 additions & 26 deletions core/meterer/meterer_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -186,16 +186,16 @@ func TestMetererReservations(t *testing.T) {
paymentChainState.On("GetActiveReservationByAccount", testifymock.Anything, testifymock.Anything).Return(&core.ActiveReservation{}, fmt.Errorf("reservation not found"))

// test invalid quorom ID
header := createPaymentHeader(1, 0, accountID1)
header := createPaymentHeader(1, big.NewInt(0), accountID1)
err := mt.MeterRequest(ctx, *header, 1000, []uint8{0, 1, 2})
assert.ErrorContains(t, err, "quorum number mismatch")

// overwhelming bin overflow for empty bins
header = createPaymentHeader(reservationPeriod-1, 0, accountID2)
header = createPaymentHeader(reservationPeriod-1, big.NewInt(0), accountID2)
err = mt.MeterRequest(ctx, *header, 10, quoromNumbers)
assert.NoError(t, err)
// overwhelming bin overflow for empty bins
header = createPaymentHeader(reservationPeriod-1, 0, accountID2)
header = createPaymentHeader(reservationPeriod-1, big.NewInt(0), accountID2)
err = mt.MeterRequest(ctx, *header, 1000, quoromNumbers)
assert.ErrorContains(t, err, "overflow usage exceeds bin limit")

Expand All @@ -204,21 +204,21 @@ func TestMetererReservations(t *testing.T) {
if err != nil {
t.Fatalf("Failed to generate key: %v", err)
}
header = createPaymentHeader(1, 0, crypto.PubkeyToAddress(unregisteredUser.PublicKey))
header = createPaymentHeader(1, big.NewInt(0), crypto.PubkeyToAddress(unregisteredUser.PublicKey))
assert.NoError(t, err)
err = mt.MeterRequest(ctx, *header, 1000, []uint8{0, 1, 2})
assert.ErrorContains(t, err, "failed to get active reservation by account: reservation not found")

// test invalid bin index
header = createPaymentHeader(reservationPeriod, 0, accountID1)
header = createPaymentHeader(reservationPeriod, big.NewInt(0), accountID1)
err = mt.MeterRequest(ctx, *header, 2000, quoromNumbers)
assert.ErrorContains(t, err, "invalid bin index for reservation")

// test bin usage metering
symbolLength := uint(20)
requiredLength := uint(21) // 21 should be charged for length of 20 since minNumSymbols is 3
for i := 0; i < 9; i++ {
header = createPaymentHeader(reservationPeriod, 0, accountID2)
header = createPaymentHeader(reservationPeriod, big.NewInt(0), accountID2)
err = mt.MeterRequest(ctx, *header, symbolLength, quoromNumbers)
assert.NoError(t, err)
item, err := dynamoClient.GetItem(ctx, reservationTableName, commondynamodb.Key{
Expand All @@ -232,7 +232,7 @@ func TestMetererReservations(t *testing.T) {

}
// first over flow is allowed
header = createPaymentHeader(reservationPeriod, 0, accountID2)
header = createPaymentHeader(reservationPeriod, big.NewInt(0), accountID2)
assert.NoError(t, err)
err = mt.MeterRequest(ctx, *header, 25, quoromNumbers)
assert.NoError(t, err)
Expand All @@ -248,7 +248,7 @@ func TestMetererReservations(t *testing.T) {
assert.Equal(t, strconv.Itoa(int(16)), item["BinUsage"].(*types.AttributeValueMemberN).Value)

// second over flow
header = createPaymentHeader(reservationPeriod, 0, accountID2)
header = createPaymentHeader(reservationPeriod, big.NewInt(0), accountID2)
assert.NoError(t, err)
err = mt.MeterRequest(ctx, *header, 1, quoromNumbers)
assert.ErrorContains(t, err, "bin has already been filled")
Expand All @@ -275,18 +275,18 @@ func TestMetererOnDemand(t *testing.T) {
if err != nil {
t.Fatalf("Failed to generate key: %v", err)
}
header := createPaymentHeader(reservationPeriod, 2, crypto.PubkeyToAddress(unregisteredUser.PublicKey))
header := createPaymentHeader(reservationPeriod, big.NewInt(2), crypto.PubkeyToAddress(unregisteredUser.PublicKey))
assert.NoError(t, err)
err = mt.MeterRequest(ctx, *header, 1000, quorumNumbers)
assert.ErrorContains(t, err, "failed to get on-demand payment by account: payment not found")

// test invalid quorom ID
header = createPaymentHeader(reservationPeriod, 1, accountID1)
header = createPaymentHeader(reservationPeriod, big.NewInt(2), accountID1)
err = mt.MeterRequest(ctx, *header, 1000, []uint8{0, 1, 2})
assert.ErrorContains(t, err, "invalid quorum for On-Demand Request")

// test insufficient cumulative payment
header = createPaymentHeader(reservationPeriod, 1, accountID1)
header = createPaymentHeader(reservationPeriod, big.NewInt(1), accountID1)
err = mt.MeterRequest(ctx, *header, 1000, quorumNumbers)
assert.ErrorContains(t, err, "insufficient cumulative payment increment")
// No rollback after meter request
Expand All @@ -300,7 +300,7 @@ func TestMetererOnDemand(t *testing.T) {
// test duplicated cumulative payments
symbolLength := uint(100)
priceCharged := mt.PaymentCharged(symbolLength)
assert.Equal(t, uint64(102*mt.ChainPaymentState.GetPricePerSymbol()), priceCharged)
assert.Equal(t, big.NewInt(int64(102*mt.ChainPaymentState.GetPricePerSymbol())), priceCharged)
header = createPaymentHeader(reservationPeriod, priceCharged, accountID2)
err = mt.MeterRequest(ctx, *header, symbolLength, quorumNumbers)
assert.NoError(t, err)
Expand All @@ -310,24 +310,24 @@ func TestMetererOnDemand(t *testing.T) {

// test valid payments
for i := 1; i < 9; i++ {
header = createPaymentHeader(reservationPeriod, uint64(priceCharged)*uint64(i+1), accountID2)
header = createPaymentHeader(reservationPeriod, new(big.Int).Mul(priceCharged, big.NewInt(int64(i+1))), accountID2)
err = mt.MeterRequest(ctx, *header, symbolLength, quorumNumbers)
assert.NoError(t, err)
}

// test cumulative payment on-chain constraint
header = createPaymentHeader(reservationPeriod, 2023, accountID2)
header = createPaymentHeader(reservationPeriod, big.NewInt(2023), accountID2)
err = mt.MeterRequest(ctx, *header, 1, quorumNumbers)
assert.ErrorContains(t, err, "invalid on-demand payment: request claims a cumulative payment greater than the on-chain deposit")

// test insufficient increment in cumulative payment
previousCumulativePayment := uint64(priceCharged) * uint64(9)
previousCumulativePayment := priceCharged.Mul(priceCharged, big.NewInt(9))
symbolLength = uint(2)
priceCharged = mt.PaymentCharged(symbolLength)
header = createPaymentHeader(reservationPeriod, previousCumulativePayment+priceCharged-1, accountID2)
header = createPaymentHeader(reservationPeriod, big.NewInt(0).Add(previousCumulativePayment, big.NewInt(0).Sub(priceCharged, big.NewInt(1))), accountID2)
err = mt.MeterRequest(ctx, *header, symbolLength, quorumNumbers)
assert.ErrorContains(t, err, "invalid on-demand payment: insufficient cumulative payment increment")
previousCumulativePayment = previousCumulativePayment + priceCharged
previousCumulativePayment = big.NewInt(0).Add(previousCumulativePayment, priceCharged)

// test cannot insert cumulative payment in out of order
header = createPaymentHeader(reservationPeriod, mt.PaymentCharged(50), accountID2)
Expand All @@ -342,7 +342,7 @@ func TestMetererOnDemand(t *testing.T) {
assert.NoError(t, err)
assert.Equal(t, numPrevRecords, len(result))
// test failed global rate limit (previously payment recorded: 2, global limit: 1009)
header = createPaymentHeader(reservationPeriod, previousCumulativePayment+mt.PaymentCharged(1010), accountID1)
header = createPaymentHeader(reservationPeriod, big.NewInt(0).Add(previousCumulativePayment, mt.PaymentCharged(1010)), accountID1)
err = mt.MeterRequest(ctx, *header, 1010, quorumNumbers)
assert.ErrorContains(t, err, "failed global rate limiting")
// Correct rollback
Expand All @@ -360,42 +360,42 @@ func TestMeterer_paymentCharged(t *testing.T) {
symbolLength uint
pricePerSymbol uint32
minNumSymbols uint32
expected uint64
expected *big.Int
}{
{
name: "Data length equal to min chargeable size",
symbolLength: 1024,
pricePerSymbol: 1,
minNumSymbols: 1024,
expected: 1024,
expected: big.NewInt(1024),
},
{
name: "Data length less than min chargeable size",
symbolLength: 512,
pricePerSymbol: 1,
minNumSymbols: 1024,
expected: 1024,
expected: big.NewInt(1024),
},
{
name: "Data length greater than min chargeable size",
symbolLength: 2048,
pricePerSymbol: 1,
minNumSymbols: 1024,
expected: 2048,
expected: big.NewInt(2048),
},
{
name: "Large data length",
symbolLength: 1 << 20, // 1 MB
pricePerSymbol: 1,
minNumSymbols: 1024,
expected: 1 << 20,
expected: big.NewInt(1 << 20),
},
{
name: "Price not evenly divisible by min chargeable size",
symbolLength: 1536,
pricePerSymbol: 1,
minNumSymbols: 1024,
expected: 2048,
expected: big.NewInt(2048),
},
}

Expand Down Expand Up @@ -465,10 +465,10 @@ func TestMeterer_symbolsCharged(t *testing.T) {
}
}

func createPaymentHeader(reservationPeriod uint32, cumulativePayment uint64, accountID gethcommon.Address) *core.PaymentMetadata {
func createPaymentHeader(reservationPeriod uint32, cumulativePayment *big.Int, accountID gethcommon.Address) *core.PaymentMetadata {
return &core.PaymentMetadata{
AccountID: accountID.Hex(),
ReservationPeriod: reservationPeriod,
CumulativePayment: big.NewInt(int64(cumulativePayment)),
CumulativePayment: cumulativePayment,
}
}
81 changes: 45 additions & 36 deletions core/meterer/offchain_store.go
Original file line number Diff line number Diff line change
Expand Up @@ -6,7 +6,6 @@ import (
"fmt"
"math/big"
"strconv"
"time"

pb "github.com/Layr-Labs/eigenda/api/grpc/disperser/v2"
commonaws "github.com/Layr-Labs/eigenda/common/aws"
Expand Down Expand Up @@ -65,24 +64,6 @@ func NewOffchainStore(
}, nil
}

type ReservationBin struct {
AccountID string
ReservationPeriod uint32
BinUsage uint32
UpdatedAt time.Time
}

type PaymentTuple struct {
CumulativePayment uint64
DataLength uint32
}

type GlobalBin struct {
ReservationPeriod uint32
BinUsage uint64
UpdatedAt time.Time
}

func (s *OffchainStore) UpdateReservationBin(ctx context.Context, accountID string, reservationPeriod uint64, size uint64) (uint64, error) {
key := map[string]types.AttributeValue{
"AccountID": &types.AttributeValueMemberS{Value: accountID},
Expand Down Expand Up @@ -185,7 +166,7 @@ func (s *OffchainStore) RemoveOnDemandPayment(ctx context.Context, accountID str

// GetRelevantOnDemandRecords gets previous cumulative payment, next cumulative payment, blob size of next payment
// The queries are done sequentially instead of one-go for efficient querying and would not cause race condition errors for honest requests
func (s *OffchainStore) GetRelevantOnDemandRecords(ctx context.Context, accountID string, cumulativePayment *big.Int) (uint64, uint64, uint32, error) {
func (s *OffchainStore) GetRelevantOnDemandRecords(ctx context.Context, accountID string, cumulativePayment *big.Int) (*big.Int, *big.Int, uint32, error) {
// Fetch the largest entry smaller than the given cumulativePayment
queryInput := &dynamodb.QueryInput{
TableName: aws.String(s.onDemandTableName),
Expand All @@ -199,14 +180,23 @@ func (s *OffchainStore) GetRelevantOnDemandRecords(ctx context.Context, accountI
}
smallerResult, err := s.dynamoClient.QueryWithInput(ctx, queryInput)
if err != nil {
return 0, 0, 0, fmt.Errorf("failed to query smaller payments for account: %w", err)
return nil, nil, 0, fmt.Errorf("failed to query smaller payments for account: %w", err)
}
var prevPayment uint64
prevPayment := big.NewInt(0)
if len(smallerResult) > 0 {
prevPayment, err = strconv.ParseUint(smallerResult[0]["CumulativePayments"].(*types.AttributeValueMemberN).Value, 10, 64)
if err != nil {
return 0, 0, 0, fmt.Errorf("failed to parse previous payment: %w", err)
cumulativePaymentsAttr, ok := smallerResult[0]["CumulativePayments"]
if !ok {
return nil, nil, 0, fmt.Errorf("CumulativePayments field not found in result")
}
cumulativePaymentsNum, ok := cumulativePaymentsAttr.(*types.AttributeValueMemberN)
if !ok {
return nil, nil, 0, fmt.Errorf("CumulativePayments has invalid type")
}
setPrevPayment, success := prevPayment.SetString(cumulativePaymentsNum.Value, 10)
if !success {
return nil, nil, 0, fmt.Errorf("failed to parse previous payment: %w", err)
}
prevPayment = setPrevPayment
}

// Fetch the smallest entry larger than the given cumulativePayment
Expand All @@ -222,18 +212,36 @@ func (s *OffchainStore) GetRelevantOnDemandRecords(ctx context.Context, accountI
}
largerResult, err := s.dynamoClient.QueryWithInput(ctx, queryInput)
if err != nil {
return 0, 0, 0, fmt.Errorf("failed to query the next payment for account: %w", err)
return nil, nil, 0, fmt.Errorf("failed to query the next payment for account: %w", err)
}
var nextPayment uint64
var nextDataLength uint32
nextPayment := big.NewInt(0)
nextDataLength := uint32(0)
if len(largerResult) > 0 {
nextPayment, err = strconv.ParseUint(largerResult[0]["CumulativePayments"].(*types.AttributeValueMemberN).Value, 10, 64)
if err != nil {
return 0, 0, 0, fmt.Errorf("failed to parse next payment: %w", err)
cumulativePaymentsAttr, ok := largerResult[0]["CumulativePayments"]
if !ok {
return nil, nil, 0, fmt.Errorf("CumulativePayments field not found in result")
}
cumulativePaymentsNum, ok := cumulativePaymentsAttr.(*types.AttributeValueMemberN)
if !ok {
return nil, nil, 0, fmt.Errorf("CumulativePayments has invalid type")
}
setNextPayment, success := nextPayment.SetString(cumulativePaymentsNum.Value, 10)
if !success {
return nil, nil, 0, fmt.Errorf("failed to parse previous payment: %w", err)
}
dataLength, err := strconv.ParseUint(largerResult[0]["DataLength"].(*types.AttributeValueMemberN).Value, 10, 32)
nextPayment = setNextPayment

dataLengthAttr, ok := largerResult[0]["DataLength"]
if !ok {
return nil, nil, 0, fmt.Errorf("DataLength field not found in result")
}
dataLengthNum, ok := dataLengthAttr.(*types.AttributeValueMemberN)
if !ok {
return nil, nil, 0, fmt.Errorf("DataLength has invalid type")
}
dataLength, err := strconv.ParseUint(dataLengthNum.Value, 10, 32)
if err != nil {
return 0, 0, 0, fmt.Errorf("failed to parse blob size: %w", err)
return nil, nil, 0, fmt.Errorf("failed to parse data length: %w", err)
}
nextDataLength = uint32(dataLength)
}
Expand Down Expand Up @@ -290,12 +298,13 @@ func (s *OffchainStore) GetLargestCumulativePayment(ctx context.Context, account
return nil, nil
}

payment, err := strconv.ParseUint(payments[0]["CumulativePayments"].(*types.AttributeValueMemberN).Value, 10, 64)
if err != nil {
var payment *big.Int
_, success := payment.SetString(payments[0]["CumulativePayments"].(*types.AttributeValueMemberN).Value, 10)
if !success {
return nil, fmt.Errorf("failed to parse payment: %w", err)
}

return new(big.Int).SetUint64(payment), nil
return payment, nil
}

func parseBinRecord(bin map[string]types.AttributeValue) (*pb.BinRecord, error) {
Expand Down

0 comments on commit be47a6c

Please sign in to comment.