forked from scylladb/gocqlx
-
Notifications
You must be signed in to change notification settings - Fork 0
/
gocqlx.go
61 lines (49 loc) · 1.86 KB
/
gocqlx.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
// Copyright (C) 2017 ScyllaDB
// Use of this source code is governed by a ALv2-style
// license that can be found in the LICENSE file.
package gocqlx
import (
"errors"
"fmt"
"reflect"
"github.com/gocql/gocql"
"github.com/scylladb/go-reflectx"
)
// structOnlyError returns an error appropriate for type when a non-scannable
// struct is expected but something else is given.
func structOnlyError(t reflect.Type) error {
if isStruct := t.Kind() == reflect.Struct; !isStruct {
return fmt.Errorf("expected a struct but got %s", t.Kind())
}
if isUnmarshaller := reflect.PtrTo(t).Implements(unmarshallerInterface); isUnmarshaller {
return fmt.Errorf("expected a struct but the provided struct type %s implements gocql.Unmarshaler", t.Name())
}
if isUDTUnmarshaller := reflect.PtrTo(t).Implements(udtUnmarshallerInterface); isUDTUnmarshaller {
return fmt.Errorf("expected a struct but the provided struct type %s implements gocql.UDTUnmarshaler", t.Name())
}
if isAutoUDT := reflect.PtrTo(t).Implements(autoUDTInterface); isAutoUDT {
return fmt.Errorf("expected a struct but the provided struct type %s implements gocqlx.UDT", t.Name())
}
return fmt.Errorf("expected a struct, but struct %s has no exported fields", t.Name())
}
// reflect helpers
var (
unmarshallerInterface = reflect.TypeOf((*gocql.Unmarshaler)(nil)).Elem()
udtUnmarshallerInterface = reflect.TypeOf((*gocql.UDTUnmarshaler)(nil)).Elem()
autoUDTInterface = reflect.TypeOf((*UDT)(nil)).Elem()
)
func baseType(t reflect.Type, expected reflect.Kind) (reflect.Type, error) {
t = reflectx.Deref(t)
if t.Kind() != expected {
return nil, fmt.Errorf("expected %s but got %s", expected, t.Kind())
}
return t, nil
}
func missingFields(transversals [][]int) (field int, err error) {
for i, t := range transversals {
if len(t) == 0 {
return i, errors.New("missing field")
}
}
return 0, nil
}