error_test.go 4.1 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687888990919293949596979899100101102103104105106107108109110111112113114115116117118119120121122123124125126127128129130131132133134135136137138139140141142143144145146147148149150151152153154155156157158159160161162163164165166167168169170171172173174175176177178179180181182183184185186187188
  1. package rest
  2. import (
  3. "errors"
  4. "io"
  5. "net/http/httptest"
  6. "testing"
  7. )
  8. type ErrorTestCase struct {
  9. Input Error
  10. Code int
  11. Str string
  12. JSON string
  13. Err error
  14. }
  15. var errorTestCases = []ErrorTestCase{
  16. // Empty error
  17. {
  18. Input: Err,
  19. Code: 200,
  20. Str: "",
  21. JSON: `{"message":""}`,
  22. },
  23. // Standard errors
  24. {
  25. Input: ErrPermanentRedirect,
  26. Code: 308,
  27. Str: "Permanent Redirect",
  28. JSON: `{"message":"Permanent Redirect"}`,
  29. },
  30. {
  31. Input: ErrNotFound,
  32. Code: 404,
  33. Str: "Not Found",
  34. JSON: `{"message":"Not Found"}`,
  35. },
  36. {
  37. Input: ErrInternalServerError,
  38. Code: 500,
  39. Str: "Internal Server Error",
  40. JSON: `{"message":"Internal Server Error"}`,
  41. },
  42. // Error with changed message
  43. {
  44. Input: ErrBadRequest.WithMessage("Invalid Recipe"),
  45. Code: 400,
  46. Str: "Invalid Recipe",
  47. JSON: `{"message":"Invalid Recipe"}`,
  48. },
  49. // Error with data
  50. {
  51. Input: ErrGatewayTimeout.WithData(map[string]any{"service": "RecipeDatabase"}),
  52. Code: 504,
  53. Str: "Gateway Timeout",
  54. JSON: `{"message":"Gateway Timeout","data":{"service":"RecipeDatabase"}}`,
  55. },
  56. // Error with value
  57. {
  58. Input: ErrGatewayTimeout.WithValue("service", "RecipeDatabase"),
  59. Code: 504,
  60. Str: "Gateway Timeout",
  61. JSON: `{"message":"Gateway Timeout","data":{"service":"RecipeDatabase"}}`,
  62. },
  63. // Error with error
  64. {
  65. Input: ErrInternalServerError.WithError(errors.New("recipe is too delicious")),
  66. Code: 500,
  67. Str: "Internal Server Error",
  68. JSON: `{"message":"Internal Server Error","data":{"error":"recipe is too delicious"}}`,
  69. },
  70. }
  71. func TestErrorIs(t *testing.T) {
  72. type TestCase struct {
  73. Err error
  74. Target error
  75. Is bool
  76. }
  77. testCases := []TestCase{
  78. // Is any REST API error
  79. {Err: Err, Target: Err, Is: true},
  80. {Err: ErrNotFound, Target: Err, Is: true},
  81. {Err: ErrBadGateway, Target: Err, Is: true},
  82. // Is specific REST API error
  83. {Err: ErrNotFound, Target: ErrNotFound, Is: true},
  84. {Err: ErrBadGateway, Target: ErrBadGateway, Is: true},
  85. // Is not specific REST API error
  86. {Err: Err, Target: ErrNotFound},
  87. {Err: Err, Target: ErrBadGateway},
  88. {Err: ErrPermanentRedirect, Target: ErrNotFound},
  89. {Err: ErrGatewayTimeout, Target: ErrBadGateway},
  90. // Is not any other error
  91. {Err: ErrNotFound, Target: errors.New("Not Found")},
  92. {Err: ErrBadGateway, Target: errors.New("Bad Gateway")},
  93. // Any other error is not a REST API Error
  94. {Err: errors.New("Not Found"), Target: Err},
  95. {Err: errors.New("Not Found"), Target: ErrNotFound},
  96. {Err: errors.New("Bad Gateway"), Target: ErrBadGateway},
  97. }
  98. for i, tc := range testCases {
  99. t.Logf("(%d) Testing %v against %v", i, tc.Err, tc.Target)
  100. if errors.Is(tc.Err, tc.Target) {
  101. if !tc.Is {
  102. t.Errorf("%v should not equal %v", tc.Err, tc.Target)
  103. }
  104. } else {
  105. if tc.Is {
  106. t.Errorf("%v should equal %v", tc.Err, tc.Target)
  107. }
  108. }
  109. }
  110. }
  111. func TestErrorWrite(t *testing.T) {
  112. for i, tc := range errorTestCases {
  113. t.Logf("(%d) Testing %v", i, tc.Input)
  114. rec := httptest.NewRecorder()
  115. _, err := tc.Input.Write(rec)
  116. if !errors.Is(err, tc.Err) {
  117. t.Errorf("Expected error %v, got %v", tc.Err, err)
  118. }
  119. if err != nil {
  120. continue
  121. }
  122. res := rec.Result()
  123. if res.StatusCode != tc.Code {
  124. t.Errorf("Expected status code %d, got %d", tc.Code, res.StatusCode)
  125. }
  126. body, err := io.ReadAll(res.Body)
  127. if err != nil {
  128. t.Errorf("Unexpected error reading response body: %v", err)
  129. continue
  130. }
  131. if string(body) != tc.Str {
  132. t.Errorf("Expected body %q, got %q", tc.Str, string(body))
  133. }
  134. }
  135. }
  136. func TestErrorWriteJSON(t *testing.T) {
  137. for i, tc := range errorTestCases {
  138. t.Logf("(%d) Testing %v", i, tc.Input)
  139. rec := httptest.NewRecorder()
  140. err := tc.Input.WriteJSON(rec)
  141. if !errors.Is(err, tc.Err) {
  142. t.Errorf("Expected error %v, got %v", tc.Err, err)
  143. }
  144. if err != nil {
  145. continue
  146. }
  147. res := rec.Result()
  148. if res.StatusCode != tc.Code {
  149. t.Errorf("Expected status code %d, got %d", tc.Code, res.StatusCode)
  150. }
  151. body, err := io.ReadAll(res.Body)
  152. if err != nil {
  153. t.Errorf("Unexpected error reading response body: %v", err)
  154. continue
  155. }
  156. if string(body) != tc.JSON {
  157. t.Errorf("Expected body %q, got %q", tc.JSON, string(body))
  158. }
  159. }
  160. }