소스 검색

add error json write tests

Aneurin Barker Snook 1 년 전
부모
커밋
a7a430a188
2개의 변경된 파일87개의 추가작업 그리고 3개의 파일을 삭제
  1. 7 3
      error.go
  2. 80 0
      error_test.go

+ 7 - 3
error.go

@@ -11,7 +11,7 @@ var (
 	ErrMovedPermanently  = NewError(http.StatusMovedPermanently, "")  // 301
 	ErrFound             = NewError(http.StatusFound, "")             // 302
 	ErrTemporaryRedirect = NewError(http.StatusTemporaryRedirect, "") // 307
-	ErrPermamentRedirect = NewError(http.StatusPermanentRedirect, "") // 308
+	ErrPermanentRedirect = NewError(http.StatusPermanentRedirect, "") // 308
 
 	ErrBadRequest       = NewError(http.StatusBadRequest, "")       // 400
 	ErrUnauthorized     = NewError(http.StatusUnauthorized, "")     // 401
@@ -31,7 +31,7 @@ var (
 // Error represents a REST API error.
 // It can be marshaled to JSON with ease and provides a standard format for printing errors and additional data.
 type Error struct {
-	StatusCode int                    `json:"statusCode"`     // HTTP status code (200, 404, 500 etc.)
+	StatusCode int                    `json:"-"`              // HTTP status code (200, 404, 500 etc.)
 	Message    string                 `json:"message"`        // Status message ("OK", "Not found", "Internal server error" etc.)
 	Data       map[string]interface{} `json:"data,omitempty"` // Optional additional data.
 }
@@ -106,7 +106,11 @@ func (e Error) Write(w http.ResponseWriter) {
 
 // WriteJSON writes the HTTP error to an HTTP response as JSON.
 func (e Error) WriteJSON(w http.ResponseWriter) error {
-	return WriteResponseJSON(w, e.StatusCode, e)
+	statusCode := e.StatusCode
+	if statusCode == 0 {
+		statusCode = 200
+	}
+	return WriteResponseJSON(w, statusCode, e)
 }
 
 // NewError creates a new REST API error.

+ 80 - 0
error_test.go

@@ -0,0 +1,80 @@
+package rest
+
+import (
+	"errors"
+	"io"
+	"net/http/httptest"
+	"testing"
+)
+
+func TestErrorWriteJSON(t *testing.T) {
+	type TestCase struct {
+		Input  Error
+		C      int
+		Output string
+		Err    error
+	}
+
+	testCases := []TestCase{
+		// Empty error
+		{Input: Err, C: 200, Output: `{"message":""}`},
+
+		// Standard errors
+		{Input: ErrPermanentRedirect, C: 308, Output: `{"message":"Permanent Redirect"}`},
+		{Input: ErrNotFound, C: 404, Output: `{"message":"Not Found"}`},
+		{Input: ErrInternalServerError, C: 500, Output: `{"message":"Internal Server Error"}`},
+
+		// Error with changed message
+		{Input: ErrBadRequest.WithMessage("Invalid Recipe"), C: 400, Output: `{"message":"Invalid Recipe"}`},
+
+		// Error with data
+		{
+			Input:  ErrGatewayTimeout.WithData(map[string]any{"service": "RecipeDatabase"}),
+			C:      504,
+			Output: `{"message":"Gateway Timeout","data":{"service":"RecipeDatabase"}}`,
+		},
+
+		// Error with value
+		{
+			Input:  ErrGatewayTimeout.WithValue("service", "RecipeDatabase"),
+			C:      504,
+			Output: `{"message":"Gateway Timeout","data":{"service":"RecipeDatabase"}}`,
+		},
+
+		// Error with error
+		{
+			Input:  ErrInternalServerError.WithError(errors.New("recipe is too delicious")),
+			C:      500,
+			Output: `{"message":"Internal Server Error","data":{"error":"recipe is too delicious"}}`,
+		},
+	}
+
+	for i, tc := range testCases {
+		t.Logf("(%d) Testing %v", i, tc.Input)
+
+		rec := httptest.NewRecorder()
+
+		err := tc.Input.WriteJSON(rec)
+		if err != tc.Err {
+			t.Errorf("Expected error %v, got %v", tc.Err, err)
+		}
+		if err != nil {
+			continue
+		}
+
+		res := rec.Result()
+		if res.StatusCode != tc.C {
+			t.Errorf("Expected status code %d, got %d", tc.C, res.StatusCode)
+		}
+
+		body, err := io.ReadAll(res.Body)
+		if err != nil {
+			t.Errorf("Unexpected error reading response body: %v", err)
+			continue
+		}
+
+		if string(body) != tc.Output {
+			t.Errorf("Expected body %q, got %q", tc.Output, string(body))
+		}
+	}
+}