error_test.go 1.1 KB

12345678910111213141516171819202122232425262728293031323334353637383940414243444546474849
  1. package validate
  2. import (
  3. "errors"
  4. "fmt"
  5. "testing"
  6. )
  7. func TestErrorIs(t *testing.T) {
  8. type TestCase struct {
  9. A error
  10. B error
  11. Want bool
  12. }
  13. testCases := []TestCase{
  14. // Want any validation error
  15. {A: Err, B: Err, Want: true},
  16. {A: ErrDisallowedChars, B: Err, Want: true},
  17. {A: ErrMustBeGreater, B: Err, Want: true},
  18. // Want specific validation error
  19. {A: ErrDisallowedChars, B: ErrDisallowedChars, Want: true},
  20. {A: ErrMustBeGreater, B: ErrMustBeGreater, Want: true},
  21. // Want not specific validation error
  22. {A: Err, B: ErrDisallowedChars},
  23. {A: Err, B: ErrMustBeGreater},
  24. {A: ErrMustBeGreater, B: ErrDisallowedChars},
  25. {A: ErrDisallowedChars, B: ErrMustBeGreater},
  26. // Want not any other error
  27. {A: ErrDisallowedChars, B: errors.New("contains disallowed characters")},
  28. {A: ErrMustBeGreater, B: errors.New("must be greater than %v")},
  29. }
  30. for _, testCase := range testCases {
  31. a, b, want := testCase.A, testCase.B, testCase.Want
  32. t.Run(fmt.Sprintf("%v/%v", a, b), func(t *testing.T) {
  33. got := errors.Is(a, b)
  34. if got != want {
  35. t.Error("got", got)
  36. t.Error("want", want)
  37. }
  38. })
  39. }
  40. }