package rest

import (
	"errors"
	"io"
	"net/http/httptest"
	"testing"
)

type ErrorTestCase struct {
	Input Error
	Code  int
	Str   string
	JSON  string
	Err   error
}

var errorTestCases = []ErrorTestCase{
	// Empty error
	{
		Input: Err,
		Code:  200,
		Str:   "",
		JSON:  `{"message":""}`,
	},

	// Standard errors
	{
		Input: ErrPermanentRedirect,
		Code:  308,
		Str:   "Permanent Redirect",
		JSON:  `{"message":"Permanent Redirect"}`,
	},
	{
		Input: ErrNotFound,
		Code:  404,
		Str:   "Not Found",
		JSON:  `{"message":"Not Found"}`,
	},
	{
		Input: ErrInternalServerError,
		Code:  500,
		Str:   "Internal Server Error",
		JSON:  `{"message":"Internal Server Error"}`,
	},

	// Error with changed message
	{
		Input: ErrBadRequest.WithMessage("Invalid Recipe"),
		Code:  400,
		Str:   "Invalid Recipe",
		JSON:  `{"message":"Invalid Recipe"}`,
	},

	// Error with data
	{
		Input: ErrGatewayTimeout.WithData(map[string]any{"service": "RecipeDatabase"}),
		Code:  504,
		Str:   "Gateway Timeout",
		JSON:  `{"message":"Gateway Timeout","data":{"service":"RecipeDatabase"}}`,
	},

	// Error with value
	{
		Input: ErrGatewayTimeout.WithValue("service", "RecipeDatabase"),
		Code:  504,
		Str:   "Gateway Timeout",
		JSON:  `{"message":"Gateway Timeout","data":{"service":"RecipeDatabase"}}`,
	},

	// Error with error
	{
		Input: ErrInternalServerError.WithError(errors.New("recipe is too delicious")),
		Code:  500,
		Str:   "Internal Server Error",
		JSON:  `{"message":"Internal Server Error","data":{"error":"recipe is too delicious"}}`,
	},
}

func TestErrorIs(t *testing.T) {
	type TestCase struct {
		Err    error
		Target error
		Is     bool
	}

	testCases := []TestCase{
		// Is any REST API error
		{Err: Err, Target: Err, Is: true},
		{Err: ErrNotFound, Target: Err, Is: true},
		{Err: ErrBadGateway, Target: Err, Is: true},

		// Is specific REST API error
		{Err: ErrNotFound, Target: ErrNotFound, Is: true},
		{Err: ErrBadGateway, Target: ErrBadGateway, Is: true},

		// Is not specific REST API error
		{Err: Err, Target: ErrNotFound},
		{Err: Err, Target: ErrBadGateway},
		{Err: ErrPermanentRedirect, Target: ErrNotFound},
		{Err: ErrGatewayTimeout, Target: ErrBadGateway},

		// Is not any other error
		{Err: ErrNotFound, Target: errors.New("Not Found")},
		{Err: ErrBadGateway, Target: errors.New("Bad Gateway")},

		// Any other error is not a REST API Error
		{Err: errors.New("Not Found"), Target: Err},
		{Err: errors.New("Not Found"), Target: ErrNotFound},
		{Err: errors.New("Bad Gateway"), Target: ErrBadGateway},
	}

	for i, tc := range testCases {
		t.Logf("(%d) Testing %v against %v", i, tc.Err, tc.Target)

		if errors.Is(tc.Err, tc.Target) {
			if !tc.Is {
				t.Errorf("%v should not equal %v", tc.Err, tc.Target)
			}
		} else {
			if tc.Is {
				t.Errorf("%v should equal %v", tc.Err, tc.Target)
			}
		}
	}
}

func TestErrorWrite(t *testing.T) {
	for i, tc := range errorTestCases {
		t.Logf("(%d) Testing %v", i, tc.Input)

		rec := httptest.NewRecorder()

		_, err := tc.Input.Write(rec)
		if !errors.Is(err, tc.Err) {
			t.Errorf("Expected error %v, got %v", tc.Err, err)
		}
		if err != nil {
			continue
		}

		res := rec.Result()
		if res.StatusCode != tc.Code {
			t.Errorf("Expected status code %d, got %d", tc.Code, res.StatusCode)
		}

		body, err := io.ReadAll(res.Body)
		if err != nil {
			t.Errorf("Unexpected error reading response body: %v", err)
			continue
		}

		if string(body) != tc.Str {
			t.Errorf("Expected body %q, got %q", tc.Str, string(body))
		}
	}
}

func TestErrorWriteJSON(t *testing.T) {
	for i, tc := range errorTestCases {
		t.Logf("(%d) Testing %v", i, tc.Input)

		rec := httptest.NewRecorder()

		err := tc.Input.WriteJSON(rec)
		if !errors.Is(err, tc.Err) {
			t.Errorf("Expected error %v, got %v", tc.Err, err)
		}
		if err != nil {
			continue
		}

		res := rec.Result()
		if res.StatusCode != tc.Code {
			t.Errorf("Expected status code %d, got %d", tc.Code, res.StatusCode)
		}

		body, err := io.ReadAll(res.Body)
		if err != nil {
			t.Errorf("Unexpected error reading response body: %v", err)
			continue
		}

		if string(body) != tc.JSON {
			t.Errorf("Expected body %q, got %q", tc.JSON, string(body))
		}
	}
}