join.go 2.0 KB

123456789101112131415161718192021222324252627282930313233343536373839404142434445464748495051525354555657585960616263646566676869707172737475767778798081828384858687
  1. package qs
  2. import (
  3. "errors"
  4. "net/http"
  5. "net/url"
  6. "regexp"
  7. )
  8. // Query error.
  9. var (
  10. ErrInvalidJoin = errors.New("invalid join")
  11. ErrTooManyJoins = errors.New("too many joins")
  12. )
  13. var joinRegexp = regexp.MustCompile("^[a-z0-9]+$")
  14. // Joins represents joins as used in, most likely, a database query.
  15. // This is a simplified instruction that should generally be interpreted as "join Y entity onto X entity".
  16. type Joins map[string]bool
  17. // ReadJoinsOptions configures the behaviour of ReadJoins.
  18. type ReadJoinsOptions struct {
  19. Key string // Query string key. The default value is "join"
  20. MaxJoins int // If this is > 0, a maximum number of joins is imposed
  21. }
  22. // ReadJoins parses URL values into a slice of joins.
  23. // This function returns nil if no joins are found.
  24. func ReadJoins(values url.Values, opt *ReadJoinsOptions) (Joins, error) {
  25. opt = initJoinsOptions(opt)
  26. if !values.Has(opt.Key) {
  27. return nil, nil
  28. }
  29. if opt.MaxJoins > 0 && len(values[opt.Key]) > opt.MaxJoins {
  30. return nil, ErrTooManyJoins
  31. }
  32. joins := Joins{}
  33. for _, join := range values[opt.Key] {
  34. if !joinRegexp.MatchString(join) {
  35. return nil, ErrInvalidJoin
  36. }
  37. joins[join] = true
  38. }
  39. if len(joins) > 0 {
  40. return joins, nil
  41. }
  42. return nil, nil
  43. }
  44. // ReadRequestJoins parses a request's query string into a Joins map.
  45. // This function returns nil if no joins are found.
  46. func ReadRequestJoins(req *http.Request, opt *ReadJoinsOptions) (Joins, error) {
  47. return ReadJoins(req.URL.Query(), opt)
  48. }
  49. // ReadStringJoins parses a query string literal into a Joins map.
  50. // This function returns nil if no joins are found.
  51. func ReadStringJoins(qs string, opt *ReadJoinsOptions) (Joins, error) {
  52. values, err := url.ParseQuery(qs)
  53. if err != nil {
  54. return nil, err
  55. }
  56. return ReadJoins(values, opt)
  57. }
  58. func initJoinsOptions(opt *ReadJoinsOptions) *ReadJoinsOptions {
  59. def := &ReadJoinsOptions{
  60. Key: "join",
  61. }
  62. if opt != nil {
  63. if len(opt.Key) > 0 {
  64. def.Key = opt.Key
  65. }
  66. if opt.MaxJoins > def.MaxJoins {
  67. def.MaxJoins = opt.MaxJoins
  68. }
  69. }
  70. return def
  71. }