Skip to content

Commit 46a724a

Browse files
committed
Build SET clause from columns
1 parent 148319b commit 46a724a

File tree

2 files changed

+51
-23
lines changed

2 files changed

+51
-23
lines changed

database/query_builder.go

Lines changed: 21 additions & 11 deletions
Original file line numberDiff line numberDiff line change
@@ -184,41 +184,51 @@ func (qb *queryBuilder) SelectStatement(stmt SelectStatement) string {
184184
}
185185

186186
func (qb *queryBuilder) UpdateStatement(stmt UpdateStatement) (string, error) {
187+
columns := qb.BuildColumns(stmt.Entity(), stmt.Columns(), stmt.ExcludedColumns())
188+
187189
table := stmt.Table()
188190
if table == "" {
189191
table = TableName(stmt.Entity())
190192
}
191-
set := stmt.Set()
192-
if set == "" {
193-
return "", errors.New("set cannot be empty")
194-
}
193+
195194
where := stmt.Where()
196195
if where == "" {
197-
return "", errors.New("cannot use UpdateStatement() without where statement - use UpdateAllStatement() instead")
196+
return "", fmt.Errorf("%w: %s", ErrMissingStatementPart, "where statement - use UpdateAllStatement() instead")
197+
}
198+
199+
var set []string
200+
201+
for _, col := range columns {
202+
set = append(set, fmt.Sprintf(`"%[1]s" = :%[1]s`, col))
198203
}
199204

200205
return fmt.Sprintf(
201-
`UPDATE "%s" SET %s%s`,
206+
`UPDATE "%s" SET %s WHERE %s`,
202207
table,
203-
set,
208+
strings.Join(set, ", "),
204209
where,
205210
), nil
206211
}
207212

208213
func (qb *queryBuilder) UpdateAllStatement(stmt UpdateStatement) (string, error) {
214+
columns := qb.BuildColumns(stmt.Entity(), stmt.Columns(), stmt.ExcludedColumns())
215+
209216
table := stmt.Table()
210217
if table == "" {
211218
table = TableName(stmt.Entity())
212219
}
213-
set := stmt.Set()
214-
if set == "" {
215-
return "", errors.New("set cannot be empty")
216-
}
220+
217221
where := stmt.Where()
218222
if where != "" {
219223
return "", errors.New("cannot use UpdateAllStatement() with where statement - use UpdateStatement() instead")
220224
}
221225

226+
var set []string
227+
228+
for _, col := range columns {
229+
set = append(set, fmt.Sprintf(`"%[1]s" = :%[1]s`, col))
230+
}
231+
222232
return fmt.Sprintf(
223233
`UPDATE "%s" SET %s`,
224234
table,

database/update.go

Lines changed: 30 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -8,8 +8,12 @@ type UpdateStatement interface {
88
// Overrides the table name provided by the entity.
99
SetTable(table string) UpdateStatement
1010

11-
// SetSet sets the set clause for the UPDATE statement.
12-
SetSet(set string) UpdateStatement
11+
// SetColumns sets the columns to be updated.
12+
SetColumns(columns ...string) UpdateStatement
13+
14+
// SetExcludedColumns sets the columns to be excluded from the UPDATE statement.
15+
// Excludes also columns set by SetColumns.
16+
SetExcludedColumns(columns ...string) UpdateStatement
1317

1418
// SetWhere sets the where clause for the UPDATE statement.
1519
SetWhere(where string) UpdateStatement
@@ -20,8 +24,11 @@ type UpdateStatement interface {
2024
// Table returns the table name for the UPDATE statement.
2125
Table() string
2226

23-
// Set returns the set clause for the UPDATE statement.
24-
Set() string
27+
// Columns returns the columns to be updated.
28+
Columns() []string
29+
30+
// ExcludedColumns returns the columns to be excluded from the UPDATE statement.
31+
ExcludedColumns() []string
2532

2633
// Where returns the where clause for the UPDATE statement.
2734
Where() string
@@ -39,10 +46,11 @@ func NewUpdateStatement(entity Entity) UpdateStatement {
3946

4047
// updateStatement is the default implementation of the UpdateStatement interface.
4148
type updateStatement struct {
42-
entity Entity
43-
table string
44-
set string
45-
where string
49+
entity Entity
50+
table string
51+
columns []string
52+
excludedColumns []string
53+
where string
4654
}
4755

4856
func (u *updateStatement) SetTable(table string) UpdateStatement {
@@ -51,8 +59,14 @@ func (u *updateStatement) SetTable(table string) UpdateStatement {
5159
return u
5260
}
5361

54-
func (u *updateStatement) SetSet(set string) UpdateStatement {
55-
u.set = set
62+
func (u *updateStatement) SetColumns(columns ...string) UpdateStatement {
63+
u.columns = columns
64+
65+
return u
66+
}
67+
68+
func (u *updateStatement) SetExcludedColumns(columns ...string) UpdateStatement {
69+
u.excludedColumns = columns
5670

5771
return u
5872
}
@@ -71,8 +85,12 @@ func (u *updateStatement) Table() string {
7185
return u.table
7286
}
7387

74-
func (u *updateStatement) Set() string {
75-
return u.set
88+
func (u *updateStatement) Columns() []string {
89+
return u.columns
90+
}
91+
92+
func (u *updateStatement) ExcludedColumns() []string {
93+
return u.excludedColumns
7694
}
7795

7896
func (u *updateStatement) Where() string {

0 commit comments

Comments
 (0)