diff --git a/client/register.go b/client/register.go index 2f8885e..79da6de 100644 --- a/client/register.go +++ b/client/register.go @@ -4,6 +4,7 @@ import ( "encoding/json" "fmt" "path" + "strconv" "time" ) @@ -20,7 +21,7 @@ Register will cache the key in the file system and keep it up to date using the -r removes all existing registered keys. -k or -f will instead replace all registered keys with those specified -k specifies a specific key identifier to register -f specifies a file containing a new line separated list of key identifiers --t specifies a timeout for getting the key from the daemon in seconds +-t specifies a timeout for getting the key from the daemon (e.g. '5s', '500ms') -g gets the key as well For a machine to access a certain key, it needs permissions on that key. @@ -39,11 +40,28 @@ var registerRemove = cmdRegister.Flag.Bool("r", false, "") var registerKey = cmdRegister.Flag.String("k", "", "") var registerKeyFile = cmdRegister.Flag.String("f", "", "") var registerAndGet = cmdRegister.Flag.Bool("g", false, "") -var registerTimeout = cmdRegister.Flag.Int("t", 5, "") +var registerTimeout = cmdRegister.Flag.String("t", "5s", "") const registerRecheckTime = 10 * time.Millisecond +func parseTimeout(val string) (time.Duration, error) { + // For backwards-compatibility, a timeout value that is a simple integer will + // be treated as a number of seconds. This ensures that the historical usage + // of the timeout flag like '-t5' retains the same meaning. + if secs, err := strconv.Atoi(val); err == nil { + return time.Duration(secs) * time.Second, nil + } + + // For all other values, use time.ParseDuration. + return time.ParseDuration(val) +} + func runRegister(cmd *Command, args []string) *ErrorStatus { + timeout, err := parseTimeout(*registerTimeout) + if err != nil { + return &ErrorStatus{fmt.Errorf("Invalid value for timeout flag: %s", err.Error()), false} + } + k := NewKeysFile(path.Join(daemonFolder, daemonToRegister)) if *registerRemove && *registerKey == "" && *registerKeyFile == "" { // Short circuit & handle `knox register -r`, which is expected to remove all keys @@ -66,7 +84,6 @@ func runRegister(cmd *Command, args []string) *ErrorStatus { return &ErrorStatus{fmt.Errorf("You must include a key or key file to register. see 'knox help register'"), false} } // Get the list of keys to add - var err error var ks []string if *registerKey == "" { f := NewKeysFile(*registerKeyFile) @@ -99,13 +116,13 @@ func runRegister(cmd *Command, args []string) *ErrorStatus { // If specified, force retrieval of keys if *registerAndGet { key, err := cli.CacheGetKey(*registerKey) - c := time.After(time.Duration(*registerTimeout) * time.Second) + c := time.After(timeout) for err != nil { select { case <-c: return &ErrorStatus{fmt.Errorf( - "Error getting key from daemon (hit timeout after %d seconds); check knox logs for details (most recent error: %v)", - *registerTimeout, err), false} + "Error getting key from daemon (hit timeout after %s seconds); check knox logs for details (most recent error: %v)", + timeout.String(), err), false} case <-time.After(registerRecheckTime): key, err = cli.CacheGetKey(*registerKey) } diff --git a/client/register_test.go b/client/register_test.go new file mode 100644 index 0000000..a9e15c9 --- /dev/null +++ b/client/register_test.go @@ -0,0 +1,29 @@ +package client + +import ( + "testing" + "time" +) + +func TestParseTimeout(t *testing.T) { + testCases := []struct { + str string + dur time.Duration + }{ + {"5", 5 * time.Second}, + {"5s", 5 * time.Second}, + {"0.5s", 500 * time.Millisecond}, + {"500ms", 500 * time.Millisecond}, + } + + for _, tc := range testCases { + r, err := parseTimeout(tc.str) + if err != nil { + t.Errorf("error parsing value %s: %s", tc.str, err) + continue + } + if r != tc.dur { + t.Errorf("mismatch: %s should parse to %s", tc.str, tc.dur.String()) + } + } +}