实现 Grpc 及 Auto2.0
This commit is contained in:
parent
f738088a9e
commit
bc740d6e92
45
client/main.go
Normal file
45
client/main.go
Normal file
@ -0,0 +1,45 @@
|
||||
package main
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
"log"
|
||||
|
||||
"ly-user-center/proto/pb"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"google.golang.org/grpc/credentials/insecure"
|
||||
)
|
||||
|
||||
func main() {
|
||||
conn, err := grpc.Dial("localhost:50051", grpc.WithTransportCredentials(insecure.NewCredentials()))
|
||||
if err != nil {
|
||||
log.Fatalf("did not connect: %v", err)
|
||||
}
|
||||
defer conn.Close()
|
||||
|
||||
client := pb.NewUserServiceClient(conn)
|
||||
|
||||
// 测试注册
|
||||
registerResp, err := client.Register(context.Background(), &pb.RegisterRequest{
|
||||
Username: "testuser",
|
||||
Password: "password123",
|
||||
Email: "test@example.com",
|
||||
Phone: "1234567890",
|
||||
Sex: pb.Gender_MALE,
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatalf("Register failed: %v", err)
|
||||
}
|
||||
fmt.Printf("Register response: %v\n", registerResp)
|
||||
|
||||
// 测试登录
|
||||
loginResp, err := client.Login(context.Background(), &pb.LoginRequest{
|
||||
Username: "testuser",
|
||||
Password: "password123",
|
||||
})
|
||||
if err != nil {
|
||||
log.Fatalf("Login failed: %v", err)
|
||||
}
|
||||
fmt.Printf("Login response: %v\n", loginResp)
|
||||
}
|
25
config/oauth.go
Normal file
25
config/oauth.go
Normal file
@ -0,0 +1,25 @@
|
||||
package config
|
||||
|
||||
// OAuth2Config OAuth 2.0 配置
|
||||
type OAuth2Config struct {
|
||||
ClientID string
|
||||
ClientSecret string
|
||||
RedirectURI string
|
||||
Scopes []string
|
||||
}
|
||||
|
||||
// OAuth2Configs 支持的 OAuth 2.0 提供商配置
|
||||
var OAuth2Configs = map[string]OAuth2Config{
|
||||
"github": {
|
||||
ClientID: "your_github_client_id",
|
||||
ClientSecret: "your_github_client_secret",
|
||||
RedirectURI: "http://localhost:8999/oauth/callback/github",
|
||||
Scopes: []string{"user:email"},
|
||||
},
|
||||
"google": {
|
||||
ClientID: "your_google_client_id",
|
||||
ClientSecret: "your_google_client_secret",
|
||||
RedirectURI: "http://localhost:8999/oauth/callback/google",
|
||||
Scopes: []string{"profile", "email"},
|
||||
},
|
||||
}
|
209
controllers/oauth.go
Normal file
209
controllers/oauth.go
Normal file
@ -0,0 +1,209 @@
|
||||
package controllers
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"net/http"
|
||||
"time"
|
||||
|
||||
"ly-user-center/models"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// OAuthAuthorizeInput 授权请求参数
|
||||
type OAuthAuthorizeInput struct {
|
||||
ResponseType string `form:"response_type" binding:"required,oneof=code token"`
|
||||
ClientID string `form:"client_id" binding:"required"`
|
||||
RedirectURI string `form:"redirect_uri" binding:"required,url"`
|
||||
Scope string `form:"scope"`
|
||||
State string `form:"state"`
|
||||
}
|
||||
|
||||
// OAuthTokenInput 令牌请求参数
|
||||
type OAuthTokenInput struct {
|
||||
GrantType string `form:"grant_type" binding:"required,oneof=authorization_code refresh_token client_credentials"`
|
||||
Code string `form:"code"`
|
||||
RefreshToken string `form:"refresh_token"`
|
||||
ClientID string `form:"client_id" binding:"required"`
|
||||
ClientSecret string `form:"client_secret" binding:"required"`
|
||||
RedirectURI string `form:"redirect_uri"`
|
||||
}
|
||||
|
||||
// Authorize OAuth 2.0 授权端点
|
||||
func Authorize(db *gorm.DB) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var input OAuthAuthorizeInput
|
||||
if err := c.ShouldBindQuery(&input); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_request"})
|
||||
return
|
||||
}
|
||||
|
||||
// 验证客户端
|
||||
var client models.OAuthClient
|
||||
if err := db.Where("client_id = ?", input.ClientID).First(&client).Error; err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_client"})
|
||||
return
|
||||
}
|
||||
|
||||
// 生成授权码
|
||||
code := generateRandomString(32)
|
||||
authCode := models.OAuthAuthorizationCode{
|
||||
UserID: c.GetString("user_id"),
|
||||
ClientID: client.ID,
|
||||
Code: code,
|
||||
Scope: input.Scope,
|
||||
ExpiresAt: time.Now().Add(10 * time.Minute),
|
||||
}
|
||||
|
||||
if err := db.Create(&authCode).Error; err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "server_error"})
|
||||
return
|
||||
}
|
||||
|
||||
// 重定向回客户端
|
||||
redirectURI := input.RedirectURI + "?code=" + code
|
||||
if input.State != "" {
|
||||
redirectURI += "&state=" + input.State
|
||||
}
|
||||
c.Redirect(http.StatusFound, redirectURI)
|
||||
}
|
||||
}
|
||||
|
||||
// Token OAuth 2.0 令牌端点
|
||||
func Token(db *gorm.DB) gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
var input OAuthTokenInput
|
||||
if err := c.ShouldBind(&input); err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_request"})
|
||||
return
|
||||
}
|
||||
|
||||
switch input.GrantType {
|
||||
case "authorization_code":
|
||||
handleAuthorizationCodeGrant(c, db, input)
|
||||
case "refresh_token":
|
||||
handleRefreshTokenGrant(c, db, input)
|
||||
case "client_credentials":
|
||||
handleClientCredentialsGrant(c, db, input)
|
||||
default:
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "unsupported_grant_type"})
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// 生成随机字符串
|
||||
func generateRandomString(length int) string {
|
||||
b := make([]byte, length)
|
||||
rand.Read(b)
|
||||
return base64.URLEncoding.EncodeToString(b)[:length]
|
||||
}
|
||||
|
||||
// 处理授权码授权方式
|
||||
func handleAuthorizationCodeGrant(c *gin.Context, db *gorm.DB, input OAuthTokenInput) {
|
||||
var authCode models.OAuthAuthorizationCode
|
||||
if err := db.Where("code = ? AND client_id = ?", input.Code, input.ClientID).
|
||||
Where("expires_at > ?", time.Now()).
|
||||
First(&authCode).Error; err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_grant"})
|
||||
return
|
||||
}
|
||||
|
||||
// 生成访问令牌
|
||||
accessToken := generateRandomString(32)
|
||||
refreshToken := generateRandomString(32)
|
||||
|
||||
token := models.OAuthAccessToken{
|
||||
UserID: authCode.UserID,
|
||||
ClientID: authCode.ClientID,
|
||||
Token: accessToken,
|
||||
Scope: authCode.Scope,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
RefreshToken: refreshToken,
|
||||
}
|
||||
|
||||
if err := db.Create(&token).Error; err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "server_error"})
|
||||
return
|
||||
}
|
||||
|
||||
// 删除已使用的授权码
|
||||
db.Delete(&authCode)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"access_token": accessToken,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 86400,
|
||||
"refresh_token": refreshToken,
|
||||
"scope": authCode.Scope,
|
||||
})
|
||||
}
|
||||
|
||||
func handleClientCredentialsGrant(c *gin.Context, db *gorm.DB, input OAuthTokenInput) {
|
||||
// 验证客户端凭据
|
||||
var client models.OAuthClient
|
||||
if err := db.Where("client_id = ? AND client_secret = ?",
|
||||
input.ClientID, input.ClientSecret).First(&client).Error; err != nil {
|
||||
c.JSON(http.StatusUnauthorized, gin.H{"error": "invalid_client"})
|
||||
return
|
||||
}
|
||||
|
||||
// 生成访问令牌
|
||||
accessToken := generateRandomString(32)
|
||||
token := models.OAuthAccessToken{
|
||||
ClientID: client.ID,
|
||||
Token: accessToken,
|
||||
Scope: input.Scope,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
}
|
||||
|
||||
if err := db.Create(&token).Error; err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "server_error"})
|
||||
return
|
||||
}
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"access_token": accessToken,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 86400,
|
||||
"scope": input.Scope,
|
||||
})
|
||||
}
|
||||
|
||||
func handleRefreshTokenGrant(c *gin.Context, db *gorm.DB, input OAuthTokenInput) {
|
||||
var oldToken models.OAuthAccessToken
|
||||
if err := db.Where("refresh_token = ?", input.RefreshToken).First(&oldToken).Error; err != nil {
|
||||
c.JSON(http.StatusBadRequest, gin.H{"error": "invalid_grant"})
|
||||
return
|
||||
}
|
||||
|
||||
// 生成新令牌
|
||||
newAccessToken := generateRandomString(32)
|
||||
newRefreshToken := generateRandomString(32)
|
||||
|
||||
token := models.OAuthAccessToken{
|
||||
UserID: oldToken.UserID,
|
||||
ClientID: oldToken.ClientID,
|
||||
Token: newAccessToken,
|
||||
Scope: oldToken.Scope,
|
||||
ExpiresAt: time.Now().Add(24 * time.Hour),
|
||||
RefreshToken: newRefreshToken,
|
||||
}
|
||||
|
||||
if err := db.Create(&token).Error; err != nil {
|
||||
c.JSON(http.StatusInternalServerError, gin.H{"error": "server_error"})
|
||||
return
|
||||
}
|
||||
|
||||
// 删除旧令牌
|
||||
db.Delete(&oldToken)
|
||||
|
||||
c.JSON(http.StatusOK, gin.H{
|
||||
"access_token": newAccessToken,
|
||||
"token_type": "Bearer",
|
||||
"expires_in": 86400,
|
||||
"refresh_token": newRefreshToken,
|
||||
"scope": oldToken.Scope,
|
||||
})
|
||||
}
|
@ -2,8 +2,8 @@ package controllers
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"ly.user_center/models"
|
||||
"ly.user_center/utils"
|
||||
"ly-user-center/models"
|
||||
"ly-user-center/utils"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
11
go.mod
11
go.mod
@ -1,12 +1,15 @@
|
||||
module ly.user_center
|
||||
module ly-user-center
|
||||
|
||||
go 1.23.2
|
||||
go 1.22
|
||||
|
||||
toolchain go1.23.2
|
||||
|
||||
require (
|
||||
github.com/gin-gonic/gin v1.10.0
|
||||
github.com/golang-jwt/jwt/v4 v4.5.1
|
||||
github.com/google/uuid v1.6.0
|
||||
golang.org/x/crypto v0.32.0
|
||||
google.golang.org/grpc v1.62.1
|
||||
gopkg.in/yaml.v3 v3.0.1
|
||||
gorm.io/driver/postgres v1.5.11
|
||||
gorm.io/gorm v1.25.12
|
||||
@ -22,6 +25,7 @@ require (
|
||||
github.com/go-playground/universal-translator v0.18.1 // indirect
|
||||
github.com/go-playground/validator/v10 v10.23.0 // indirect
|
||||
github.com/goccy/go-json v0.10.4 // indirect
|
||||
github.com/golang/protobuf v1.5.3 // indirect
|
||||
github.com/jackc/pgpassfile v1.0.0 // indirect
|
||||
github.com/jackc/pgservicefile v0.0.0-20221227161230-091c0ba34f0a // indirect
|
||||
github.com/jackc/pgx/v5 v5.5.5 // indirect
|
||||
@ -44,5 +48,6 @@ require (
|
||||
golang.org/x/sync v0.10.0 // indirect
|
||||
golang.org/x/sys v0.29.0 // indirect
|
||||
golang.org/x/text v0.21.0 // indirect
|
||||
google.golang.org/protobuf v1.36.2 // indirect
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240123012728-ef4313101c80 // indirect
|
||||
google.golang.org/protobuf v1.34.1 // indirect
|
||||
)
|
||||
|
17
go.sum
17
go.sum
@ -28,8 +28,12 @@ github.com/goccy/go-json v0.10.4 h1:JSwxQzIqKfmFX1swYPpUThQZp/Ka4wzJdK0LWVytLPM=
|
||||
github.com/goccy/go-json v0.10.4/go.mod h1:oq7eo15ShAhp70Anwd5lgX2pLfOS3QCiwU/PULtXL6M=
|
||||
github.com/golang-jwt/jwt/v4 v4.5.1 h1:JdqV9zKUdtaa9gdPlywC3aeoEsR681PlKC+4F5gQgeo=
|
||||
github.com/golang-jwt/jwt/v4 v4.5.1/go.mod h1:m21LjoU+eqJr34lmDMbreY2eSTRJ1cv77w39/MY0Ch0=
|
||||
github.com/google/go-cmp v0.5.5 h1:Khx7svrCpmxxtHBq5j2mp/xVjsi8hQMfNLvJFAlrGgU=
|
||||
github.com/golang/protobuf v1.5.0/go.mod h1:FsONVRAS9T7sI+LIUmWTfcYkHO4aIWwzhcaSAoJOfIk=
|
||||
github.com/golang/protobuf v1.5.3 h1:KhyjKVUg7Usr/dYsdSqoFveMYd5ko72D+zANwlG1mmg=
|
||||
github.com/golang/protobuf v1.5.3/go.mod h1:XVQd3VNwM+JqD3oG2Ue2ip4fOMUkwXdXDdiuN0vRsmY=
|
||||
github.com/google/go-cmp v0.5.5/go.mod h1:v8dTdLbMG2kIc/vJvl+f65V22dbkXbowE6jgT/gNBxE=
|
||||
github.com/google/go-cmp v0.6.0 h1:ofyhxvXcZhMsU5ulbFiLKl/XBFqE1GSq7atu8tAmTRI=
|
||||
github.com/google/go-cmp v0.6.0/go.mod h1:17dUlkBOakJ0+DkrSSNjCkIjxS6bF9zb3elmeNGIjoY=
|
||||
github.com/google/gofuzz v1.0.0/go.mod h1:dBl0BpW6vV/+mYPU4Po3pmUjxk6FQPldtuIdl/M65Eg=
|
||||
github.com/google/uuid v1.6.0 h1:NIvaJDMOsjHA8n1jAhLSgzrAzy1Hgr+hNrb57e+94F0=
|
||||
github.com/google/uuid v1.6.0/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo=
|
||||
@ -99,10 +103,15 @@ golang.org/x/sys v0.29.0 h1:TPYlXGxvx1MGTn2GiZDhnjPA9wZzZeGKHHmKhHYvgaU=
|
||||
golang.org/x/sys v0.29.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
|
||||
golang.org/x/text v0.21.0 h1:zyQAAkrwaneQ066sspRyJaG9VNi/YJ1NfzcGB3hZ/qo=
|
||||
golang.org/x/text v0.21.0/go.mod h1:4IBbMaMmOPCJ8SecivzSH54+73PCFmPWxNTLm+vZkEQ=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543 h1:E7g+9GITq07hpfrRu66IVDexMakfv52eLZ2CXBWiKr4=
|
||||
golang.org/x/xerrors v0.0.0-20191204190536-9bdfabe68543/go.mod h1:I/5z698sn9Ka8TeJc9MKroUUfqBBauWjQqLJ2OPfmY0=
|
||||
google.golang.org/protobuf v1.36.2 h1:R8FeyR1/eLmkutZOM5CWghmo5itiG9z0ktFlTVLuTmU=
|
||||
google.golang.org/protobuf v1.36.2/go.mod h1:9fA7Ob0pmnwhb644+1+CVWFRbNajQ6iRojtC/QF5bRE=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240123012728-ef4313101c80 h1:AjyfHzEPEFp/NpvfN5g+KDla3EMojjhRVZc1i7cj+oM=
|
||||
google.golang.org/genproto/googleapis/rpc v0.0.0-20240123012728-ef4313101c80/go.mod h1:PAREbraiVEVGVdTZsVWjSbbTtSyGbAgIIvni8a8CD5s=
|
||||
google.golang.org/grpc v1.62.1 h1:B4n+nfKzOICUXMgyrNd19h/I9oH0L1pizfk1d4zSgTk=
|
||||
google.golang.org/grpc v1.62.1/go.mod h1:IWTG0VlJLCh1SkC58F7np9ka9mx/WNkjl4PGJaiq+QE=
|
||||
google.golang.org/protobuf v1.26.0-rc.1/go.mod h1:jlhhOSvTdKEhbULTjvd4ARK9grFBp09yW+WbY/TyQbw=
|
||||
google.golang.org/protobuf v1.26.0/go.mod h1:9q0QmTI4eRPtz6boOQmLYwt+qCgq0jsYwAQnmE0givc=
|
||||
google.golang.org/protobuf v1.34.1 h1:9ddQBjfCyZPOHPUiPxpYESBLc+T8P3E+Vo4IbKZgFWg=
|
||||
google.golang.org/protobuf v1.34.1/go.mod h1:c6P6GXX6sHbq/GpV6MGZEdwhWPcYBgnhAHhKbcUYpos=
|
||||
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c h1:Hei/4ADfdWqJk1ZMxUNpqntNwaWcugrBjAiHlqqRiVk=
|
||||
gopkg.in/check.v1 v1.0.0-20201130134442-10cb98267c6c/go.mod h1:JHkPIbrfpd72SG/EVd6muEfDQjcINNoR0C8j2r3qZ4Q=
|
||||
|
Binary file not shown.
33
main.go
33
main.go
@ -2,10 +2,14 @@ package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"ly.user_center/config"
|
||||
"ly.user_center/models"
|
||||
"ly.user_center/routes"
|
||||
"net"
|
||||
|
||||
"ly-user-center/config"
|
||||
"ly-user-center/models"
|
||||
|
||||
"ly-user-center/services"
|
||||
|
||||
"google.golang.org/grpc"
|
||||
"gorm.io/driver/postgres"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
@ -14,8 +18,6 @@ func main() {
|
||||
// 加载配置文件
|
||||
dsn := config.LoadDBConfig()
|
||||
|
||||
fmt.Println("dsn======", dsn)
|
||||
|
||||
db, err := gorm.Open(postgres.Open(dsn), &gorm.Config{})
|
||||
if err != nil {
|
||||
panic("数据库初始化连接失败: " + err.Error())
|
||||
@ -23,10 +25,23 @@ func main() {
|
||||
|
||||
// 自动迁移
|
||||
models.MigrateUser(db)
|
||||
models.MigrateOAuth(db)
|
||||
|
||||
// 初始化路由
|
||||
router := routes.SetupRoutes(db)
|
||||
// 创建 gRPC 服务器
|
||||
grpcServer := grpc.NewServer()
|
||||
|
||||
// 启动服务器
|
||||
router.Run(":8999")
|
||||
// 注册用户服务
|
||||
userService := services.NewUserService(db)
|
||||
pb.RegisterUserServiceServer(grpcServer, userService)
|
||||
|
||||
// 启动 gRPC 服务器
|
||||
lis, err := net.Listen("tcp", ":50051")
|
||||
if err != nil {
|
||||
panic(fmt.Sprintf("failed to listen: %v", err))
|
||||
}
|
||||
|
||||
fmt.Println("gRPC server is running on :50051")
|
||||
if err := grpcServer.Serve(lis); err != nil {
|
||||
panic(fmt.Sprintf("failed to serve: %v", err))
|
||||
}
|
||||
}
|
||||
|
@ -1,7 +1,7 @@
|
||||
package middlewares
|
||||
|
||||
import (
|
||||
"ly.user_center/utils"
|
||||
"ly-user-center/utils"
|
||||
"net/http"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
|
35
middlewares/recovery.go
Normal file
35
middlewares/recovery.go
Normal file
@ -0,0 +1,35 @@
|
||||
package middlewares
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/http"
|
||||
"runtime/debug"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
)
|
||||
|
||||
// Recovery 全局异常处理中间件
|
||||
func Recovery() gin.HandlerFunc {
|
||||
return func(c *gin.Context) {
|
||||
defer func() {
|
||||
if err := recover(); err != nil {
|
||||
// 打印错误堆栈信息
|
||||
fmt.Printf("panic: %v\n", err)
|
||||
fmt.Printf("stack: %s\n", debug.Stack())
|
||||
|
||||
// 返回 500 错误
|
||||
c.JSON(http.StatusInternalServerError, gin.H{
|
||||
"code": http.StatusInternalServerError,
|
||||
"message": "服务器内部错误",
|
||||
"error": fmt.Sprint(err),
|
||||
})
|
||||
|
||||
// 终止后续中间件
|
||||
c.Abort()
|
||||
}
|
||||
}()
|
||||
|
||||
// 继续处理请求
|
||||
c.Next()
|
||||
}
|
||||
}
|
50
models/oauth.go
Normal file
50
models/oauth.go
Normal file
@ -0,0 +1,50 @@
|
||||
package models
|
||||
|
||||
import (
|
||||
"time"
|
||||
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
// OAuthClient OAuth 2.0 客户端
|
||||
type OAuthClient struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
ClientID string `gorm:"unique;not null"`
|
||||
ClientSecret string `gorm:"not null"`
|
||||
RedirectURIs string `gorm:"not null"` // 以逗号分隔的重定向URI列表
|
||||
Scopes string `gorm:"not null"` // 以逗号分隔的授权范围列表
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// OAuthAccessToken 访问令牌
|
||||
type OAuthAccessToken struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
UserID string `gorm:"not null"`
|
||||
ClientID string `gorm:"not null"`
|
||||
Token string `gorm:"unique;not null"`
|
||||
Scope string
|
||||
ExpiresAt time.Time
|
||||
RefreshToken string `gorm:"unique"`
|
||||
CreatedAt time.Time
|
||||
UpdatedAt time.Time
|
||||
}
|
||||
|
||||
// OAuthAuthorizationCode 授权码
|
||||
type OAuthAuthorizationCode struct {
|
||||
ID string `gorm:"primaryKey"`
|
||||
UserID string `gorm:"not null"`
|
||||
ClientID string `gorm:"not null"`
|
||||
Code string `gorm:"unique;not null"`
|
||||
Scope string
|
||||
ExpiresAt time.Time
|
||||
CreatedAt time.Time
|
||||
}
|
||||
|
||||
// MigrateOAuth 初始化 OAuth 相关表
|
||||
func MigrateOAuth(db *gorm.DB) {
|
||||
err := db.AutoMigrate(&OAuthClient{}, &OAuthAccessToken{}, &OAuthAuthorizationCode{})
|
||||
if err != nil {
|
||||
panic("failed to migrate OAuth models: " + err.Error())
|
||||
}
|
||||
}
|
61
proto/user.proto
Normal file
61
proto/user.proto
Normal file
@ -0,0 +1,61 @@
|
||||
syntax = "proto3";
|
||||
|
||||
package pb;
|
||||
|
||||
option go_package = "ly-user-center/proto/pb";
|
||||
|
||||
// 用户服务定义
|
||||
service UserService {
|
||||
// 用户注册
|
||||
rpc Register(RegisterRequest) returns (RegisterResponse) {}
|
||||
// 用户登录
|
||||
rpc Login(LoginRequest) returns (LoginResponse) {}
|
||||
// 刷新令牌
|
||||
rpc RefreshToken(RefreshTokenRequest) returns (RefreshTokenResponse) {}
|
||||
}
|
||||
|
||||
// 性别枚举
|
||||
enum Gender {
|
||||
FEMALE = 0;
|
||||
MALE = 1;
|
||||
}
|
||||
|
||||
// 注册请求
|
||||
message RegisterRequest {
|
||||
string username = 1;
|
||||
string password = 2;
|
||||
string email = 3;
|
||||
string phone = 4;
|
||||
Gender sex = 5;
|
||||
}
|
||||
|
||||
// 注册响应
|
||||
message RegisterResponse {
|
||||
bool success = 1;
|
||||
string message = 2;
|
||||
}
|
||||
|
||||
// 登录请求
|
||||
message LoginRequest {
|
||||
string username = 1;
|
||||
string password = 2;
|
||||
}
|
||||
|
||||
// 登录响应
|
||||
message LoginResponse {
|
||||
string token = 1;
|
||||
int64 expiration_time = 2;
|
||||
string message = 3;
|
||||
}
|
||||
|
||||
// 刷新令牌请求
|
||||
message RefreshTokenRequest {
|
||||
string token = 1;
|
||||
}
|
||||
|
||||
// 刷新令牌响应
|
||||
message RefreshTokenResponse {
|
||||
string token = 1;
|
||||
int64 expiration_time = 2;
|
||||
string message = 3;
|
||||
}
|
@ -1,8 +1,8 @@
|
||||
package routes
|
||||
|
||||
import (
|
||||
"ly.user_center/controllers"
|
||||
"ly.user_center/middlewares"
|
||||
"ly-user-center/controllers"
|
||||
"ly-user-center/middlewares"
|
||||
|
||||
"github.com/gin-gonic/gin"
|
||||
"gorm.io/gorm"
|
||||
@ -11,28 +11,25 @@ import (
|
||||
func SetupRoutes(db *gorm.DB) *gin.Engine {
|
||||
router := gin.Default()
|
||||
|
||||
// 公共路由组
|
||||
publicRouter := router.Group("/users")
|
||||
// 添加全局异常处理中间件
|
||||
router.Use(middlewares.Recovery())
|
||||
|
||||
// OAuth 2.0 端点
|
||||
oauth := router.Group("/oauth")
|
||||
{
|
||||
// 用户注册和登录
|
||||
publicRouter.POST("/register", controllers.Register(db))
|
||||
publicRouter.POST("/login", controllers.Login(db))
|
||||
publicRouter.POST("/refreshToken", controllers.RefreshToken(db))
|
||||
oauth.GET("/authorize", middlewares.AuthMiddleware(), controllers.Authorize(db))
|
||||
oauth.POST("/token", controllers.Token(db))
|
||||
}
|
||||
|
||||
// 受保护的路由组
|
||||
protected := router.Group("/api")
|
||||
protected.Use(middlewares.AuthMiddleware())
|
||||
// 注册路由
|
||||
router.POST("/register", controllers.Register(db))
|
||||
router.POST("/login", controllers.Login(db))
|
||||
|
||||
// 需要认证的路由组
|
||||
authorized := router.Group("/")
|
||||
authorized.Use(middlewares.AuthMiddleware())
|
||||
{
|
||||
protected.GET("/profile", func(c *gin.Context) {
|
||||
// 从中间件获取用户ID
|
||||
userID, exists := c.Get("user_id")
|
||||
if !exists {
|
||||
c.JSON(401, gin.H{"error": "未授权"})
|
||||
return
|
||||
}
|
||||
c.JSON(200, gin.H{"message": "访问成功", "user_id": userID})
|
||||
})
|
||||
authorized.POST("/refresh-token", controllers.RefreshToken(db))
|
||||
}
|
||||
|
||||
return router
|
||||
|
124
services/user_service.go
Normal file
124
services/user_service.go
Normal file
@ -0,0 +1,124 @@
|
||||
package services
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"ly-user-center/models"
|
||||
"ly-user-center/proto/pb"
|
||||
"ly-user-center/utils"
|
||||
|
||||
"golang.org/x/crypto/bcrypt"
|
||||
"google.golang.org/grpc/codes"
|
||||
"google.golang.org/grpc/status"
|
||||
"gorm.io/gorm"
|
||||
)
|
||||
|
||||
type UserService struct {
|
||||
pb.UnimplementedUserServiceServer
|
||||
db *gorm.DB
|
||||
}
|
||||
|
||||
func NewUserService(db *gorm.DB) *UserService {
|
||||
return &UserService{db: db}
|
||||
}
|
||||
|
||||
// Register 实现用户注册
|
||||
func (s *UserService) Register(ctx context.Context, req *pb.RegisterRequest) (*pb.RegisterResponse, error) {
|
||||
// 开始事务
|
||||
tx := s.db.Begin()
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
tx.Rollback()
|
||||
}
|
||||
}()
|
||||
|
||||
// 检查用户名、邮箱和手机号是否重复
|
||||
var existingUser models.User
|
||||
if err := tx.Where("user_name = ? OR email = ? OR phone = ?",
|
||||
req.Username, req.Email, req.Phone).First(&existingUser).Error; err == nil {
|
||||
tx.Rollback()
|
||||
return nil, status.Error(codes.AlreadyExists, "用户名、邮箱或手机号已存在")
|
||||
}
|
||||
|
||||
// 密码哈希
|
||||
hashedPassword, err := bcrypt.GenerateFromPassword([]byte(req.Password), bcrypt.DefaultCost)
|
||||
if err != nil {
|
||||
tx.Rollback()
|
||||
return nil, status.Error(codes.Internal, "密码加密失败")
|
||||
}
|
||||
|
||||
// 创建用户
|
||||
user := models.User{
|
||||
Username: req.Username,
|
||||
Password: string(hashedPassword),
|
||||
Email: req.Email,
|
||||
Phone: req.Phone,
|
||||
Sex: models.Gender(req.Sex),
|
||||
}
|
||||
|
||||
if err := tx.Create(&user).Error; err != nil {
|
||||
tx.Rollback()
|
||||
return nil, status.Error(codes.Internal, "创建用户失败")
|
||||
}
|
||||
|
||||
// 提交事务
|
||||
if err := tx.Commit().Error; err != nil {
|
||||
return nil, status.Error(codes.Internal, "提交事务失败")
|
||||
}
|
||||
|
||||
return &pb.RegisterResponse{
|
||||
Success: true,
|
||||
Message: "用户创建成功",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Login 实现用户登录
|
||||
func (s *UserService) Login(ctx context.Context, req *pb.LoginRequest) (*pb.LoginResponse, error) {
|
||||
var user models.User
|
||||
if err := s.db.Where("user_name = ?", req.Username).First(&user).Error; err != nil {
|
||||
return nil, status.Error(codes.NotFound, "用户名或密码错误")
|
||||
}
|
||||
|
||||
if err := bcrypt.CompareHashAndPassword([]byte(user.Password), []byte(req.Password)); err != nil {
|
||||
return nil, status.Error(codes.Unauthenticated, "用户名或密码错误")
|
||||
}
|
||||
|
||||
// 生成 JWT
|
||||
token, expirationTime, err := utils.GenerateToken(user.ID)
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.Internal, "生成令牌失败")
|
||||
}
|
||||
|
||||
return &pb.LoginResponse{
|
||||
Token: token,
|
||||
ExpirationTime: expirationTime,
|
||||
Message: "登录成功",
|
||||
}, nil
|
||||
}
|
||||
|
||||
// RefreshToken 实现令牌刷新
|
||||
func (s *UserService) RefreshToken(ctx context.Context, req *pb.RefreshTokenRequest) (*pb.RefreshTokenResponse, error) {
|
||||
// 验证旧令牌
|
||||
claims, err := utils.ParseToken(req.Token)
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.Unauthenticated, "令牌无效或已过期")
|
||||
}
|
||||
|
||||
// 查询用户是否存在
|
||||
var user models.User
|
||||
if err := s.db.Where("id = ?", claims.UserID).First(&user).Error; err != nil {
|
||||
return nil, status.Error(codes.NotFound, "用户不存在")
|
||||
}
|
||||
|
||||
// 生成新令牌
|
||||
newToken, newExpirationTime, err := utils.GenerateToken(user.ID)
|
||||
if err != nil {
|
||||
return nil, status.Error(codes.Internal, "生成新令牌失败")
|
||||
}
|
||||
|
||||
return &pb.RefreshTokenResponse{
|
||||
Token: newToken,
|
||||
ExpirationTime: newExpirationTime,
|
||||
Message: "令牌刷新成功",
|
||||
}, nil
|
||||
}
|
Loading…
Reference in New Issue
Block a user