-
Notifications
You must be signed in to change notification settings - Fork 62
/
upsert.go
210 lines (176 loc) · 5.02 KB
/
upsert.go
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
package dat
import "reflect"
// UpsertBuilder contains the clauses for an INSERT statement
type UpsertBuilder struct {
Execer
isInterpolated bool
table string
cols []string
isBlacklist bool
vals []interface{}
record interface{}
returnings []string
whereFragments []*whereFragment
}
// NewUpsertBuilder creates a new UpsertBuilder for the given table.
func NewUpsertBuilder(table string) *UpsertBuilder {
if table == "" {
logger.Error("Insect requires a table name.")
return nil
}
return &UpsertBuilder{table: table, isInterpolated: EnableInterpolation}
}
// Columns appends columns to insert in the statement
func (b *UpsertBuilder) Columns(columns ...string) *UpsertBuilder {
return b.Whitelist(columns...)
}
// Blacklist defines a blacklist of columns and should only be used
// in conjunction with Record.
func (b *UpsertBuilder) Blacklist(columns ...string) *UpsertBuilder {
b.isBlacklist = true
b.cols = columns
return b
}
// Whitelist defines a whitelist of columns to be inserted. To
// specify all columsn of a record use "*".
func (b *UpsertBuilder) Whitelist(columns ...string) *UpsertBuilder {
b.cols = columns
return b
}
// Values appends a set of values to the statement
func (b *UpsertBuilder) Values(vals ...interface{}) *UpsertBuilder {
b.vals = vals
return b
}
// Record pulls in values to match Columns from the record
func (b *UpsertBuilder) Record(record interface{}) *UpsertBuilder {
b.record = record
return b
}
// Returning sets the columns for the RETURNING clause
func (b *UpsertBuilder) Returning(columns ...string) *UpsertBuilder {
b.returnings = columns
return b
}
// ToSQL serialized the UpsertBuilder to a SQL string
// It returns the string with placeholders and a slice of query arguments
func (b *UpsertBuilder) ToSQL() (string, []interface{}) {
if len(b.table) == 0 {
panic("no table specified")
}
lenCols := len(b.cols)
if lenCols == 0 {
panic("no columns specified")
}
if len(b.vals) == 0 && b.record == nil {
panic("no values or records specified")
}
if b.record == nil && b.cols[0] == "*" {
panic(`"*" can only be used in conjunction with Record`)
}
if b.record == nil && b.isBlacklist {
panic(`Blacklist can only be used in conjunction with Record`)
}
// build where clause from columns and values
if len(b.whereFragments) == 0 {
panic("where clause required for upsert")
}
// reflect fields removing blacklisted columns
if b.record != nil && b.isBlacklist {
b.cols = reflectExcludeColumns(b.record, b.cols)
}
// reflect all fields
if b.record != nil && b.cols[0] == "*" {
b.cols = reflectColumns(b.record)
}
if len(b.returnings) == 0 {
b.returnings = b.cols
}
/*
END GOAL:
WITH
new_values (id, field1, field2) AS (
values (1, 'A', 'X'),
(2, 'B', 'Y'),
(3, 'C', 'Z')
),
upsert as
(
update mytable m
set field1 = nv.field1,
field2 = nv.field2
FROM new_values nv
WHERE m.id = nv.id
RETURNING m.*
)
INSERT INTO mytable (id, field1, field2)
SELECT id, field1, field2
FROM new_values
WHERE NOT EXISTS (SELECT 1
FROM upsert up
WHERE up.id = new_values.id)
Upsert("table").
Columns("name", "email").
Values("mario", "[email protected]").
Where("name = $1", "mario").
Returning("id", "name", "email")
*/
if b.record != nil {
ind := reflect.Indirect(reflect.ValueOf(b.record))
var err error
b.vals, err = valuesFor(ind.Type(), ind, b.cols)
if err != nil {
panic(err.Error())
}
}
buf := bufPool.Get()
defer bufPool.Put(buf)
/*
WITH
upd as (
update people
set name = $1,
email = $2
WHERE name = $3
RETURNING id, name, email
),
ins AS (
INSERT INTO people (name, email)
SELECT $1, $2
WHERE NOT EXISTS (SELECT 1 FROM upd)
RETURNING id, name, email
)
SELECT * FROM upd
UNION ALL
SELECT * FROM ins
*/
// TODO refactor this, no need to call update
// builder, just need a few more helper functions
var args []interface{}
buf.WriteString("WITH upd AS ( ")
ub := NewUpdateBuilder(b.table)
for i, col := range b.cols {
ub.Set(col, b.vals[i])
}
ub.whereFragments = b.whereFragments
ub.returnings = b.returnings
updateSQL, args := ub.ToSQL()
buf.WriteString(updateSQL)
buf.WriteString("), ins AS (")
buf.WriteString(" INSERT INTO ")
writeIdentifier(buf, b.table)
buf.WriteString("(")
writeIdentifiers(buf, b.cols, ",")
buf.WriteString(") SELECT ")
writePlaceholders(buf, len(b.vals), ",", 1)
buf.WriteString(" WHERE NOT EXISTS (SELECT 1 FROM upd) RETURNING ")
writeIdentifiers(buf, b.returnings, ",")
buf.WriteString(") SELECT * FROM ins UNION ALL SELECT * FROM upd")
return buf.String(), args
}
// Where appends a WHERE clause to the statement for the given string and args
// or map of column/value pairs
func (b *UpsertBuilder) Where(whereSQLOrMap interface{}, args ...interface{}) *UpsertBuilder {
b.whereFragments = append(b.whereFragments, newWhereFragment(whereSQLOrMap, args))
return b
}