Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
57 changes: 57 additions & 0 deletions internal/handlers/setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -166,6 +166,41 @@ func (h *SetupHandler) enableUserSystem(adminConfig AdminConfig) error {
})
}

func (h *SetupHandler) isSystemInitialized() (bool, error) {
return isSystemInitialized(h.manager, h.daoManager)
}

func isSystemInitialized(manager *config.ConfigManager, daoManager *repository.RepositoryManager) (bool, error) {
if daoManager != nil && daoManager.User != nil {
count, err := daoManager.User.CountAdminUsers()
if err != nil {
return false, err
}
return count > 0, nil
}

if manager == nil {
return false, nil
}

db := manager.GetDB()
if db == nil {
return false, nil
}

repo := repository.NewRepositoryManager(db)
if repo.User == nil {
return false, nil
}

count, err := repo.User.CountAdminUsers()
if err != nil {
return false, err
}

return count > 0, nil
}

// contains 检查字符串是否包含子字符串
func contains(s, substr string) bool {
for i := 0; i <= len(s)-len(substr); i++ {
Expand Down Expand Up @@ -197,6 +232,17 @@ func InitializeNoDB(manager *config.ConfigManager) gin.HandlerFunc {
return
}
defer atomic.StoreInt32(&initInProgress, 0)

initialized, err := isSystemInitialized(manager, nil)
if err != nil {
logrus.WithError(err).Error("[InitializeNoDB] 检查系统初始化状态失败")
common.InternalServerErrorResponse(c, "检查系统初始化状态失败")
return
}
if initialized {
common.ForbiddenResponse(c, "系统已初始化,禁止重复初始化")
Copy link

Copilot AI Sep 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The duplicate Chinese error message should be extracted to a constant to avoid repetition and facilitate easier maintenance and localization.

Copilot uses AI. Check for mistakes.
return
}
// 解析 JSON(仅接受嵌套结构),不再兼容 legacy 扁平字段
var req SetupRequest
if !utils.BindJSONWithValidation(c, &req) {
Expand Down Expand Up @@ -431,6 +477,17 @@ func (h *SetupHandler) Initialize(c *gin.Context) {
return
}

initialized, err := h.isSystemInitialized()
if err != nil {
logrus.WithError(err).Error("[SetupHandler.Initialize] 检查系统初始化状态失败")
common.InternalServerErrorResponse(c, "检查系统初始化状态失败")
return
}
if initialized {
common.ForbiddenResponse(c, "系统已初始化,禁止重复初始化")
Copy link

Copilot AI Sep 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The duplicate Chinese error message should be extracted to a constant to avoid repetition and facilitate easier maintenance and localization.

Copilot uses AI. Check for mistakes.
return
}

var req SetupRequest
if !utils.BindJSONWithValidation(c, &req) {
return
Expand Down
109 changes: 109 additions & 0 deletions internal/middleware/setup_guard.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,109 @@
package middleware

import (
"net/http"
"strings"

"github.com/gin-gonic/gin"
"github.com/sirupsen/logrus"
)

// SetupGuardConfig controls how the setup guard middleware behaves.
type SetupGuardConfig struct {
// IsInitialized returns the current initialization status.
IsInitialized func() (bool, error)
// SetupPath denotes the setup entry path, defaults to /setup.
SetupPath string
// RedirectPath denotes the path to redirect to once initialized, defaults to /.
RedirectPath string
// AllowPaths lists exact paths that should remain accessible before initialization.
AllowPaths []string
// AllowPrefixes lists path prefixes that should remain accessible before initialization.
AllowPrefixes []string
}

// SetupGuard ensures only setup resources are accessible before initialization
// and blocks setup routes after initialization is complete.
func SetupGuard(cfg SetupGuardConfig) gin.HandlerFunc {
setupPath := cfg.SetupPath
if setupPath == "" {
setupPath = "/setup"
}
redirectPath := cfg.RedirectPath
if redirectPath == "" {
redirectPath = "/"
}

allowPaths := map[string]struct{}{
setupPath: {},
setupPath + "/": {},
}

for _, p := range cfg.AllowPaths {
allowPaths[p] = struct{}{}
}

allowPrefixes := []string{setupPath + "/"}
allowPrefixes = append(allowPrefixes, cfg.AllowPrefixes...)

return func(c *gin.Context) {
initialized := false
if cfg.IsInitialized != nil {
var err error
initialized, err = cfg.IsInitialized()
if err != nil {
logrus.WithError(err).Warn("setup guard: failed to determine initialization state")
// Fail closed on error so users can still reach setup for recovery.
initialized = false
}
}

path := c.Request.URL.Path

if initialized {
if path == setupPath || strings.HasPrefix(path, setupPath+"/") {
switch c.Request.Method {
case http.MethodGet, http.MethodHead:
c.Redirect(http.StatusFound, redirectPath)
case http.MethodOptions:
c.Status(http.StatusNoContent)
default:
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
"code": http.StatusForbidden,
"message": "系统已初始化,禁止重新初始化",
Copy link

Copilot AI Sep 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hard-coded Chinese error messages should be externalized to a localization system or constants file for better maintainability and internationalization support.

Copilot uses AI. Check for mistakes.
})
}
c.Abort()
return
}

c.Next()
return
}

if _, ok := allowPaths[path]; ok {
c.Next()
return
}
for _, prefix := range allowPrefixes {
if strings.HasPrefix(path, prefix) {
c.Next()
return
}
}

switch c.Request.Method {
case http.MethodGet, http.MethodHead:
c.Redirect(http.StatusFound, setupPath)
case http.MethodOptions:
// Allow CORS preflight to complete without redirect loops.
c.Status(http.StatusNoContent)
default:
c.AbortWithStatusJSON(http.StatusForbidden, gin.H{
"code": http.StatusForbidden,
"message": "系统未初始化,请访问 /setup 完成初始化",
Copy link

Copilot AI Sep 24, 2025

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Hard-coded Chinese error messages should be externalized to a localization system or constants file for better maintainability and internationalization support.

Copilot uses AI. Check for mistakes.
})
}
c.Abort()
}
}
65 changes: 65 additions & 0 deletions internal/routes/setup.go
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,71 @@ func CreateAndSetupRouter(
router.Use(middleware.CORS())
router.Use(middleware.RateLimit(manager))

var cachedInitialized atomic.Bool
var cachedRepo atomic.Pointer[repository.RepositoryManager]

guardConfig := middleware.SetupGuardConfig{
SetupPath: "/setup",
RedirectPath: "/",
AllowPaths: []string{
"/setup/initialize",
"/check-init",
"/user/system-info",
"/health",
},
AllowPrefixes: []string{
"/assets/",
"/css/",
"/js/",
"/components/",
},
}

guardConfig.IsInitialized = func() (bool, error) {
if cachedInitialized.Load() {
return true, nil
}

if daoManager != nil && daoManager.User != nil {
count, err := daoManager.User.CountAdminUsers()
if err != nil {
return false, err
}
if count > 0 {
cachedInitialized.Store(true)
return true, nil
}
return false, nil
}

db := manager.GetDB()
if db == nil {
return false, nil
}

repo := cachedRepo.Load()
if repo == nil || repo.DB() != db {
repo = repository.NewRepositoryManager(db)
cachedRepo.Store(repo)
}
if repo.User == nil {
return false, nil
}

count, err := repo.User.CountAdminUsers()
if err != nil {
return false, err
}

if count > 0 {
cachedInitialized.Store(true)
return true, nil
}
return false, nil
}

router.Use(middleware.SetupGuard(guardConfig))

// 如果 daoManager 为 nil,表示尚未初始化数据库,只注册基础和初始化相关的路由
if daoManager == nil {
// 基础路由(不传 userHandler)
Expand Down
Loading