Skip to content

Commit

Permalink
Merge pull request #87 from pinterest/cs/timeout
Browse files Browse the repository at this point in the history
Add support for fractional timeouts
  • Loading branch information
csstaub authored Oct 5, 2022
2 parents 3f2d988 + 0287fc8 commit b4363a4
Show file tree
Hide file tree
Showing 2 changed files with 52 additions and 6 deletions.
29 changes: 23 additions & 6 deletions client/register.go
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@ import (
"encoding/json"
"fmt"
"path"
"strconv"
"time"
)

Expand All @@ -20,7 +21,7 @@ Register will cache the key in the file system and keep it up to date using the
-r removes all existing registered keys. -k or -f will instead replace all registered keys with those specified
-k specifies a specific key identifier to register
-f specifies a file containing a new line separated list of key identifiers
-t specifies a timeout for getting the key from the daemon in seconds
-t specifies a timeout for getting the key from the daemon (e.g. '5s', '500ms')
-g gets the key as well
For a machine to access a certain key, it needs permissions on that key.
Expand All @@ -39,11 +40,28 @@ var registerRemove = cmdRegister.Flag.Bool("r", false, "")
var registerKey = cmdRegister.Flag.String("k", "", "")
var registerKeyFile = cmdRegister.Flag.String("f", "", "")
var registerAndGet = cmdRegister.Flag.Bool("g", false, "")
var registerTimeout = cmdRegister.Flag.Int("t", 5, "")
var registerTimeout = cmdRegister.Flag.String("t", "5s", "")

const registerRecheckTime = 10 * time.Millisecond

func parseTimeout(val string) (time.Duration, error) {
// For backwards-compatibility, a timeout value that is a simple integer will
// be treated as a number of seconds. This ensures that the historical usage
// of the timeout flag like '-t5' retains the same meaning.
if secs, err := strconv.Atoi(val); err == nil {
return time.Duration(secs) * time.Second, nil
}

// For all other values, use time.ParseDuration.
return time.ParseDuration(val)
}

func runRegister(cmd *Command, args []string) *ErrorStatus {
timeout, err := parseTimeout(*registerTimeout)
if err != nil {
return &ErrorStatus{fmt.Errorf("Invalid value for timeout flag: %s", err.Error()), false}
}

k := NewKeysFile(path.Join(daemonFolder, daemonToRegister))
if *registerRemove && *registerKey == "" && *registerKeyFile == "" {
// Short circuit & handle `knox register -r`, which is expected to remove all keys
Expand All @@ -66,7 +84,6 @@ func runRegister(cmd *Command, args []string) *ErrorStatus {
return &ErrorStatus{fmt.Errorf("You must include a key or key file to register. see 'knox help register'"), false}
}
// Get the list of keys to add
var err error
var ks []string
if *registerKey == "" {
f := NewKeysFile(*registerKeyFile)
Expand Down Expand Up @@ -99,13 +116,13 @@ func runRegister(cmd *Command, args []string) *ErrorStatus {
// If specified, force retrieval of keys
if *registerAndGet {
key, err := cli.CacheGetKey(*registerKey)
c := time.After(time.Duration(*registerTimeout) * time.Second)
c := time.After(timeout)
for err != nil {
select {
case <-c:
return &ErrorStatus{fmt.Errorf(
"Error getting key from daemon (hit timeout after %d seconds); check knox logs for details (most recent error: %v)",
*registerTimeout, err), false}
"Error getting key from daemon (hit timeout after %s seconds); check knox logs for details (most recent error: %v)",
timeout.String(), err), false}
case <-time.After(registerRecheckTime):
key, err = cli.CacheGetKey(*registerKey)
}
Expand Down
29 changes: 29 additions & 0 deletions client/register_test.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
package client

import (
"testing"
"time"
)

func TestParseTimeout(t *testing.T) {
testCases := []struct {
str string
dur time.Duration
}{
{"5", 5 * time.Second},
{"5s", 5 * time.Second},
{"0.5s", 500 * time.Millisecond},
{"500ms", 500 * time.Millisecond},
}

for _, tc := range testCases {
r, err := parseTimeout(tc.str)
if err != nil {
t.Errorf("error parsing value %s: %s", tc.str, err)
continue
}
if r != tc.dur {
t.Errorf("mismatch: %s should parse to %s", tc.str, tc.dur.String())
}
}
}

0 comments on commit b4363a4

Please sign in to comment.