Aneurin Barker Snook 1 anno fa
parent
commit
d94f155714
4 ha cambiato i file con 175 aggiunte e 0 eliminazioni
  1. 87 0
      join.go
  2. 52 0
      join_test.go
  3. 8 0
      page.go
  4. 28 0
      page_test.go

+ 87 - 0
join.go

@@ -0,0 +1,87 @@
+package qs
+
+import (
+	"errors"
+	"net/http"
+	"net/url"
+	"regexp"
+)
+
+// Query error.
+var (
+	ErrInvalidJoin  = errors.New("invalid join")
+	ErrTooManyJoins = errors.New("too many joins")
+)
+
+var joinRegexp = regexp.MustCompile("^[a-z0-9]+$")
+
+// Joins represents joins as used in, most likely, a database query.
+// This is a simplified instruction that should generally be interpreted as "join Y entity onto X entity".
+type Joins map[string]bool
+
+// ReadJoinsOptions configures the behaviour of ReadJoins.
+type ReadJoinsOptions struct {
+	Key      string // Query string key. The default value is "join"
+	MaxJoins int    // If this is > 0, a maximum number of joins is imposed
+}
+
+// ReadJoins parses URL values into a slice of joins.
+// This function returns nil if no joins are found.
+func ReadJoins(values url.Values, opt *ReadJoinsOptions) (Joins, error) {
+	opt = initJoinsOptions(opt)
+
+	if !values.Has(opt.Key) {
+		return nil, nil
+	}
+
+	if opt.MaxJoins > 0 && len(values[opt.Key]) > opt.MaxJoins {
+		return nil, ErrTooManyJoins
+	}
+
+	joins := Joins{}
+	for _, join := range values[opt.Key] {
+		if !joinRegexp.MatchString(join) {
+			return nil, ErrInvalidJoin
+		}
+		joins[join] = true
+	}
+
+	if len(joins) > 0 {
+		return joins, nil
+	}
+	return nil, nil
+}
+
+// ReadRequestJoins parses a request's query string into a Joins map.
+// This function returns nil if no joins are found.
+func ReadRequestJoins(req *http.Request, opt *ReadJoinsOptions) (Joins, error) {
+	return ReadJoins(req.URL.Query(), opt)
+}
+
+// ReadStringJoins parses a query string literal into a Joins map.
+// This function returns nil if no joins are found.
+func ReadStringJoins(qs string, opt *ReadJoinsOptions) (Joins, error) {
+	values, err := url.ParseQuery(qs)
+	if err != nil {
+		return nil, err
+	}
+	return ReadJoins(values, opt)
+}
+
+func initJoinsOptions(opt *ReadJoinsOptions) *ReadJoinsOptions {
+	def := &ReadJoinsOptions{
+		Key: "join",
+	}
+
+	if opt != nil {
+		if len(opt.Key) > 0 {
+			def.Key = opt.Key
+		}
+
+		if opt.MaxJoins > def.MaxJoins {
+			def.MaxJoins = opt.MaxJoins
+		}
+	}
+
+	return def
+}

+ 52 - 0
join_test.go

@@ -0,0 +1,52 @@
+package qs
+
+import "testing"
+
+func TestReadJoins(t *testing.T) {
+	type TestCase struct {
+		Input  string
+		Opt    *ReadJoinsOptions
+		Output Joins
+		Err    error
+	}
+
+	testCases := []TestCase{
+		{Input: ""},
+		{
+			Input:  "join=ingredient",
+			Output: Joins{"ingredient": true},
+		},
+		{
+			Input:  "join=author&join=ingredient",
+			Output: Joins{"author": true, "ingredient": true},
+		},
+	}
+
+	for n, tc := range testCases {
+		t.Logf("(%d) Testing %q with options %+v", n, tc.Input, tc.Opt)
+
+		joins, err := ReadStringJoins(tc.Input, nil)
+
+		if err != tc.Err {
+			t.Errorf("Expected error %v, got %v", tc.Err, err)
+		}
+		if tc.Err != nil {
+			continue
+		}
+
+		if tc.Output == nil && joins != nil {
+			t.Error("Expected nil")
+			continue
+		}
+
+		if len(joins) != len(tc.Output) {
+			t.Errorf("Expected %d joins, got %d", len(tc.Output), len(joins))
+		}
+
+		for name, join := range tc.Output {
+			if join != joins[name] {
+				t.Errorf("Expected %t for join %s, got %t", join, name, joins[name])
+			}
+		}
+	}
+}

+ 8 - 0
page.go

@@ -10,6 +10,7 @@ type Page struct {
 	Pagination *Pagination `json:"pagination"`
 	Filters    Filters     `json:"filters,omitempty"`
 	Sorts      Sorts       `json:"sorts,omitempty"`
+	Joins      Joins       `json:"joins,omitempty"`
 }
 
 // ReadPageOptions configures the behaviour of ReadPage.
@@ -17,6 +18,7 @@ type ReadPageOptions struct {
 	Pagination *ReadPaginationOptions
 	Filter     *ReadFiltersOptions
 	Sort       *ReadSortsOptions
+	Join       *ReadJoinsOptions
 }
 
 // ReadPage parses URL values into a convenient Page struct.
@@ -38,10 +40,16 @@ func ReadPage(values url.Values, opt *ReadPageOptions) (*Page, error) {
 		return nil, err
 	}
 
+	joins, err := ReadJoins(values, opt.Join)
+	if err != nil {
+		return nil, err
+	}
+
 	page := &Page{
 		Pagination: pag,
 		Filters:    filters,
 		Sorts:      sorts,
+		Joins:      joins,
 	}
 	return page, nil
 }

+ 28 - 0
page_test.go

@@ -44,6 +44,19 @@ func TestReadPage(t *testing.T) {
 				},
 			},
 		},
+		{
+			Input: "limit=10&page=2&filter=title eq Spaghetti&sort=serves desc&join=author",
+			Output: &Page{
+				Pagination: &Pagination{Limit: 10, Offset: 10, Page: 2},
+				Filters: []Filter{
+					{Field: "title", Operator: "eq", Value: "Spaghetti"},
+				},
+				Sorts: []Sort{
+					{Field: "serves", Direction: "desc"},
+				},
+				Joins: Joins{"author": true},
+			},
+		},
 	}
 
 	for n, tc := range testCases {
@@ -103,5 +116,20 @@ func TestReadPage(t *testing.T) {
 				t.Errorf("Expected %+v for sort %d, got %+v", sort, i, page.Sorts[i])
 			}
 		}
+
+		// Compare joins (see join_test.go)
+		if tc.Output.Joins == nil && page.Joins != nil {
+			t.Error("Expected nil sorts")
+		}
+
+		if len(page.Joins) != len(tc.Output.Joins) {
+			t.Errorf("Expected %d joins, got %d", len(tc.Output.Joins), len(page.Joins))
+		}
+
+		for name, join := range tc.Output.Joins {
+			if join != page.Joins[name] {
+				t.Errorf("Expected %t for join %s, got %t", join, name, page.Joins[name])
+			}
+		}
 	}
 }