package qs

import (
	"errors"
	"net/http"
	"net/url"
	"regexp"
)

// Query error.
var (
	ErrInvalidJoin  = errors.New("invalid join")
	ErrTooManyJoins = errors.New("too many joins")
)

var joinRegexp = regexp.MustCompile("^[a-z0-9]+$")

// Joins represents joins as used in, most likely, a database query.
// This is a simplified instruction that should generally be interpreted as "join Y entity onto X entity".
type Joins map[string]bool

// ReadJoinsOptions configures the behaviour of ReadJoins.
type ReadJoinsOptions struct {
	Key      string // Query string key. The default value is "join"
	MaxJoins int    // If this is > 0, a maximum number of joins is imposed
}

// ReadJoins parses URL values into a slice of joins.
// This function returns nil if no joins are found.
func ReadJoins(values url.Values, opt *ReadJoinsOptions) (Joins, error) {
	opt = initJoinsOptions(opt)

	if !values.Has(opt.Key) {
		return nil, nil
	}

	if opt.MaxJoins > 0 && len(values[opt.Key]) > opt.MaxJoins {
		return nil, ErrTooManyJoins
	}

	joins := Joins{}
	for _, join := range values[opt.Key] {
		if !joinRegexp.MatchString(join) {
			return nil, ErrInvalidJoin
		}
		joins[join] = true
	}

	if len(joins) > 0 {
		return joins, nil
	}
	return nil, nil
}

// ReadRequestJoins parses a request's query string into a Joins map.
// This function returns nil if no joins are found.
func ReadRequestJoins(req *http.Request, opt *ReadJoinsOptions) (Joins, error) {
	return ReadJoins(req.URL.Query(), opt)
}

// ReadStringJoins parses a query string literal into a Joins map.
// This function returns nil if no joins are found.
func ReadStringJoins(qs string, opt *ReadJoinsOptions) (Joins, error) {
	values, err := url.ParseQuery(qs)
	if err != nil {
		return nil, err
	}
	return ReadJoins(values, opt)
}

func initJoinsOptions(opt *ReadJoinsOptions) *ReadJoinsOptions {
	def := &ReadJoinsOptions{
		Key: "join",
	}

	if opt != nil {
		if len(opt.Key) > 0 {
			def.Key = opt.Key
		}

		if opt.MaxJoins > def.MaxJoins {
			def.MaxJoins = opt.MaxJoins
		}
	}

	return def
}