212 lines
7.2 KiB
Go
212 lines
7.2 KiB
Go
package codegen
|
||
|
||
import (
|
||
"fmt"
|
||
"os"
|
||
"path/filepath"
|
||
"strings"
|
||
|
||
codegenModel "sundynix-go/model/codegen"
|
||
)
|
||
|
||
// autoRegister 在代码生成后自动修改项目的注册文件
|
||
// 包括:initialize/gorm.go, initialize/router.go, api/v1/enter.go, router/enter.go, service/enter.go
|
||
// 以及子模块的 enter.go 文件(增量追加新 feature)
|
||
func (s *CodegenService) autoRegister(outputDir string, config codegenModel.GenConfig) error {
|
||
for _, module := range config.Modules {
|
||
if strings.TrimSpace(module.PackageName) == "" || strings.TrimSpace(module.Name) == "" {
|
||
continue
|
||
}
|
||
|
||
pkg := module.PackageName
|
||
pascalModule := module.Name
|
||
|
||
// 判断是否新模块:检查 router/{pkg}/enter.go 是否已存在
|
||
routerEnterPath := filepath.Join(outputDir, "router", pkg, "enter.go")
|
||
isNewModule := !fileExistsOnDisk(routerEnterPath)
|
||
|
||
if isNewModule {
|
||
// 新模块:注册到 5 个顶级文件
|
||
if err := s.registerGorm(outputDir, module); err != nil {
|
||
return fmt.Errorf("注册 gorm 迁移失败: %w", err)
|
||
}
|
||
if err := s.registerTopEnter(outputDir, pkg, pascalModule); err != nil {
|
||
return fmt.Errorf("注册顶级 enter 失败: %w", err)
|
||
}
|
||
if err := s.registerRouter(outputDir, pkg, pascalModule, module.Features); err != nil {
|
||
return fmt.Errorf("注册路由初始化失败: %w", err)
|
||
}
|
||
} else {
|
||
// 已有模块:增量注册新 feature
|
||
if err := s.registerGormIncremental(outputDir, module); err != nil {
|
||
return fmt.Errorf("增量注册 gorm 迁移失败: %w", err)
|
||
}
|
||
if err := s.registerRouterIncremental(outputDir, pkg, pascalModule, module.Features); err != nil {
|
||
return fmt.Errorf("增量注册路由失败: %w", err)
|
||
}
|
||
}
|
||
}
|
||
return nil
|
||
}
|
||
|
||
// registerGorm 向 initialize/gorm.go 的 AutoMigrate 中追加所有 feature
|
||
func (s *CodegenService) registerGorm(outputDir string, module codegenModel.Module) error {
|
||
gormPath := filepath.Join(outputDir, "initialize", "gorm.go")
|
||
content, err := os.ReadFile(gormPath)
|
||
if err != nil {
|
||
return fmt.Errorf("读取 %s 失败: %w", gormPath, err)
|
||
}
|
||
text := string(content)
|
||
pkg := module.PackageName
|
||
|
||
// 1. 添加 import
|
||
importLine := fmt.Sprintf("\t\"sundynix-go/model/%s\"", pkg)
|
||
if !strings.Contains(text, importLine) {
|
||
text = strings.Replace(text, "\"go.uber.org/zap\"", importLine+"\n\n\t\"go.uber.org/zap\"", 1)
|
||
}
|
||
|
||
// 2. 在 AutoMigrate 闭合 ) 前插入 model
|
||
for _, feature := range module.Features {
|
||
if strings.TrimSpace(feature.Name) == "" {
|
||
continue
|
||
}
|
||
modelEntry := fmt.Sprintf("\t\t%s.%s{},", pkg, feature.Name)
|
||
if !strings.Contains(text, modelEntry) {
|
||
// 找最后一个 ) 之前的位置(AutoMigrate 结束)
|
||
closeIdx := strings.LastIndex(text, "\t)")
|
||
if closeIdx > 0 {
|
||
text = text[:closeIdx] + "\n" + modelEntry + "\n" + text[closeIdx:]
|
||
}
|
||
}
|
||
}
|
||
|
||
return os.WriteFile(gormPath, []byte(text), 0644)
|
||
}
|
||
|
||
// registerGormIncremental 增量模式:只追加新 feature 到 AutoMigrate
|
||
func (s *CodegenService) registerGormIncremental(outputDir string, module codegenModel.Module) error {
|
||
return s.registerGorm(outputDir, module) // 逻辑相同,Contains 检查保证幂等
|
||
}
|
||
|
||
// registerTopEnter 向 api/v1/enter.go、router/enter.go、service/enter.go 追加新模块
|
||
func (s *CodegenService) registerTopEnter(outputDir string, pkg, pascalModule string) error {
|
||
// --- api/v1/enter.go ---
|
||
apiEnterPath := filepath.Join(outputDir, "api", "v1", "enter.go")
|
||
if err := appendToEnterFile(apiEnterPath,
|
||
fmt.Sprintf("\t\"sundynix-go/api/v1/%s\"", pkg),
|
||
fmt.Sprintf("\t%sApiGroup %s.ApiGroup", pascalModule, pkg),
|
||
); err != nil {
|
||
return err
|
||
}
|
||
|
||
// --- router/enter.go ---
|
||
routerEnterPath := filepath.Join(outputDir, "router", "enter.go")
|
||
if err := appendToEnterFile(routerEnterPath,
|
||
fmt.Sprintf("\t\"sundynix-go/router/%s\"", pkg),
|
||
fmt.Sprintf("\t%s %s.RouterGroup", pascalModule, pkg),
|
||
); err != nil {
|
||
return err
|
||
}
|
||
|
||
// --- service/enter.go ---
|
||
svcEnterPath := filepath.Join(outputDir, "service", "enter.go")
|
||
if err := appendToEnterFile(svcEnterPath,
|
||
fmt.Sprintf("\t\"sundynix-go/service/%s\"", pkg),
|
||
fmt.Sprintf("\t%sServiceGroup %s.ServiceGroup", pascalModule, pkg),
|
||
); err != nil {
|
||
return err
|
||
}
|
||
|
||
return nil
|
||
}
|
||
|
||
// appendToEnterFile 向 enter.go 文件追加 import 和 struct 字段
|
||
func appendToEnterFile(filePath, importLine, structField string) error {
|
||
content, err := os.ReadFile(filePath)
|
||
if err != nil {
|
||
return fmt.Errorf("读取 %s 失败: %w", filePath, err)
|
||
}
|
||
text := string(content)
|
||
|
||
// 添加 import(如果不存在)
|
||
if !strings.Contains(text, importLine) {
|
||
// 在最后一个 import 之前插入
|
||
lastImportIdx := strings.LastIndex(text, "\t\"")
|
||
if lastImportIdx > 0 {
|
||
// 找这行的换行符
|
||
nlIdx := strings.Index(text[lastImportIdx:], "\n")
|
||
if nlIdx > 0 {
|
||
insertPos := lastImportIdx + nlIdx + 1
|
||
text = text[:insertPos] + importLine + "\n" + text[insertPos:]
|
||
}
|
||
}
|
||
}
|
||
|
||
// 添加 struct 字段(如果不存在)
|
||
if !strings.Contains(text, structField) {
|
||
// 在 struct 的 } 之前插入
|
||
closeBrace := strings.Index(text, "\n}")
|
||
if closeBrace > 0 {
|
||
text = text[:closeBrace] + "\n" + structField + text[closeBrace:]
|
||
}
|
||
}
|
||
|
||
return os.WriteFile(filePath, []byte(text), 0644)
|
||
}
|
||
|
||
// registerRouter 向 initialize/router.go 追加路由注册代码
|
||
func (s *CodegenService) registerRouter(outputDir, pkg, pascalModule string, features []codegenModel.Feature) error {
|
||
routerPath := filepath.Join(outputDir, "initialize", "router.go")
|
||
content, err := os.ReadFile(routerPath)
|
||
if err != nil {
|
||
return fmt.Errorf("读取 %s 失败: %w", routerPath, err)
|
||
}
|
||
text := string(content)
|
||
|
||
// 1. 添加变量声明:xxxRouter := router.GroupApp.Xxx
|
||
varLine := fmt.Sprintf("\t%sRouter := router.GroupApp.%s", lowerFirst(pkg), pascalModule)
|
||
if !strings.Contains(text, varLine) {
|
||
// 在 NeedAuthGroup 行之前插入
|
||
needAuthIdx := strings.Index(text, "NeedAuthGroup := Router.Group")
|
||
if needAuthIdx > 0 {
|
||
text = text[:needAuthIdx] + varLine + "\n\n\t" + text[needAuthIdx:]
|
||
}
|
||
}
|
||
|
||
// 2. 添加路由注册:xxxRouter.InitXxxRouter(NeedAuthGroup)
|
||
for _, feature := range features {
|
||
if strings.TrimSpace(feature.Name) == "" {
|
||
continue
|
||
}
|
||
routerCall := fmt.Sprintf("\t\t%sRouter.Init%sRouter(NeedAuthGroup)", lowerFirst(pkg), feature.Name)
|
||
if !strings.Contains(text, routerCall) {
|
||
// 在最后一个 Init...Router 行之后插入
|
||
lastInitIdx := strings.LastIndex(text, "Router(NeedAuthGroup)")
|
||
if lastInitIdx > 0 {
|
||
nlIdx := strings.Index(text[lastInitIdx:], "\n")
|
||
if nlIdx > 0 {
|
||
insertPos := lastInitIdx + nlIdx
|
||
comment := fmt.Sprintf(" //%s", feature.Comment)
|
||
if feature.Comment == "" {
|
||
comment = ""
|
||
}
|
||
text = text[:insertPos] + "\n" + routerCall + comment + text[insertPos:]
|
||
}
|
||
}
|
||
}
|
||
}
|
||
|
||
return os.WriteFile(routerPath, []byte(text), 0644)
|
||
}
|
||
|
||
// registerRouterIncremental 增量模式:只追加新 feature 的路由注册
|
||
func (s *CodegenService) registerRouterIncremental(outputDir, pkg, pascalModule string, features []codegenModel.Feature) error {
|
||
return s.registerRouter(outputDir, pkg, pascalModule, features) // Contains 检查保证幂等
|
||
}
|
||
|
||
// fileExistsOnDisk 检查文件是否已存在
|
||
func fileExistsOnDisk(path string) bool {
|
||
_, err := os.Stat(path)
|
||
return err == nil
|
||
}
|