实现 Grpc 及 Auto2.0

This commit is contained in:
xiangning 2025-01-15 16:22:20 +08:00
parent f738088a9e
commit bc740d6e92
14 changed files with 614 additions and 39 deletions

45
client/main.go Normal file
View File

@ -0,0 +1,45 @@
package main
import (
"context"
"fmt"
"log"
"ly-user-center/proto/pb"
"google.golang.org/grpc"
"google.golang.org/grpc/credentials/insecure"
)
func main() {
conn, err := grpc.Dial("localhost:50051", grpc.WithTransportCredentials(insecure.NewCredentials()))
if err != nil {
log.Fatalf("did not connect: %v", err)
}
defer conn.Close()
client := pb.NewUserServiceClient(conn)
// 测试注册
registerResp, err := client.Register(context.Background(), &pb.RegisterRequest{
Username: "testuser",
Password: "password123",
Email: "test@example.com",
Phone: "1234567890",
Sex: pb.Gender_MALE,
})
if err != nil {
log.Fatalf("Register failed: %v", err)
}
fmt.Printf("Register response: %v\n", registerResp)
// 测试登录
loginResp, err := client.Login(context.Background(), &pb.LoginRequest{
Username: "testuser",
Password: "password123",
})
if err != nil {
log.Fatalf("Login failed: %v", err)
}
fmt.Printf("Login response: %v\n", loginResp)
}

25
config/oauth.go Normal file
View File

@ -0,0 +1,25 @@
package config
// OAuth2Config OAuth 2.0 配置
type OAuth2Config struct {
ClientID string
ClientSecret string
RedirectURI string
Scopes []string
}
// OAuth2Configs 支持的 OAuth 2.0 提供商配置
var OAuth2Configs = map[string]OAuth2Config{
"github": {
ClientID: "your_github_client_id",
ClientSecret: "your_github_client_secret",
RedirectURI: "http://localhost:8999/oauth/callback/github",
Scopes: []string{"user:email"},
},
"google": {
ClientID: "your_google_client_id",
ClientSecret: "your_google_client_secret",
RedirectURI: "http://localhost:8999/oauth/callback/google",
Scopes: []string{"profile", "email"},
},
}

209
controllers/oauth.go Normal file
View File

@ -0,0 +1,209 @@
package controllers
import (
"crypto/rand"
"encoding/base64"
"net/http"
"time"
"ly-user-center/models"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
)
// OAuthAuthorizeInput 授权请求参数
type OAuthAuthorizeInput struct {
ResponseType string `form:"response_type" binding:"required,oneof=code token"`
ClientID string `form:"client_id" binding:"required"`
RedirectURI string `form:"redirect_uri" binding:"required,url"`
Scope string `form:"scope"`
State string `form:"state"`
}
// OAuthTokenInput 令牌请求参数
type OAuthTokenInput struct {
GrantType string `form:"grant_type" binding:"required,oneof=authorization_code refresh_token client_credentials"`
Code string `form:"code"`
RefreshToken string `form:"refresh_token"`
ClientID string `form:"client_id" binding:"required"`
ClientSecret string `form:"client_secret" binding:"required"`
RedirectURI string `form:"redirect_uri"`
}
// Authorize OAuth 2.0 授权端点
func Authorize(db *gorm.DB) gin.HandlerFunc {
return func(c *gin.Context) {
var input OAuthAuthorizeInput
if err := c.ShouldBindQuery(&input); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_request"})
return
}
// 验证客户端
var client models.OAuthClient
if err := db.Where("client_id = ?", input.ClientID).First(&client).Error; err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_client"})
return
}
// 生成授权码
code := generateRandomString(32)
authCode := models.OAuthAuthorizationCode{
UserID: c.GetString("user_id"),
ClientID: client.ID,
Code: code,
Scope: input.Scope,
ExpiresAt: time.Now().Add(10 * time.Minute),
}
if err := db.Create(&authCode).Error; err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "server_error"})
return
}
// 重定向回客户端
redirectURI := input.RedirectURI + "?code=" + code
if input.State != "" {
redirectURI += "&state=" + input.State
}
c.Redirect(http.StatusFound, redirectURI)
}
}
// Token OAuth 2.0 令牌端点
func Token(db *gorm.DB) gin.HandlerFunc {
return func(c *gin.Context) {
var input OAuthTokenInput
if err := c.ShouldBind(&input); err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_request"})
return
}
switch input.GrantType {
case "authorization_code":
handleAuthorizationCodeGrant(c, db, input)
case "refresh_token":
handleRefreshTokenGrant(c, db, input)
case "client_credentials":
handleClientCredentialsGrant(c, db, input)
default:
c.JSON(http.StatusBadRequest, gin.H{"error": "unsupported_grant_type"})
}
}
}
// 生成随机字符串
func generateRandomString(length int) string {
b := make([]byte, length)
rand.Read(b)
return base64.URLEncoding.EncodeToString(b)[:length]
}
// 处理授权码授权方式
func handleAuthorizationCodeGrant(c *gin.Context, db *gorm.DB, input OAuthTokenInput) {
var authCode models.OAuthAuthorizationCode
if err := db.Where("code = ? AND client_id = ?", input.Code, input.ClientID).
Where("expires_at > ?", time.Now()).
First(&authCode).Error; err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_grant"})
return
}
// 生成访问令牌
accessToken := generateRandomString(32)
refreshToken := generateRandomString(32)
token := models.OAuthAccessToken{
UserID: authCode.UserID,
ClientID: authCode.ClientID,
Token: accessToken,
Scope: authCode.Scope,
ExpiresAt: time.Now().Add(24 * time.Hour),
RefreshToken: refreshToken,
}
if err := db.Create(&token).Error; err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "server_error"})
return
}
// 删除已使用的授权码
db.Delete(&authCode)
c.JSON(http.StatusOK, gin.H{
"access_token": accessToken,
"token_type": "Bearer",
"expires_in": 86400,
"refresh_token": refreshToken,
"scope": authCode.Scope,
})
}
func handleClientCredentialsGrant(c *gin.Context, db *gorm.DB, input OAuthTokenInput) {
// 验证客户端凭据
var client models.OAuthClient
if err := db.Where("client_id = ? AND client_secret = ?",
input.ClientID, input.ClientSecret).First(&client).Error; err != nil {
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid_client"})
return
}
// 生成访问令牌
accessToken := generateRandomString(32)
token := models.OAuthAccessToken{
ClientID: client.ID,
Token: accessToken,
Scope: input.Scope,
ExpiresAt: time.Now().Add(24 * time.Hour),
}
if err := db.Create(&token).Error; err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "server_error"})
return
}
c.JSON(http.StatusOK, gin.H{
"access_token": accessToken,
"token_type": "Bearer",
"expires_in": 86400,
"scope": input.Scope,
})
}
func handleRefreshTokenGrant(c *gin.Context, db *gorm.DB, input OAuthTokenInput) {
var oldToken models.OAuthAccessToken
if err := db.Where("refresh_token = ?", input.RefreshToken).First(&oldToken).Error; err != nil {
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_grant"})
return
}
// 生成新令牌
newAccessToken := generateRandomString(32)
newRefreshToken := generateRandomString(32)
token := models.OAuthAccessToken{
UserID: oldToken.UserID,
ClientID: oldToken.ClientID,
Token: newAccessToken,
Scope: oldToken.Scope,
ExpiresAt: time.Now().Add(24 * time.Hour),
RefreshToken: newRefreshToken,
}
if err := db.Create(&token).Error; err != nil {
c.JSON(http.StatusInternalServerError, gin.H{"error": "server_error"})
return
}
// 删除旧令牌
db.Delete(&oldToken)
c.JSON(http.StatusOK, gin.H{
"access_token": newAccessToken,
"token_type": "Bearer",
"expires_in": 86400,
"refresh_token": newRefreshToken,
"scope": oldToken.Scope,
})
}

View File

@ -2,8 +2,8 @@ package controllers
import (
"fmt"
"ly.user_center/models"
"ly.user_center/utils"
"ly-user-center/models"
"ly-user-center/utils"
"net/http"
"github.com/gin-gonic/gin"

11
go.mod
View File

@ -1,12 +1,15 @@
module ly.user_center
module ly-user-center
go 1.23.2
go 1.22
toolchain go1.23.2
require (
github.com/gin-gonic/gin v1.10.0
github.com/golang-jwt/jwt/v4 v4.5.1
github.com/google/uuid v1.6.0
golang.org/x/crypto v0.32.0
google.golang.org/grpc v1.62.1
gopkg.in/yaml.v3 v3.0.1
gorm.io/driver/postgres v1.5.11
gorm.io/gorm v1.25.12
@ -22,6 +25,7 @@ require (
github.com/go-playground/universal-translator v0.18.1 // indirect
github.com/go-playground/validator/v10 v10.23.0 // indirect
github.com/goccy/go-json v0.10.4 // indirect
github.com/golang/protobuf v1.5.3 // indirect
github.com/jackc/pgpassfile v1.0.0 // indirect
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
github.com/jackc/pgx/v5 v5.5.5 // indirect
@ -44,5 +48,6 @@ require (
golang.org/x/sync v0.10.0 // indirect
golang.org/x/sys v0.29.0 // indirect
golang.org/x/text v0.21.0 // indirect
google.golang.org/protobuf v1.36.2 // indirect
google.golang.org/genproto/googleapis/rpc v0.0.0-20240123012728-ef4313101c80 // indirect
google.golang.org/protobuf v1.34.1 // indirect
)

17
go.sum
View File

@ -28,8 +28,12 @@ github.com/goccy/go-json v0.10.4 h1:JSwxQzIqKfmFX1swYPpUThQZp/Ka4wzJdK0LWVytLPM=
github.com/goccy/go-json v0.10.4/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
github.com/golang-jwt/jwt/v4 v4.5.1 h1:JdqV9zKUdtaa9gdPlywC3aeoEsR681PlKC+4F5gQgeo=
github.com/golang-jwt/jwt/v4 v4.5.1/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
@ -99,10 +103,15 @@ golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
google.golang.org/protobuf v1.36.2 h1:R8FeyR1/eLmkutZOM5CWghmo5itiG9z0ktFlTVLuTmU=
google.golang.org/protobuf v1.36.2/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240123012728-ef4313101c80 h1:AjyfHzEPEFp/NpvfN5g+KDla3EMojjhRVZc1i7cj+oM=
google.golang.org/genproto/googleapis/rpc v0.0.0-20240123012728-ef4313101c80/go.mod h1:PAREbraiVEVGVdTZsVWjSbbTtSyGbAgIIvni8a8CD5s=
google.golang.org/grpc v1.62.1 h1:B4n+nfKzOICUXMgyrNd19h/I9oH0L1pizfk1d4zSgTk=
google.golang.org/grpc v1.62.1/go.mod h1:IWTG0VlJLCh1SkC58F7np9ka9mx/WNkjl4PGJaiq+QE=
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=

Binary file not shown.

33
main.go
View File

@ -2,10 +2,14 @@ package main
import (
"fmt"
"ly.user_center/config"
"ly.user_center/models"
"ly.user_center/routes"
"net"
"ly-user-center/config"
"ly-user-center/models"
"ly-user-center/services"
"google.golang.org/grpc"
"gorm.io/driver/postgres"
"gorm.io/gorm"
)
@ -14,8 +18,6 @@ func main() {
// 加载配置文件
dsn := config.LoadDBConfig()
fmt.Println("dsn======", dsn)
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
if err != nil {
panic("数据库初始化连接失败: " + err.Error())
@ -23,10 +25,23 @@ func main() {
// 自动迁移
models.MigrateUser(db)
models.MigrateOAuth(db)
// 初始化路由
router := routes.SetupRoutes(db)
// 创建 gRPC 服务器
grpcServer := grpc.NewServer()
// 启动服务器
router.Run(":8999")
// 注册用户服务
userService := services.NewUserService(db)
pb.RegisterUserServiceServer(grpcServer, userService)
// 启动 gRPC 服务器
lis, err := net.Listen("tcp", ":50051")
if err != nil {
panic(fmt.Sprintf("failed to listen: %v", err))
}
fmt.Println("gRPC server is running on :50051")
if err := grpcServer.Serve(lis); err != nil {
panic(fmt.Sprintf("failed to serve: %v", err))
}
}

View File

@ -1,7 +1,7 @@
package middlewares
import (
"ly.user_center/utils"
"ly-user-center/utils"
"net/http"
"github.com/gin-gonic/gin"

35
middlewares/recovery.go Normal file
View File

@ -0,0 +1,35 @@
package middlewares
import (
"fmt"
"net/http"
"runtime/debug"
"github.com/gin-gonic/gin"
)
// Recovery 全局异常处理中间件
func Recovery() gin.HandlerFunc {
return func(c *gin.Context) {
defer func() {
if err := recover(); err != nil {
// 打印错误堆栈信息
fmt.Printf("panic: %v\n", err)
fmt.Printf("stack: %s\n", debug.Stack())
// 返回 500 错误
c.JSON(http.StatusInternalServerError, gin.H{
"code": http.StatusInternalServerError,
"message": "服务器内部错误",
"error": fmt.Sprint(err),
})
// 终止后续中间件
c.Abort()
}
}()
// 继续处理请求
c.Next()
}
}

50
models/oauth.go Normal file
View File

@ -0,0 +1,50 @@
package models
import (
"time"
"gorm.io/gorm"
)
// OAuthClient OAuth 2.0 客户端
type OAuthClient struct {
ID string `gorm:"primaryKey"`
ClientID string `gorm:"unique;not null"`
ClientSecret string `gorm:"not null"`
RedirectURIs string `gorm:"not null"` // 以逗号分隔的重定向URI列表
Scopes string `gorm:"not null"` // 以逗号分隔的授权范围列表
CreatedAt time.Time
UpdatedAt time.Time
}
// OAuthAccessToken 访问令牌
type OAuthAccessToken struct {
ID string `gorm:"primaryKey"`
UserID string `gorm:"not null"`
ClientID string `gorm:"not null"`
Token string `gorm:"unique;not null"`
Scope string
ExpiresAt time.Time
RefreshToken string `gorm:"unique"`
CreatedAt time.Time
UpdatedAt time.Time
}
// OAuthAuthorizationCode 授权码
type OAuthAuthorizationCode struct {
ID string `gorm:"primaryKey"`
UserID string `gorm:"not null"`
ClientID string `gorm:"not null"`
Code string `gorm:"unique;not null"`
Scope string
ExpiresAt time.Time
CreatedAt time.Time
}
// MigrateOAuth 初始化 OAuth 相关表
func MigrateOAuth(db *gorm.DB) {
err := db.AutoMigrate(&OAuthClient{}, &OAuthAccessToken{}, &OAuthAuthorizationCode{})
if err != nil {
panic("failed to migrate OAuth models: " + err.Error())
}
}

61
proto/user.proto Normal file
View File

@ -0,0 +1,61 @@
syntax = "proto3";
package pb;
option go_package = "ly-user-center/proto/pb";
//
service UserService {
//
rpc Register(RegisterRequest) returns (RegisterResponse) {}
//
rpc Login(LoginRequest) returns (LoginResponse) {}
//
rpc RefreshToken(RefreshTokenRequest) returns (RefreshTokenResponse) {}
}
//
enum Gender {
FEMALE = 0;
MALE = 1;
}
//
message RegisterRequest {
string username = 1;
string password = 2;
string email = 3;
string phone = 4;
Gender sex = 5;
}
//
message RegisterResponse {
bool success = 1;
string message = 2;
}
//
message LoginRequest {
string username = 1;
string password = 2;
}
//
message LoginResponse {
string token = 1;
int64 expiration_time = 2;
string message = 3;
}
//
message RefreshTokenRequest {
string token = 1;
}
//
message RefreshTokenResponse {
string token = 1;
int64 expiration_time = 2;
string message = 3;
}

View File

@ -1,8 +1,8 @@
package routes
import (
"ly.user_center/controllers"
"ly.user_center/middlewares"
"ly-user-center/controllers"
"ly-user-center/middlewares"
"github.com/gin-gonic/gin"
"gorm.io/gorm"
@ -11,28 +11,25 @@ import (
func SetupRoutes(db *gorm.DB) *gin.Engine {
router := gin.Default()
// 公共路由组
publicRouter := router.Group("/users")
// 添加全局异常处理中间件
router.Use(middlewares.Recovery())
// OAuth 2.0 端点
oauth := router.Group("/oauth")
{
// 用户注册和登录
publicRouter.POST("/register", controllers.Register(db))
publicRouter.POST("/login", controllers.Login(db))
publicRouter.POST("/refreshToken", controllers.RefreshToken(db))
oauth.GET("/authorize", middlewares.AuthMiddleware(), controllers.Authorize(db))
oauth.POST("/token", controllers.Token(db))
}
// 受保护的路由组
protected := router.Group("/api")
protected.Use(middlewares.AuthMiddleware())
// 注册路由
router.POST("/register", controllers.Register(db))
router.POST("/login", controllers.Login(db))
// 需要认证的路由组
authorized := router.Group("/")
authorized.Use(middlewares.AuthMiddleware())
{
protected.GET("/profile", func(c *gin.Context) {
// 从中间件获取用户ID
userID, exists := c.Get("user_id")
if !exists {
c.JSON(401, gin.H{"error": "未授权"})
return
}
c.JSON(200, gin.H{"message": "访问成功", "user_id": userID})
})
authorized.POST("/refresh-token", controllers.RefreshToken(db))
}
return router

124
services/user_service.go Normal file
View File

@ -0,0 +1,124 @@
package services
import (
"context"
"ly-user-center/models"
"ly-user-center/proto/pb"
"ly-user-center/utils"
"golang.org/x/crypto/bcrypt"
"google.golang.org/grpc/codes"
"google.golang.org/grpc/status"
"gorm.io/gorm"
)
type UserService struct {
pb.UnimplementedUserServiceServer
db *gorm.DB
}
func NewUserService(db *gorm.DB) *UserService {
return &UserService{db: db}
}
// Register 实现用户注册
func (s *UserService) Register(ctx context.Context, req *pb.RegisterRequest) (*pb.RegisterResponse, error) {
// 开始事务
tx := s.db.Begin()
defer func() {
if r := recover(); r != nil {
tx.Rollback()
}
}()
// 检查用户名、邮箱和手机号是否重复
var existingUser models.User
if err := tx.Where("user_name = ? OR email = ? OR phone = ?",
req.Username, req.Email, req.Phone).First(&existingUser).Error; err == nil {
tx.Rollback()
return nil, status.Error(codes.AlreadyExists, "用户名、邮箱或手机号已存在")
}
// 密码哈希
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost)
if err != nil {
tx.Rollback()
return nil, status.Error(codes.Internal, "密码加密失败")
}
// 创建用户
user := models.User{
Username: req.Username,
Password: string(hashedPassword),
Email: req.Email,
Phone: req.Phone,
Sex: models.Gender(req.Sex),
}
if err := tx.Create(&user).Error; err != nil {
tx.Rollback()
return nil, status.Error(codes.Internal, "创建用户失败")
}
// 提交事务
if err := tx.Commit().Error; err != nil {
return nil, status.Error(codes.Internal, "提交事务失败")
}
return &pb.RegisterResponse{
Success: true,
Message: "用户创建成功",
}, nil
}
// Login 实现用户登录
func (s *UserService) Login(ctx context.Context, req *pb.LoginRequest) (*pb.LoginResponse, error) {
var user models.User
if err := s.db.Where("user_name = ?", req.Username).First(&user).Error; err != nil {
return nil, status.Error(codes.NotFound, "用户名或密码错误")
}
if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(req.Password)); err != nil {
return nil, status.Error(codes.Unauthenticated, "用户名或密码错误")
}
// 生成 JWT
token, expirationTime, err := utils.GenerateToken(user.ID)
if err != nil {
return nil, status.Error(codes.Internal, "生成令牌失败")
}
return &pb.LoginResponse{
Token: token,
ExpirationTime: expirationTime,
Message: "登录成功",
}, nil
}
// RefreshToken 实现令牌刷新
func (s *UserService) RefreshToken(ctx context.Context, req *pb.RefreshTokenRequest) (*pb.RefreshTokenResponse, error) {
// 验证旧令牌
claims, err := utils.ParseToken(req.Token)
if err != nil {
return nil, status.Error(codes.Unauthenticated, "令牌无效或已过期")
}
// 查询用户是否存在
var user models.User
if err := s.db.Where("id = ?", claims.UserID).First(&user).Error; err != nil {
return nil, status.Error(codes.NotFound, "用户不存在")
}
// 生成新令牌
newToken, newExpirationTime, err := utils.GenerateToken(user.ID)
if err != nil {
return nil, status.Error(codes.Internal, "生成新令牌失败")
}
return &pb.RefreshTokenResponse{
Token: newToken,
ExpirationTime: newExpirationTime,
Message: "令牌刷新成功",
}, nil
}