From 212fc8ad2a828d2a76bc5770a298812398fbb7b0 Mon Sep 17 00:00:00 2001 From: Stojan Dimitrovski Date: Tue, 25 Oct 2022 11:46:18 +0200 Subject: [PATCH] x/crypto/bcrypt: add cooperative scheduling and cancellation to bcrypt --- bcrypt/bcrypt.go | 57 ++++++++++++++++++++++++++++++++++++------- bcrypt/bcrypt_test.go | 42 +++++++++++++++++++++++++------ 2 files changed, 83 insertions(+), 16 deletions(-) diff --git a/bcrypt/bcrypt.go b/bcrypt/bcrypt.go index aeb73f81a1..dc855721aa 100644 --- a/bcrypt/bcrypt.go +++ b/bcrypt/bcrypt.go @@ -8,11 +8,13 @@ package bcrypt // import "golang.org/x/crypto/bcrypt" // The code is a port of Provos and Mazières's C implementation. import ( + "context" "crypto/rand" "crypto/subtle" "errors" "fmt" "io" + "runtime" "strconv" "golang.org/x/crypto/blowfish" @@ -24,6 +26,10 @@ const ( DefaultCost int = 10 // the cost that will actually be set if a cost below MinCost is passed into GenerateFromPassword ) +const ( + cooperativeRounds uint64 = 1 << 6 // the number of rounds after which cooperative scheduling is invoked +) + // The error returned from CompareHashAndPassword when a password and hash do // not match. var ErrMismatchedHashAndPassword = errors.New("crypto/bcrypt: hashedPassword is not the hash of the given password") @@ -85,9 +91,21 @@ type hashed struct { // GenerateFromPassword returns the bcrypt hash of the password at the given // cost. If the cost given is less than MinCost, the cost will be set to // DefaultCost, instead. Use CompareHashAndPassword, as defined in this package, -// to compare the returned hashed password with its cleartext version. +// to compare the returned hashed password with its cleartext version. This +// function is not cooperative and can stall other goroutines while it is executing. +// Use GenerateFromPasswordWithContext for the cooperative version. func GenerateFromPassword(password []byte, cost int) ([]byte, error) { - p, err := newFromPassword(password, cost) + return GenerateFromPasswordWithContext(nil, password, cost) +} + +// GenerateFromPassword returns the bcrypt hash of the password at the given +// cost. If the cost given is less than MinCost, the cost will be set to +// DefaultCost, instead. Use CompareHashAndPasswordWithContext, as defined in this package, +// to compare the returned hashed password with its cleartext version. +// Passing in a context lets the hashing algorithm cooperate with other goroutines, +// and can be used to cancel a long-running operation. +func GenerateFromPasswordWithContext(ctx context.Context, password []byte, cost int) ([]byte, error) { + p, err := newFromPassword(ctx, password, cost) if err != nil { return nil, err } @@ -95,14 +113,24 @@ func GenerateFromPassword(password []byte, cost int) ([]byte, error) { } // CompareHashAndPassword compares a bcrypt hashed password with its possible -// plaintext equivalent. Returns nil on success, or an error on failure. +// plaintext equivalent. Returns nil on success, or an error on failure. This +// function is not cooperative and can stall other goroutines while it is executing. +// Use CompareHashAndPasswordWithContext for the cooperative version. func CompareHashAndPassword(hashedPassword, password []byte) error { + return CompareHashAndPasswordWithContext(nil, hashedPassword, password) +} + +// CompareHashAndPasswordWithContext compares a bcrypt hashed password with its possible +// plaintext equivalent. Returns nil on success, or an error on failure. Passing in a context +// lets the hashing algorithm cooperate with other goroutines, and can be used to cancel a +// long-running operation. +func CompareHashAndPasswordWithContext(ctx context.Context, hashedPassword, password []byte) error { p, err := newFromHash(hashedPassword) if err != nil { return err } - otherHash, err := bcrypt(password, p.cost, p.salt) + otherHash, err := bcrypt(ctx, password, p.cost, p.salt) if err != nil { return err } @@ -127,7 +155,7 @@ func Cost(hashedPassword []byte) (int, error) { return p.cost, nil } -func newFromPassword(password []byte, cost int) (*hashed, error) { +func newFromPassword(ctx context.Context, password []byte, cost int) (*hashed, error) { if cost < MinCost { cost = DefaultCost } @@ -148,7 +176,7 @@ func newFromPassword(password []byte, cost int) (*hashed, error) { } p.salt = base64Encode(unencodedSalt) - hash, err := bcrypt(password, p.cost, p.salt) + hash, err := bcrypt(ctx, password, p.cost, p.salt) if err != nil { return nil, err } @@ -184,11 +212,11 @@ func newFromHash(hashedSecret []byte) (*hashed, error) { return p, nil } -func bcrypt(password []byte, cost int, salt []byte) ([]byte, error) { +func bcrypt(ctx context.Context, password []byte, cost int, salt []byte) ([]byte, error) { cipherData := make([]byte, len(magicCipherData)) copy(cipherData, magicCipherData) - c, err := expensiveBlowfishSetup(password, uint32(cost), salt) + c, err := expensiveBlowfishSetup(ctx, password, uint32(cost), salt) if err != nil { return nil, err } @@ -205,7 +233,7 @@ func bcrypt(password []byte, cost int, salt []byte) ([]byte, error) { return hsh, nil } -func expensiveBlowfishSetup(key []byte, cost uint32, salt []byte) (*blowfish.Cipher, error) { +func expensiveBlowfishSetup(ctx context.Context, key []byte, cost uint32, salt []byte) (*blowfish.Cipher, error) { csalt, err := base64Decode(salt) if err != nil { return nil, err @@ -226,6 +254,17 @@ func expensiveBlowfishSetup(key []byte, cost uint32, salt []byte) (*blowfish.Cip for i = 0; i < rounds; i++ { blowfish.ExpandKey(ckey, c) blowfish.ExpandKey(csalt, c) + + if ctx != nil && (i+1)%cooperativeRounds == 0 { + // i+1 because we want to invoke after cooperativeRounds have processed, not immediately + select { + case <-ctx.Done(): + return nil, ctx.Err() + + default: + runtime.Gosched() + } + } } return c, nil diff --git a/bcrypt/bcrypt_test.go b/bcrypt/bcrypt_test.go index b7162d8217..05d870cad1 100644 --- a/bcrypt/bcrypt_test.go +++ b/bcrypt/bcrypt_test.go @@ -6,6 +6,8 @@ package bcrypt import ( "bytes" + "context" + "errors" "fmt" "testing" ) @@ -33,7 +35,7 @@ func TestBcryptingIsCorrect(t *testing.T) { salt := []byte("XajjQvNhvvRt5GSeFk1xFe") expectedHash := []byte("$2a$10$XajjQvNhvvRt5GSeFk1xFeyqRrsxkhBkUiQeg0dt.wU1qD4aFDcga") - hash, err := bcrypt(pass, 10, salt) + hash, err := bcrypt(nil, pass, 10, salt) if err != nil { t.Fatalf("bcrypt blew up: %v", err) } @@ -56,7 +58,7 @@ func TestBcryptingIsCorrect(t *testing.T) { func TestVeryShortPasswords(t *testing.T) { key := []byte("k") salt := []byte("XajjQvNhvvRt5GSeFk1xFe") - _, err := bcrypt(key, 10, salt) + _, err := bcrypt(nil, key, 10, salt) if err != nil { t.Errorf("One byte key resulted in error: %s", err) } @@ -67,7 +69,7 @@ func TestTooLongPasswordsWork(t *testing.T) { // One byte over the usual 56 byte limit that blowfish has tooLongPass := []byte("012345678901234567890123456789012345678901234567890123456") tooLongExpected := []byte("$2a$10$XajjQvNhvvRt5GSeFk1xFe5l47dONXg781AmZtd869sO8zfsHuw7C") - hash, err := bcrypt(tooLongPass, 10, salt) + hash, err := bcrypt(nil, tooLongPass, 10, salt) if err != nil { t.Fatalf("bcrypt blew up on long password: %v", err) } @@ -156,13 +158,13 @@ func TestCostValidationInHash(t *testing.T) { pass := []byte("mypassword") for c := 0; c < MinCost; c++ { - p, _ := newFromPassword(pass, c) + p, _ := newFromPassword(nil, pass, c) if p.cost != DefaultCost { t.Errorf("newFromPassword should default costs below %d to %d, but was %d", MinCost, DefaultCost, p.cost) } } - p, _ := newFromPassword(pass, 14) + p, _ := newFromPassword(nil, pass, 14) if p.cost != 14 { t.Errorf("newFromPassword should default cost to 14, but was %d", p.cost) } @@ -172,7 +174,7 @@ func TestCostValidationInHash(t *testing.T) { t.Errorf("newFromHash should maintain the cost at %d, but was %d", p.cost, hp.cost) } - _, err := newFromPassword(pass, 32) + _, err := newFromPassword(nil, pass, 32) if err == nil { t.Fatalf("newFromPassword: should return a cost error") } @@ -182,7 +184,7 @@ func TestCostValidationInHash(t *testing.T) { } func TestCostReturnsWithLeadingZeroes(t *testing.T) { - hp, _ := newFromPassword([]byte("abcdefgh"), 7) + hp, _ := newFromPassword(nil, []byte("abcdefgh"), 7) cost := hp.Hash()[4:7] expected := []byte("07$") @@ -241,3 +243,29 @@ func TestNoSideEffectsFromCompare(t *testing.T) { t.Errorf("got=%q want=%q", got, want) } } + +func TestCancellationLongDuration(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + // using DefaultCost means that cooperative scheduling will be used at least once + // as there are many rounds to go through + _, err := GenerateFromPasswordWithContext(ctx, []byte("mylongpassword1234"), DefaultCost) + + if !errors.Is(err, context.Canceled) { + t.Errorf("got=%v want=%v", err, context.Canceled) + } +} + +func TestCancellationShortDuration(t *testing.T) { + ctx, cancel := context.WithCancel(context.Background()) + cancel() + + // using MinCost means that cooperative scheduling won't be used + // as there are a few rounds to go through + _, err := GenerateFromPasswordWithContext(ctx, []byte("mylongpassword1234"), MinCost) + + if err != nil { + t.Errorf("got=%v want=%v", err, nil) + } +}