Skip to content

Commit f8fa86c

Browse files
authored
refactor(merger): 使用更宽松的比较机制来比较两个列的相等性 (#227)
1 parent 2f1d981 commit f8fa86c

File tree

3 files changed

+28
-19
lines changed

3 files changed

+28
-19
lines changed

.CHANGELOG.md

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -42,6 +42,7 @@
4242
- [feat(merger): 新增Distinct Merger](https://github.com/ecodeclub/eorm/pull/224)
4343
- [refactor(merger): 去掉无用代码及过期注释,整理代码](https://github.com/ecodeclub/eorm/pull/225)
4444
- [eorm: 结果集处理--聚合函数支持nullable类型的数据](https://github.com/ecodeclub/eorm/pull/226)
45+
- [refactor(merger): 使用更宽松的比较机制来比较两个列的相等性](https://github.com/ecodeclub/eorm/pull/227)
4546

4647
## v0.0.1:
4748
- [Init Project](https://github.com/ecodeclub/eorm/pull/1)

internal/merger/factory/factory.go

Lines changed: 21 additions & 19 deletions
Original file line numberDiff line numberDiff line change
@@ -117,34 +117,40 @@ func (q QuerySpec) validateGroupBy() error {
117117
}
118118
for _, c := range q.GroupBy {
119119
if !c.Validate() {
120-
return fmt.Errorf("%w: groupby %v", ErrInvalidColumnInfo, c.Name)
120+
return fmt.Errorf("%w: groupby %#v", ErrInvalidColumnInfo, c)
121121
}
122-
// 清除ASC
123-
c.Order = merger.OrderDESC
124-
if !slice.Contains(q.Select, c) {
125-
return fmt.Errorf("%w: groupby %v", ErrColumnNotFoundInSelectList, c.Name)
122+
if !slice.ContainsFunc(q.Select, func(src merger.ColumnInfo) bool { return equals(src, c) }) {
123+
return fmt.Errorf("%w: groupby %#v", ErrColumnNotFoundInSelectList, c)
126124
}
127125
}
128126
for _, c := range q.Select {
129-
if c.AggregateFunc == "" && !slice.Contains(q.GroupBy, c) {
130-
return fmt.Errorf("%w: 非聚合列 %v 必须出现在groupby列表中", ErrInvalidColumnInfo, c.Name)
127+
isInGroupByList := slice.ContainsFunc(q.GroupBy, func(src merger.ColumnInfo) bool { return equals(src, c) })
128+
if c.AggregateFunc == "" && !isInGroupByList {
129+
return fmt.Errorf("%w: 非聚合列 %#v 必须出现在groupby列表中", ErrInvalidColumnInfo, c)
131130
}
132-
if c.AggregateFunc != "" && slice.Contains(q.GroupBy, c) {
133-
return fmt.Errorf("%w: 聚合列 %v 不能出现在groupby列表中", ErrInvalidColumnInfo, c.Name)
131+
if c.AggregateFunc != "" && isInGroupByList {
132+
return fmt.Errorf("%w: 聚合列 %#v 不能出现在groupby列表中", ErrInvalidColumnInfo, c)
134133
}
135134
}
136135
return nil
137136
}
138137

138+
func equals(a, b merger.ColumnInfo) bool {
139+
// 这里忽略Order和Distinct字段的比较
140+
return a.Index == b.Index &&
141+
strings.Trim(a.Name, "`") == strings.Trim(b.Name, "`") &&
142+
strings.EqualFold(a.AggregateFunc, b.AggregateFunc) &&
143+
strings.Trim(a.Alias, "`") == strings.Trim(b.Alias, "`")
144+
}
145+
139146
func (q QuerySpec) validateDistinct() error {
140147
if !slice.Contains(q.Features, query.Distinct) {
141148
return nil
142149
}
143-
// 程序走到这q.Select的长度至少为1
150+
// 注意: 程序走到这q.Select的长度至少为1
144151
for _, c := range q.Select {
145-
// case2,3
146152
if !c.Distinct || !c.Validate() {
147-
return fmt.Errorf("%w: distinct %v", ErrInvalidColumnInfo, c.Name)
153+
return fmt.Errorf("%w: distinct %#v", ErrInvalidColumnInfo, c)
148154
}
149155
}
150156
return nil
@@ -158,15 +164,11 @@ func (q QuerySpec) validateOrderBy() error {
158164
return fmt.Errorf("%w: orderby", ErrEmptyColumnList)
159165
}
160166
for _, c := range q.OrderBy {
161-
162167
if !c.Validate() {
163-
return fmt.Errorf("%w: orderby %v", ErrInvalidColumnInfo, c.Name)
168+
return fmt.Errorf("%w: orderby %#v", ErrInvalidColumnInfo, c)
164169
}
165-
_, ok := slice.Find(q.Select, func(src merger.ColumnInfo) bool {
166-
return src.Index == c.Index && src.SelectName() == c.SelectName()
167-
})
168-
if !ok {
169-
return fmt.Errorf("%w: orderby %v", ErrColumnNotFoundInSelectList, c.Name)
170+
if !slice.ContainsFunc(q.Select, func(src merger.ColumnInfo) bool { return equals(src, c) }) {
171+
return fmt.Errorf("%w: orderby %#v", ErrColumnNotFoundInSelectList, c)
170172
}
171173
}
172174
return nil

internal/merger/type.go

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -22,6 +22,7 @@ import (
2222
"strings"
2323
"time"
2424

25+
"github.com/ecodeclub/ekit/slice"
2526
"github.com/ecodeclub/eorm/internal/merger/internal/errs"
2627
"github.com/ecodeclub/eorm/internal/rows"
2728
)
@@ -67,6 +68,11 @@ func (c ColumnInfo) SelectName() string {
6768
func (c ColumnInfo) Validate() bool {
6869
// ColumnInfo.Name中不能包含括号,也就是聚合函数, name = `id`, 而不是name = count(`id`)
6970
// 聚合函数需要写在aggregateFunc字段中
71+
aggregateFuncs := []string{"MAX", "MIN", "AVG", "SUM", "COUNT"}
72+
if c.AggregateFunc != "" &&
73+
!slice.ContainsFunc(aggregateFuncs, func(src string) bool { return strings.EqualFold(src, c.AggregateFunc) }) {
74+
return false
75+
}
7076
return !strings.Contains(c.Name, "(")
7177
}
7278

0 commit comments

Comments
 (0)