diff --git a/common/retry/retry.go b/common/retry/retry.go index 34d2ba609..4233b2158 100644 --- a/common/retry/retry.go +++ b/common/retry/retry.go @@ -6,10 +6,12 @@ import ( ) var ( - RetryFailed = errors.New("All retry attempts failed.") + errorRetryFailed = errors.New("All retry attempts failed.") ) +// Strategy is a way to retry on a specific function. type Strategy interface { + // On performs a retry on a specific function, until it doesn't return any error. On(func() error) error } @@ -17,6 +19,7 @@ type retryer struct { NextDelay func(int) int } +// On implements Strategy.On. func (r *retryer) On(method func() error) error { attempt := 0 for { @@ -26,13 +29,14 @@ func (r *retryer) On(method func() error) error { } delay := r.NextDelay(attempt) if delay < 0 { - return RetryFailed + return errorRetryFailed } <-time.After(time.Duration(delay) * time.Millisecond) attempt++ } } +// Timed returns a retry strategy with fixed interval. func Timed(attempts int, delay int) Strategy { return &retryer{ NextDelay: func(attempt int) int { diff --git a/common/retry/retry_test.go b/common/retry/retry_test.go index ed8249736..1c0be1a9a 100644 --- a/common/retry/retry_test.go +++ b/common/retry/retry_test.go @@ -9,7 +9,7 @@ import ( ) var ( - TestError = errors.New("This is a fake error.") + errorTestOnly = errors.New("This is a fake error.") ) func TestNoRetry(t *testing.T) { @@ -33,7 +33,7 @@ func TestRetryOnce(t *testing.T) { err := Timed(10, 1000).On(func() error { if called == 0 { called++ - return TestError + return errorTestOnly } return nil }) @@ -51,7 +51,7 @@ func TestRetryMultiple(t *testing.T) { err := Timed(10, 1000).On(func() error { if called < 5 { called++ - return TestError + return errorTestOnly } return nil }) @@ -69,12 +69,12 @@ func TestRetryExhausted(t *testing.T) { err := Timed(2, 1000).On(func() error { if called < 5 { called++ - return TestError + return errorTestOnly } return nil }) duration := time.Since(startTime) - assert.Error(err).Equals(RetryFailed) + assert.Error(err).Equals(errorRetryFailed) assert.Int64(int64(duration / time.Millisecond)).AtLeast(1900) }