1
0
Эх сурвалжийг харах

add standard error write tests

Aneurin Barker Snook 1 жил өмнө
parent
commit
7655bdf73d
2 өөрчлөгдсөн 112 нэмэгдсэн , 50 устгасан
  1. 8 6
      error.go
  2. 104 44
      error_test.go

+ 8 - 6
error.go

@@ -99,18 +99,20 @@ func (e Error) WithValue(name string, value any) Error {
 
 // Write writes the HTTP error to an HTTP response as plain text.
 // Additional data is omitted.
-func (e Error) Write(w http.ResponseWriter) {
+func (e Error) Write(w http.ResponseWriter) (int, error) {
+	if e.StatusCode == 0 {
+		e.StatusCode = 200
+	}
 	w.WriteHeader(e.StatusCode)
-	w.Write([]byte(e.Message))
+	return w.Write([]byte(e.Message))
 }
 
 // WriteJSON writes the HTTP error to an HTTP response as JSON.
 func (e Error) WriteJSON(w http.ResponseWriter) error {
-	statusCode := e.StatusCode
-	if statusCode == 0 {
-		statusCode = 200
+	if e.StatusCode == 0 {
+		e.StatusCode = 200
 	}
-	return WriteResponseJSON(w, statusCode, e)
+	return WriteResponseJSON(w, e.StatusCode, e)
 }
 
 // NewError creates a new REST API error.

+ 104 - 44
error_test.go

@@ -7,49 +7,109 @@ import (
 	"testing"
 )
 
-func TestErrorWriteJSON(t *testing.T) {
-	type TestCase struct {
-		Input  Error
-		C      int
-		Output string
-		Err    error
-	}
+type ErrorTestCase struct {
+	Input Error
+	Code  int
+	Str   string
+	JSON  string
+	Err   error
+}
+
+var errorTestCases = []ErrorTestCase{
+	// Empty error
+	{
+		Input: Err,
+		Code:  200,
+		Str:   "",
+		JSON:  `{"message":""}`,
+	},
+
+	// Standard errors
+	{
+		Input: ErrPermanentRedirect,
+		Code:  308,
+		Str:   "Permanent Redirect",
+		JSON:  `{"message":"Permanent Redirect"}`,
+	},
+	{
+		Input: ErrNotFound,
+		Code:  404,
+		Str:   "Not Found",
+		JSON:  `{"message":"Not Found"}`,
+	},
+	{
+		Input: ErrInternalServerError,
+		Code:  500,
+		Str:   "Internal Server Error",
+		JSON:  `{"message":"Internal Server Error"}`,
+	},
 
-	testCases := []TestCase{
-		// Empty error
-		{Input: Err, C: 200, Output: `{"message":""}`},
-
-		// Standard errors
-		{Input: ErrPermanentRedirect, C: 308, Output: `{"message":"Permanent Redirect"}`},
-		{Input: ErrNotFound, C: 404, Output: `{"message":"Not Found"}`},
-		{Input: ErrInternalServerError, C: 500, Output: `{"message":"Internal Server Error"}`},
-
-		// Error with changed message
-		{Input: ErrBadRequest.WithMessage("Invalid Recipe"), C: 400, Output: `{"message":"Invalid Recipe"}`},
-
-		// Error with data
-		{
-			Input:  ErrGatewayTimeout.WithData(map[string]any{"service": "RecipeDatabase"}),
-			C:      504,
-			Output: `{"message":"Gateway Timeout","data":{"service":"RecipeDatabase"}}`,
-		},
-
-		// Error with value
-		{
-			Input:  ErrGatewayTimeout.WithValue("service", "RecipeDatabase"),
-			C:      504,
-			Output: `{"message":"Gateway Timeout","data":{"service":"RecipeDatabase"}}`,
-		},
-
-		// Error with error
-		{
-			Input:  ErrInternalServerError.WithError(errors.New("recipe is too delicious")),
-			C:      500,
-			Output: `{"message":"Internal Server Error","data":{"error":"recipe is too delicious"}}`,
-		},
+	// Error with changed message
+	{
+		Input: ErrBadRequest.WithMessage("Invalid Recipe"),
+		Code:  400,
+		Str:   "Invalid Recipe",
+		JSON:  `{"message":"Invalid Recipe"}`,
+	},
+
+	// Error with data
+	{
+		Input: ErrGatewayTimeout.WithData(map[string]any{"service": "RecipeDatabase"}),
+		Code:  504,
+		Str:   "Gateway Timeout",
+		JSON:  `{"message":"Gateway Timeout","data":{"service":"RecipeDatabase"}}`,
+	},
+
+	// Error with value
+	{
+		Input: ErrGatewayTimeout.WithValue("service", "RecipeDatabase"),
+		Code:  504,
+		Str:   "Gateway Timeout",
+		JSON:  `{"message":"Gateway Timeout","data":{"service":"RecipeDatabase"}}`,
+	},
+
+	// Error with error
+	{
+		Input: ErrInternalServerError.WithError(errors.New("recipe is too delicious")),
+		Code:  500,
+		Str:   "Internal Server Error",
+		JSON:  `{"message":"Internal Server Error","data":{"error":"recipe is too delicious"}}`,
+	},
+}
+
+func TestErrorWrite(t *testing.T) {
+	for i, tc := range errorTestCases {
+		t.Logf("(%d) Testing %v", i, tc.Input)
+
+		rec := httptest.NewRecorder()
+
+		_, err := tc.Input.Write(rec)
+		if err != tc.Err {
+			t.Errorf("Expected error %v, got %v", tc.Err, err)
+		}
+		if err != nil {
+			continue
+		}
+
+		res := rec.Result()
+		if res.StatusCode != tc.Code {
+			t.Errorf("Expected status code %d, got %d", tc.Code, res.StatusCode)
+		}
+
+		body, err := io.ReadAll(res.Body)
+		if err != nil {
+			t.Errorf("Unexpected error reading response body: %v", err)
+			continue
+		}
+
+		if string(body) != tc.Str {
+			t.Errorf("Expected body %q, got %q", tc.Str, string(body))
+		}
 	}
+}
 
-	for i, tc := range testCases {
+func TestErrorWriteJSON(t *testing.T) {
+	for i, tc := range errorTestCases {
 		t.Logf("(%d) Testing %v", i, tc.Input)
 
 		rec := httptest.NewRecorder()
@@ -63,8 +123,8 @@ func TestErrorWriteJSON(t *testing.T) {
 		}
 
 		res := rec.Result()
-		if res.StatusCode != tc.C {
-			t.Errorf("Expected status code %d, got %d", tc.C, res.StatusCode)
+		if res.StatusCode != tc.Code {
+			t.Errorf("Expected status code %d, got %d", tc.Code, res.StatusCode)
 		}
 
 		body, err := io.ReadAll(res.Body)
@@ -73,8 +133,8 @@ func TestErrorWriteJSON(t *testing.T) {
 			continue
 		}
 
-		if string(body) != tc.Output {
-			t.Errorf("Expected body %q, got %q", tc.Output, string(body))
+		if string(body) != tc.JSON {
+			t.Errorf("Expected body %q, got %q", tc.JSON, string(body))
 		}
 	}
 }