Skip to content

Commit cd7181f

Browse files
authored
fix enum list and alias name (#277)
1 parent b4e8f00 commit cd7181f

File tree

5 files changed

+167
-127
lines changed

5 files changed

+167
-127
lines changed

generator/code_generator.go

Lines changed: 39 additions & 28 deletions
Original file line numberDiff line numberDiff line change
@@ -28,25 +28,25 @@ func NewCodeGenerator() *CodeGenerator {
2828
return &CodeGenerator{}
2929
}
3030

31-
func (g *CodeGenerator) Generate(file *resolver.File, enums []*resolver.Enum) ([]byte, error) {
31+
func (g *CodeGenerator) Generate(file *resolver.File) ([]byte, error) {
3232
tmpl, err := loadTemplate()
3333
if err != nil {
3434
return nil, err
3535
}
36-
return generateGoContent(tmpl, NewFile(file, enums))
36+
return generateGoContent(tmpl, NewFile(file))
3737
}
3838

3939
type File struct {
4040
*resolver.File
41-
enums []*resolver.Enum
42-
pkgMap map[*resolver.GoPackage]struct{}
41+
pkgMap map[*resolver.GoPackage]struct{}
42+
aliasMap map[*resolver.GoPackage]string
4343
}
4444

45-
func NewFile(file *resolver.File, enums []*resolver.Enum) *File {
45+
func NewFile(file *resolver.File) *File {
4646
return &File{
47-
File: file,
48-
pkgMap: make(map[*resolver.GoPackage]struct{}),
49-
enums: enums,
47+
File: file,
48+
pkgMap: make(map[*resolver.GoPackage]struct{}),
49+
aliasMap: make(map[*resolver.GoPackage]string),
5050
}
5151
}
5252

@@ -608,21 +608,22 @@ func (f *File) Imports() []*Import {
608608
if _, exists := importPathMap[pkg.ImportPath]; exists {
609609
return
610610
}
611+
alias := pkg.Name
611612
if _, exists := importAliasMap[pkg.Name]; exists {
612613
// conflict alias name
613614
suffixIndex := 1 // start from 1.
614615
for {
615-
alias := pkg.Name + fmt.Sprint(suffixIndex)
616+
alias = pkg.Name + fmt.Sprint(suffixIndex)
616617
if _, exists := importAliasMap[alias]; !exists {
617-
pkg.Name = alias
618618
break
619619
}
620620
suffixIndex++
621621
}
622622
}
623+
f.aliasMap[pkg] = alias
623624
imports = append(imports, &Import{
624625
Path: pkg.ImportPath,
625-
Alias: pkg.Name,
626+
Alias: alias,
626627
Used: true,
627628
})
628629
importPathMap[pkg.ImportPath] = struct{}{}
@@ -644,15 +645,23 @@ func (f *File) Imports() []*Import {
644645
return imports
645646
}
646647

648+
func (f *File) getAlias(pkg *resolver.GoPackage) string {
649+
alias, exists := f.aliasMap[pkg]
650+
if exists {
651+
return alias
652+
}
653+
return pkg.Name
654+
}
655+
647656
type Enum struct {
648657
ProtoName string
649658
GoName string
650659
EnumAttribute *EnumAttribute
651660
}
652661

653662
func (f *File) Enums() []*Enum {
654-
ret := make([]*Enum, 0, len(f.enums))
655-
for _, enum := range f.enums {
663+
var enums []*Enum
664+
for _, enum := range f.AllEnumsIncludeDeps() {
656665
protoName := enum.FQDN()
657666
// ignore standard library's enum.
658667
if strings.HasPrefix(protoName, "google.") {
@@ -661,7 +670,6 @@ func (f *File) Enums() []*Enum {
661670
if strings.HasPrefix(protoName, "grpc.federation.") {
662671
continue
663672
}
664-
// f.enums contain all the enums defined in the package
665673
// Currently Enums are used only from a File contains Services.
666674
if len(f.File.Services) != 0 {
667675
f.pkgMap[enum.GoPackage()] = struct{}{}
@@ -680,14 +688,13 @@ func (f *File) Enums() []*Enum {
680688
})
681689
}
682690
}
683-
684-
ret = append(ret, &Enum{
691+
enums = append(enums, &Enum{
685692
ProtoName: protoName,
686693
GoName: f.enumTypeToText(enum),
687694
EnumAttribute: enumAttr,
688695
})
689696
}
690-
return ret
697+
return enums
691698
}
692699

693700
type EnumAttribute struct {
@@ -895,11 +902,12 @@ func (s *Service) ServiceName() string {
895902
}
896903

897904
func (s *Service) PackageName() string {
898-
return s.GoPackage().Name
905+
return s.file.getAlias(s.GoPackage())
899906
}
900907

901908
type ServiceDependency struct {
902909
*resolver.ServiceDependency
910+
file *File
903911
}
904912

905913
func (dep *ServiceDependency) ServiceName() string {
@@ -923,15 +931,15 @@ func (dep *ServiceDependency) PrivateClientName() string {
923931
func (dep *ServiceDependency) ClientType() string {
924932
return fmt.Sprintf(
925933
"%s.%sClient",
926-
dep.Service.GoPackage().Name,
934+
dep.file.getAlias(dep.Service.GoPackage()),
927935
dep.Service.Name,
928936
)
929937
}
930938

931939
func (dep *ServiceDependency) ClientConstructor() string {
932940
return fmt.Sprintf(
933941
"%s.New%sClient",
934-
dep.Service.GoPackage().Name,
942+
dep.file.getAlias(dep.Service.GoPackage()),
935943
dep.Service.Name,
936944
)
937945
}
@@ -940,7 +948,10 @@ func (s *Service) ServiceDependencies() []*ServiceDependency {
940948
deps := s.Service.ServiceDependencies()
941949
ret := make([]*ServiceDependency, 0, len(deps))
942950
for _, dep := range deps {
943-
ret = append(ret, &ServiceDependency{dep})
951+
ret = append(ret, &ServiceDependency{
952+
ServiceDependency: dep,
953+
file: s.file,
954+
})
944955
}
945956
return ret
946957
}
@@ -1198,7 +1209,7 @@ func (f *File) messageTypeToText(msg *resolver.Message) string {
11981209
if f.GoPackage.ImportPath == msg.GoPackage().ImportPath {
11991210
return fmt.Sprintf("*%s", name)
12001211
}
1201-
return fmt.Sprintf("*%s.%s", msg.GoPackage().Name, name)
1212+
return fmt.Sprintf("*%s.%s", f.getAlias(msg.GoPackage()), name)
12021213
}
12031214

12041215
func (f *File) enumTypeToText(enum *resolver.Enum) string {
@@ -1215,7 +1226,7 @@ func (f *File) enumTypeToText(enum *resolver.Enum) string {
12151226
if f.GoPackage.ImportPath == enum.GoPackage().ImportPath {
12161227
return name
12171228
}
1218-
return fmt.Sprintf("%s.%s", enum.GoPackage().Name, name)
1229+
return fmt.Sprintf("%s.%s", f.getAlias(enum.GoPackage()), name)
12191230
}
12201231

12211232
type LogValue struct {
@@ -2258,7 +2269,7 @@ func (f *CastField) ToStruct() *CastStruct {
22582269
}
22592270
name := strings.Join(names, "_")
22602271
if f.service.GoPackage().ImportPath != toMsg.GoPackage().ImportPath {
2261-
name = fmt.Sprintf("%s.%s", toMsg.GoPackage().Name, name)
2272+
name = fmt.Sprintf("%s.%s", f.file.getAlias(toMsg.GoPackage()), name)
22622273
}
22632274

22642275
return &CastStruct{
@@ -2328,7 +2339,7 @@ func (f *CastField) ToOneof() *CastOneof {
23282339
name += "_"
23292340
}
23302341
if f.service.GoPackage().ImportPath != msg.GoPackage().ImportPath {
2331-
name = fmt.Sprintf("%s.%s", msg.GoPackage().Name, name)
2342+
name = fmt.Sprintf("%s.%s", f.file.getAlias(msg.GoPackage()), name)
23322343
}
23332344
return &CastOneof{
23342345
Name: name,
@@ -3075,7 +3086,7 @@ func (d *VariableDefinition) RequestType() string {
30753086
case expr.Call != nil:
30763087
request := expr.Call.Request
30773088
return fmt.Sprintf("%s.%s",
3078-
request.Type.GoPackage().Name,
3089+
d.file.getAlias(request.Type.GoPackage()),
30793090
util.ToPublicGoVariable(request.Type.Name),
30803091
)
30813092
case expr.Message != nil:
@@ -3099,7 +3110,7 @@ func (d *VariableDefinition) ReturnType() string {
30993110
case expr.Call != nil:
31003111
response := expr.Call.Method.Response
31013112
return fmt.Sprintf("%s.%s",
3102-
response.GoPackage().Name,
3113+
d.file.getAlias(response.GoPackage()),
31033114
response.Name,
31043115
)
31053116
case expr.Message != nil:
@@ -3673,7 +3684,7 @@ func toEnumValuePrefix(file *File, typ *resolver.Type) string {
36733684
if file.GoPackage.ImportPath == enum.GoPackage().ImportPath {
36743685
return name
36753686
}
3676-
return fmt.Sprintf("%s.%s", enum.GoPackage().Name, name)
3687+
return fmt.Sprintf("%s.%s", file.getAlias(enum.GoPackage()), name)
36773688
}
36783689

36793690
func toEnumValueText(enumValuePrefix string, value string) string {

generator/code_generator_test.go

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -71,7 +71,7 @@ func TestCodeGenerate(t *testing.T) {
7171
t.Fatalf("failed to get files. expected 1 but got %d", len(result.Files))
7272
}
7373
result.Files[0].Name = filepath.Base(result.Files[0].Name)
74-
out, err := generator.NewCodeGenerator().Generate(result.Files[0], result.Enums)
74+
out, err := generator.NewCodeGenerator().Generate(result.Files[0])
7575
if err != nil {
7676
t.Fatal(err)
7777
}

generator/generator.go

Lines changed: 2 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -603,7 +603,7 @@ func (g *Generator) generateByGRPCFederation(r *PluginRequest) (*pluginpb.CodeGe
603603

604604
var resp pluginpb.CodeGeneratorResponse
605605
for _, file := range result.Files {
606-
out, err := NewCodeGenerator().Generate(file, result.Enums)
606+
out, err := NewCodeGenerator().Generate(file)
607607
if err != nil {
608608
return nil, err
609609
}
@@ -667,7 +667,7 @@ func CreateCodeGeneratorResponse(ctx context.Context, req *pluginpb.CodeGenerato
667667

668668
var resp pluginpb.CodeGeneratorResponse
669669
for _, file := range result.Files {
670-
out, err := NewCodeGenerator().Generate(file, result.Enums)
670+
out, err := NewCodeGenerator().Generate(file)
671671
if err != nil {
672672
return nil, err
673673
}

resolver/file.go

Lines changed: 31 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -3,6 +3,7 @@ package resolver
33
import (
44
"fmt"
55
"path/filepath"
6+
"sort"
67
"strings"
78

89
"github.com/mercari/grpc-federation/grpc/federation"
@@ -33,6 +34,7 @@ func (f *File) Message(name string) *Message {
3334
return nil
3435
}
3536

37+
// AllEnums returns a list that includes the enums defined in the file itself.
3638
func (f *File) AllEnums() []*Enum {
3739
enums := f.Enums
3840
for _, msg := range f.Messages {
@@ -41,6 +43,35 @@ func (f *File) AllEnums() []*Enum {
4143
return enums
4244
}
4345

46+
// AllEnumsIncludeDeps recursively searches for the imported file and returns a list of all enums.
47+
func (f *File) AllEnumsIncludeDeps() []*Enum {
48+
return f.allEnumsIncludeDeps(make(map[string][]*Enum))
49+
}
50+
51+
func (f *File) allEnumsIncludeDeps(cacheMap map[string][]*Enum) []*Enum {
52+
if enums, exists := cacheMap[f.Name]; exists {
53+
return enums
54+
}
55+
enumMap := make(map[string]*Enum)
56+
for _, enum := range f.AllEnums() {
57+
enumMap[enum.FQDN()] = enum
58+
}
59+
for _, importFile := range f.ImportFiles {
60+
for _, enum := range importFile.allEnumsIncludeDeps(cacheMap) {
61+
enumMap[enum.FQDN()] = enum
62+
}
63+
}
64+
enums := make([]*Enum, 0, len(enumMap))
65+
for _, enum := range enumMap {
66+
enums = append(enums, enum)
67+
}
68+
sort.Slice(enums, func(i, j int) bool {
69+
return enums[i].FQDN() < enums[j].FQDN()
70+
})
71+
cacheMap[f.Name] = enums
72+
return enums
73+
}
74+
4475
func (f *File) PrivatePackageName() string {
4576
return federation.PrivatePackageName + "." + f.PackageName()
4677
}

0 commit comments

Comments
 (0)