Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
46 changes: 46 additions & 0 deletions KimmyXYC/cmd/server/main.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
package main

import (
"log"
"os"

"github.com/joho/godotenv"

"AIBackend/internal/db"
"AIBackend/internal/httpserver"
"AIBackend/internal/provider"
)

// main 是程序的入口点。它可选加载 `.env`,使用 `DATABASE_URL` 建立并迁移数据库,
// 从环境创建 LLM 提供者,并使用 `ADDR`(默认 ":8080")启动 HTTP 服务器;在连接、迁移或启动失败时记录致命错误,在缺少 `DATABASE_URL` 时记录警告。
func main() {
// Load .env if present (dev convenience)
_ = godotenv.Load()

// Initialize DB
pgURL := os.Getenv("DATABASE_URL")
if pgURL == "" {
log.Println("WARNING: DATABASE_URL is not set. The server may fail to start when DB is required.")
}
gormDB, err := db.Connect(pgURL)
if err != nil {
log.Fatalf("failed to connect database: %v", err)
}
if err := db.AutoMigrate(gormDB); err != nil {
log.Fatalf("failed to migrate database: %v", err)
}

// Initialize LLM provider (Mock by default)
llm := provider.NewProviderFromEnv()

// Start HTTP server
r := httpserver.NewRouter(gormDB, llm)
addr := os.Getenv("ADDR")
if addr == "" {
addr = ":8080"
}
log.Printf("Server listening on %s", addr)
if err := r.Run(addr); err != nil {
log.Fatalf("server error: %v", err)
}
}
37 changes: 37 additions & 0 deletions KimmyXYC/internal/db/db.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
package db

import (
"fmt"

"gorm.io/driver/postgres"
"gorm.io/gorm"

"AIBackend/internal/models"
)

// Connect 使用提供的数据库 URL 打开一个 PostgreSQL 连接。
// 如果传入的 databaseURL 为空,则使用本地默认 DSN:
// postgres://postgres:postgres@localhost:5432/aibackend?sslmode=disable。
// 返回已打开的 *gorm.DB;在无法建立连接时返回非 nil 错误。
func Connect(databaseURL string) (*gorm.DB, error) {
if databaseURL == "" {
// Provide a friendly default to help first run; it will still fail if DB not available.
databaseURL = "postgres://postgres:postgres@localhost:5432/aibackend?sslmode=disable"
}
dsn := databaseURL
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
if err != nil {
return nil, fmt.Errorf("connect postgres: %w", err)
}
return db, nil
}

// AutoMigrate 在数据库上应用 User、Conversation 和 Message 模型的自动迁移。
// 如果迁移失败,返回相应的错误。
func AutoMigrate(db *gorm.DB) error {
return db.AutoMigrate(
&models.User{},
&models.Conversation{},
&models.Message{},
)
}
200 changes: 200 additions & 0 deletions KimmyXYC/internal/httpserver/router.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,200 @@
package httpserver

import (
"net/http"
"strconv"
"time"

"github.com/gin-gonic/gin"
"gorm.io/gorm"

"AIBackend/internal/provider"
"AIBackend/internal/services"
"AIBackend/pkg/middleware"
)

type Server struct {
Auth *services.AuthService
Chat *services.ChatService
}

// NewRouter 创建并返回已配置的 Gin 引擎,注册健康检查、认证相关路由、带鉴权的会话与聊天 API(包含可选的 SSE 流式聊天)并提供前端静态文件服务。
func NewRouter(db *gorm.DB, llm provider.LLMProvider) *gin.Engine {
g := gin.Default()

server := &Server{
Auth: services.NewAuthService(db),
Chat: services.NewChatService(db, llm),
}

g.GET("/health", func(c *gin.Context) { c.JSON(http.StatusOK, gin.H{"status": "ok"}) })

api := g.Group("/api")
{
auth := api.Group("/auth")
auth.POST("/register", server.handleRegister)
auth.POST("/login", server.handleLogin)
}

protected := api.Group("")
protected.Use(middleware.AuthRequired())
{
protected.GET("/me", server.handleMe)
protected.GET("/conversations", server.handleListConversations)
protected.GET("/conversations/:id/messages", server.handleGetMessages)
protected.POST("/chat", middleware.ModelAccess(), server.handleChat)
}

// Serve static frontend files without conflicting wildcard
g.StaticFile("/", "./web/index.html")
g.Static("/css", "./web/css")
g.Static("/js", "./web/js")

return g
}

type registerReq struct {
Email string `json:"email" binding:"required"`
Password string `json:"password" binding:"required"`
Role string `json:"role"`
}

type loginReq struct {
Email string `json:"email" binding:"required"`
Password string `json:"password" binding:"required"`
}

func (s *Server) handleRegister(c *gin.Context) {
var req registerReq
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
user, token, err := s.Auth.Register(req.Email, req.Password, req.Role)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"user": user, "token": token})
}

func (s *Server) handleLogin(c *gin.Context) {
var req loginReq
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
user, token, err := s.Auth.Login(req.Email, req.Password)
if err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"user": user, "token": token})
}

func (s *Server) handleMe(c *gin.Context) {
c.JSON(http.StatusOK, gin.H{
"user_id": c.GetUint("user_id"),
"user_email": c.GetString("user_email"),
"user_role": c.GetString("user_role"),
})
}

func (s *Server) handleListConversations(c *gin.Context) {
uid := c.GetUint("user_id")
convs, err := s.Chat.ListConversations(uid)
if err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"conversations": convs})
}

func (s *Server) handleGetMessages(c *gin.Context) {
uid := c.GetUint("user_id")
idStr := c.Param("id")
id64, _ := strconv.ParseUint(idStr, 10, 64)
msgs, err := s.Chat.GetMessages(uid, uint(id64))
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"messages": msgs})
}

type chatReq struct {
ConversationID uint `json:"conversation_id"`
Model string `json:"model"`
Message string `json:"message" binding:"required"`
Stream *bool `json:"stream"`
}

func (s *Server) handleChat(c *gin.Context) {
uid := c.GetUint("user_id")
var req chatReq
if err := c.ShouldBindJSON(&req); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
// Fallback to query param model for middleware check compatibility
if req.Model == "" {
req.Model = c.Query("model")
}
// Enforce model access if provided in body
role := c.GetString("user_role")
if !middleware.CheckModelAccess(role, req.Model) {
c.JSON(http.StatusForbidden, gin.H{"error": "model access denied for role"})
return
}
streaming := false
if req.Stream != nil {
streaming = *req.Stream
}
if c.Query("stream") == "1" || c.Query("stream") == "true" {
streaming = true
}
if !streaming {
convID, reply, err := s.Chat.SendMessage(c.Request.Context(), uid, req.ConversationID, req.Model, req.Message, nil)
if err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": err.Error()})
return
}
c.JSON(http.StatusOK, gin.H{"conversation_id": convID, "reply": reply})
return
}
// Streaming via SSE
w := c.Writer
c.Header("Content-Type", "text/event-stream")
c.Header("Cache-Control", "no-cache")
c.Header("Connection", "keep-alive")
c.Status(http.StatusOK)
flusher, _ := w.(http.Flusher)
sentAny := false
convID, _, err := s.Chat.SendMessage(c.Request.Context(), uid, req.ConversationID, req.Model, req.Message, func(chunk string) error {
sentAny = true
_, err := w.Write([]byte("data: " + chunk + "\n\n"))
if err == nil && flusher != nil {
flusher.Flush()
}
return err
})
if err != nil {
// send error as SSE comment and 0-length event end
_, _ = w.Write([]byte(": error: " + err.Error() + "\n\n"))
if flusher != nil {
flusher.Flush()
}
return
}
if !sentAny {
// send at least one empty event to keep clients happy
_, _ = w.Write([]byte("data: \n\n"))
}
// end event
_, _ = w.Write([]byte("event: done\n" + "data: {\"conversation_id\": " + strconv.FormatUint(uint64(convID), 10) + "}\n\n"))
if flusher != nil {
flusher.Flush()
}
// allow connection to close shortly after
time.Sleep(50 * time.Millisecond)
}
Loading
Loading