中间件模式
中间件(Middleware)是 Web 开发中一种强大的设计模式。它在请求到达最终处理器之前或响应返回客户端之后执行通用逻辑,如日志记录、身份认证、请求限流、错误恢复等。Go 语言的函数式特性使得中间件的实现格外优雅。
中间件概念与实现原理
中间件是一个函数,接收一个 http.Handler 作为参数,返回一个新的 http.Handler。它可以在请求处理前后插入自定义逻辑,形成”洋葱模型”的执行链。
// 中间件的函数签名
type Middleware func(http.Handler) http.Handler
// 示例:一个最简单的中间件
func simpleMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// ===== 请求处理前 =====
fmt.Println("请求到达中间件")
// 调用下一个处理器
next.ServeHTTP(w, r)
// ===== 响应处理后 =====
fmt.Println("响应返回后")
})
}中间件的执行流程遵循洋葱模型(洋葱圈模型):
请求 → 中间件A(前) → 中间件B(前) → 中间件C(前) → Handler
响应 ← 中间件A(后) ← 中间件B(后) ← 中间件C(后) ← Handlerpackage main
import (
"fmt"
"net/http"
"strings"
"time"
)
// --- 中间件定义 ---
// 日志中间件
func loggingMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
fmt.Printf("[%s] 开始处理: %s %s\n", start.Format(time.RFC3339), r.Method, r.URL.Path)
next.ServeHTTP(w, r)
fmt.Printf("处理完成,耗时: %v\n", time.Since(start))
})
}
// 请求 ID 中间件
func requestIDMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
requestID := r.Header.Get("X-Request-ID")
if requestID == "" {
requestID = generateUUID()
}
w.Header().Set("X-Request-ID", requestID)
// 将 requestID 存入 Context,供后续处理器使用
ctx := context.WithValue(r.Context(), "requestID", requestID)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
func generateUUID() string {
// 简化的 UUID 生成,生产环境建议使用 github.com/google/uuid
return fmt.Sprintf("%d", time.Now().UnixNano())
}
// --- 注册路由 ---
func main() {
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "Hello, World!")
})
// 逐层包装(洋葱模型)
handler := loggingMiddleware(requestIDMiddleware(mux))
http.ListenAndServe(":8080", handler)
}日志中间件
package middleware
import (
"log"
"net/http"
"time"
)
// 自定义 ResponseWriter 包装器,用于捕获状态码
type responseWriter struct {
http.ResponseWriter
statusCode int
written int
}
func (rw *responseWriter) WriteHeader(code int) {
rw.statusCode = code
rw.ResponseWriter.WriteHeader(code)
}
func (rw *responseWriter) Write(b []byte) (int, error) {
n, err := rw.ResponseWriter.Write(b)
rw.written += n
return n, err
}
// RequestLogger 请求日志中间件
func RequestLogger(logger *log.Logger) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
// 包装 ResponseWriter 以捕获状态码
rw := &responseWriter{ResponseWriter: w, statusCode: http.StatusOK}
next.ServeHTTP(rw, r)
// 记录请求日志
logger.Printf(
"[%s] %s %s → %d (%d bytes) in %v",
r.RemoteAddr,
r.Method,
r.URL.Path,
rw.statusCode,
rw.written,
time.Since(start),
)
})
}
}
// 使用示例
func main() {
logger := log.New(os.Stdout, "[API] ", log.LstdFlags)
mux := http.NewServeMux()
mux.HandleFunc("/api/users", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`{"users":[]}`))
})
// 应用日志中间件
handler := RequestLogger(logger)(mux)
http.ListenAndServe(":8080", handler)
}
// 日志输出示例:
// [API] 2025/01/15 10:30:00 127.0.0.1:54321 GET /api/users → 200 (12 bytes) in 1.2ms认证与授权中间件
JWT 是一种开放的行业标准(RFC 7519),用于在各方之间安全地传输信息。它由三部分组成:Header(头部)、Payload(载荷)和 Signature(签名),用 . 分隔。
package middleware
import (
"context"
"encoding/json"
"fmt"
"net/http"
"strings"
"time"
"github.com/golang-jwt/jwt/v5"
)
var jwtSecret = []byte("your-secret-key-change-in-production")
type Claims struct {
UserID int `json:"user_id"`
Username string `json:"username"`
Role string `json:"role"`
jwt.RegisteredClaims
}
// GenerateToken 生成 JWT Token
func GenerateToken(userID int, username, role string) (string, error) {
claims := Claims{
UserID: userID,
Username: username,
Role: role,
RegisteredClaims: jwt.RegisteredClaims{
ExpiresAt: jwt.NewNumericDate(time.Now().Add(24 * time.Hour)),
IssuedAt: jwt.NewNumericDate(time.Now()),
Issuer: "my-app",
},
}
token := jwt.NewWithClaims(jwt.SigningMethodHS256, claims)
return token.SignedString(jwtSecret)
}
// AuthMiddleware JWT 认证中间件
func AuthMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 从 Header 提取 Token
authHeader := r.Header.Get("Authorization")
if authHeader == "" {
writeAuthError(w, http.StatusUnauthorized, "缺少认证令牌")
return
}
// 解析 Bearer Token
parts := strings.SplitN(authHeader, " ", 2)
if len(parts) != 2 || parts[0] != "Bearer" {
writeAuthError(w, http.StatusUnauthorized, "认证格式错误")
return
}
tokenString := parts[1]
// 解析并验证 Token
token, err := jwt.ParseWithClaims(tokenString, &Claims{}, func(token *jwt.Token) (interface{}, error) {
if _, ok := token.Method.(*jwt.SigningMethodHMAC); !ok {
return nil, fmt.Errorf("签名算法不正确")
}
return jwtSecret, nil
})
if err != nil || !token.Valid {
writeAuthError(w, http.StatusUnauthorized, "认证令牌无效或已过期")
return
}
// 将用户信息存入 Context
claims, ok := token.Claims.(*Claims)
if !ok {
writeAuthError(w, http.StatusUnauthorized, "令牌解析失败")
return
}
ctx := context.WithValue(r.Context(), "userID", claims.UserID)
ctx = context.WithValue(ctx, "username", claims.Username)
ctx = context.WithValue(ctx, "role", claims.Role)
next.ServeHTTP(w, r.WithContext(ctx))
})
}
// RequireRole 角色授权中间件
func RequireRole(roles ...string) func(http.Handler) http.Handler {
roleSet := make(map[string]bool)
for _, r := range roles {
roleSet[r] = true
}
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
role, ok := r.Context().Value("role").(string)
if !ok || !roleSet[role] {
writeAuthError(w, http.StatusForbidden, "权限不足")
return
}
next.ServeHTTP(w, r)
})
}
}
func writeAuthError(w http.ResponseWriter, statusCode int, message string) {
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(statusCode)
json.NewEncoder(w).Encode(map[string]string{
"error": message,
})
}
// --- 路由注册示例 ---
func main() {
mux := http.NewServeMux()
// 公开路由
mux.HandleFunc("POST /api/auth/login", loginHandler)
// 需要认证的路由
authMux := http.NewServeMux()
authMux.HandleFunc("GET /api/profile", profileHandler)
// 需要管理员权限的路由
adminMux := http.NewServeMux()
adminMux.HandleFunc("DELETE /api/users/{id}", deleteUserHandler)
// 应用中间件
mux.Handle("/api/profile", AuthMiddleware(authMux))
mux.Handle("/api/users/", AuthMiddleware(RequireRole("admin")(adminMux)))
http.ListenAndServe(":8080", mux)
}请求限流中间件
令牌桶算法是一种常用的限流算法。系统以固定速率向桶中放入令牌,请求需要消耗令牌才能被处理。桶满时多余的令牌被丢弃,桶空时请求被拒绝或等待。
package middleware
import (
"net/http"
"sync"
"time"
)
// TokenBucket 令牌桶实现
type TokenBucket struct {
mu sync.Mutex
tokens int // 当前令牌数
maxTokens int // 桶容量
rate int // 每秒补充的令牌数
lastRefill time.Time // 上次补充时间
}
func NewTokenBucket(rate, maxTokens int) *TokenBucket {
return &TokenBucket{
tokens: maxTokens,
maxTokens: maxTokens,
rate: rate,
lastRefill: time.Now(),
}
}
// Allow 检查是否允许请求通过
func (tb *TokenBucket) Allow() bool {
tb.mu.Lock()
defer tb.mu.Unlock()
// 补充令牌
now := time.Now()
elapsed := now.Sub(tb.lastRefill)
refillTokens := int(elapsed.Seconds()) * tb.rate
if refillTokens > 0 {
tb.tokens += refillTokens
if tb.tokens > tb.maxTokens {
tb.tokens = tb.maxTokens
}
tb.lastRefill = now
}
// 消耗令牌
if tb.tokens > 0 {
tb.tokens--
return true
}
return false
}
// RateLimiter 限流中间件
func RateLimiter(bucket *TokenBucket) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if !bucket.Allow() {
w.Header().Set("Content-Type", "application/json")
w.Header().Set("Retry-After", "1")
w.WriteHeader(http.StatusTooManyRequests) // 429
w.Write([]byte(`{"error":"请求过于频繁,请稍后再试"}`))
return
}
next.ServeHTTP(w, r)
})
}
}
// 按用户/IP限流
type IPRateLimiter struct {
mu sync.Mutex
buckets map[string]*TokenBucket
rate int
maxBurst int
}
func NewIPRateLimiter(rate, maxBurst int) *IPRateLimiter {
return &IPRateLimiter{
buckets: make(map[string]*TokenBucket),
rate: rate,
maxBurst: maxBurst,
}
}
func (rl *IPRateLimiter) GetBucket(ip string) *TokenBucket {
rl.mu.Lock()
defer rl.mu.Unlock()
bucket, ok := rl.buckets[ip]
if !ok {
bucket = NewTokenBucket(rl.rate, rl.maxBurst)
rl.buckets[ip] = bucket
}
return bucket
}
func (rl *IPRateLimiter) Middleware() func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ip := strings.Split(r.RemoteAddr, ":")[0]
bucket := rl.GetBucket(ip)
if !bucket.Allow() {
w.WriteHeader(http.StatusTooManyRequests)
w.Write([]byte(`{"error":"请求过于频繁"}`))
return
}
next.ServeHTTP(w, r)
})
}
}
// 使用示例:每秒允许10个请求,突发最多20个
func main() {
limiter := NewIPRateLimiter(10, 20)
mux := http.NewServeMux()
mux.HandleFunc("/api/data", dataHandler)
http.ListenAndServe(":8080", limiter.Middleware()(mux))
}恐慌恢复中间件
当处理器函数中发生 panic 时,如果不捕获,整个服务器进程会崩溃。Panic Recovery 中间件使用 defer + recover 捕获 panic,将错误记录到日志并返回 500 响应,保证服务器持续运行。
package middleware
import (
"log"
"net/http"
"runtime/debug"
)
// Recovery 恐慌恢复中间件
func Recovery(logger *log.Logger) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
defer func() {
if err := recover(); err != nil {
// 记录堆栈信息
stack := debug.Stack()
logger.Printf("[PANIC] %s %s\n错误: %v\n堆栈:\n%s",
r.Method, r.URL.Path, err, string(stack))
// 返回 500 错误
w.Header().Set("Content-Type", "application/json")
w.WriteHeader(http.StatusInternalServerError)
w.Write([]byte(`{
"code": 500,
"message": "服务器内部错误"
}`))
}
}()
next.ServeHTTP(w, r)
})
}
}
// 使用示例:处理器中发生 panic 不会导致服务器崩溃
func unsafeHandler(w http.ResponseWriter, r *http.Request) {
// 这个 panic 会被 Recovery 中间件捕获
panic("something went wrong!")
}
func main() {
logger := log.New(os.Stdout, "[APP] ", log.LstdFlags)
mux := http.NewServeMux()
mux.HandleFunc("/unsafe", unsafeHandler)
// Recovery 通常是最外层的中间件
handler := Recovery(logger)(mux)
http.ListenAndServe(":8080", handler)
}Recovery 中间件应放在中间件链的最外层,确保所有内层处理器的 panic 都能被捕获。同时注意:recover() 只能在 defer 函数中生效。
中间件链式组合
在实际项目中,我们通常需要同时使用多个中间件。Go 语言中常见的链式组合方式有逐层包装和链式调用两种模式:
逐层包装模式
package main
import (
"fmt"
"net/http"
)
// Chain 将多个中间件从右到左依次应用到 handler 上
// 最终执行顺序:logging → auth → recovery → handler
func Chain(handler http.Handler, middlewares ...func(http.Handler) http.Handler) http.Handler {
// 从后向前包装
for i := len(middlewares) - 1; i >= 0; i-- {
handler = middlewares[i](handler)
}
return handler
}
func main() {
mux := http.NewServeMux()
mux.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) {
fmt.Fprintf(w, "Hello, World!")
})
handler := Chain(mux,
RequestLogger(logger), // 外层:最先进入,最后退出
AuthMiddleware, // 中层
Recovery(logger), // 内层(最先处理 panic)
)
http.ListenAndServe(":8080", handler)
}按路由应用中间件
func main() {
mux := http.NewServeMux()
// 全局中间件(应用于所有路由)
globalMux := Chain(mux,
Recovery(logger),
RequestLogger(logger),
)
// 公开路由
mux.HandleFunc("POST /api/auth/login", loginHandler)
mux.HandleFunc("POST /api/auth/register", registerHandler)
// 需要认证的路由组
authMux := http.NewServeMux()
authMux.HandleFunc("GET /api/profile", profileHandler)
authMux.HandleFunc("POST /api/todos", createTodoHandler)
authMux.HandleFunc("GET /api/todos", listTodoHandler)
// 对 /api/ 前缀的认证路由应用认证中间件
mux.Handle("/api/profile", AuthMiddleware(authMux))
mux.Handle("/api/todos", AuthMiddleware(RateLimiter(limiter)(authMux)))
http.ListenAndServe(":8080", globalMux)
}第三方中间件库推荐
import (
"github.com/rs/cors" // CORS 跨域处理
"github.com/justinas/alice" // 中间件链优雅组合
"github.com/prometheus/client_golang" // Prometheus 指标收集
"github.com/go-chi/chi/v5/middleware" // 常用中间件集
"compress/gzip" // 内置 gzip 压缩
)Alice — 中间件链组合
import "github.com/justinas/alice"
func main() {
mux := http.NewServeMux()
mux.HandleFunc("/", homeHandler)
// Alice 从左到右组合中间件
chain := alice.New(
Recovery(logger),
RequestLogger(logger),
corsMiddleware,
).Then(mux)
http.ListenAndServe(":8080", chain)
}CORS 中间件(rs/cors)
import "github.com/rs/cors"
func main() {
mux := http.NewServeMux()
mux.HandleFunc("/api/users", usersHandler)
c := cors.New(cors.Options{
AllowedOrigins: []string{"http://localhost:3000"},
AllowedMethods: []string{"GET", "POST", "PUT", "DELETE"},
AllowedHeaders: []string{"Content-Type", "Authorization"},
AllowCredentials: true,
MaxAge: 86400,
})
http.ListenAndServe(":8080", c.Handler(mux))
}Gzip 压缩中间件
import (
"compress/gzip"
"net/http"
"strings"
)
func GzipMiddleware(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// 检查客户端是否支持 gzip
if !strings.Contains(r.Header.Get("Accept-Encoding"), "gzip") {
next.ServeHTTP(w, r)
return
}
// 包装 ResponseWriter
gz := gzip.NewWriter(w)
defer gz.Close()
w.Header().Set("Content-Encoding", "gzip")
w.Header().Del("Content-Length") // 内容长度会变化
gzw := &gzipResponseWriter{Writer: gz, ResponseWriter: w}
next.ServeHTTP(gzw, r)
})
}
type gzipResponseWriter struct {
http.ResponseWriter
Writer *gzip.Writer
}
func (gzw *gzipResponseWriter) Write(b []byte) (int, error) {
return gzw.Writer.Write(b)
}练习题
练习 1:实现请求计时中间件
实现一个中间件,记录每个请求的处理耗时。如果请求处理时间超过 500ms,在日志中输出警告(WARN),否则输出信息(INFO)。
package main
import (
"fmt"
"log"
"net/http"
"os"
"time"
)
type responseWriterWrapper struct {
http.ResponseWriter
statusCode int
}
func (rw *responseWriterWrapper) WriteHeader(code int) {
rw.statusCode = code
rw.ResponseWriter.WriteHeader(code)
}
func TimingMiddleware(logger *log.Logger) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
start := time.Now()
rw := &responseWriterWrapper{ResponseWriter: w, statusCode: 200}
next.ServeHTTP(rw, r)
duration := time.Since(start)
if duration > 500*time.Millisecond {
logger.Printf("[WARN] %s %s → %d (slow: %v)",
r.Method, r.URL.Path, rw.statusCode, duration)
} else {
logger.Printf("[INFO] %s %s → %d (%v)",
r.Method, r.URL.Path, rw.statusCode, duration)
}
})
}
}
func main() {
logger := log.New(os.Stdout, "", log.LstdFlags)
mux := http.NewServeMux()
mux.HandleFunc("/fast", func(w http.ResponseWriter, r *http.Request) {
time.Sleep(50 * time.Millisecond)
w.Write([]byte("fast"))
})
mux.HandleFunc("/slow", func(w http.ResponseWriter, r *http.Request) {
time.Sleep(600 * time.Millisecond)
w.Write([]byte("slow"))
})
http.ListenAndServe(":8080", TimingMiddleware(logger)(mux))
}练习 2:实现滑动窗口限流中间件
实现一个基于滑动窗口的 IP 限流中间件:每个 IP 在 60 秒内最多允许 100 次请求。
package main
import (
"net/http"
"sync"
"time"
)
type SlidingWindow struct {
mu sync.Mutex
windows map[string][]time.Time
maxReqs int
windowSize time.Duration
}
func NewSlidingWindow(maxReqs int, windowSize time.Duration) *SlidingWindow {
sw := &SlidingWindow{
windows: make(map[string][]time.Time),
maxReqs: maxReqs,
windowSize: windowSize,
}
// 后台清理过期记录
go sw.cleanup()
return sw
}
func (sw *SlidingWindow) Allow(key string) bool {
sw.mu.Lock()
defer sw.mu.Unlock()
now := time.Now()
cutoff := now.Add(-sw.windowSize)
// 过滤掉过期的请求记录
timestamps := sw.windows[key]
valid := make([]time.Time, 0, len(timestamps))
for _, t := range timestamps {
if t.After(cutoff) {
valid = append(valid, t)
}
}
if len(valid) >= sw.maxReqs {
sw.windows[key] = valid
return false
}
valid = append(valid, now)
sw.windows[key] = valid
return true
}
func (sw *SlidingWindow) cleanup() {
ticker := time.NewTicker(1 * time.Minute)
for range ticker.C {
sw.mu.Lock()
cutoff := time.Now().Add(-sw.windowSize)
for key, timestamps := range sw.windows {
filtered := make([]time.Time, 0)
for _, t := range timestamps {
if t.After(cutoff) {
filtered = append(filtered, t)
}
}
if len(filtered) == 0 {
delete(sw.windows, key)
} else {
sw.windows[key] = filtered
}
}
sw.mu.Unlock()
}
}
func SlidingWindowMiddleware(sw *SlidingWindow) func(http.Handler) http.Handler {
return func(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
ip := r.RemoteAddr
if !sw.Allow(ip) {
w.Header().Set("Retry-After", "60")
w.WriteHeader(http.StatusTooManyRequests)
w.Write([]byte(`{"error":"rate limit exceeded"}`))
return
}
next.ServeHTTP(w, r)
})
}
}
func main() {
// 每IP每60秒最多100次请求
sw := NewSlidingWindow(100, 60*time.Second)
mux := http.NewServeMux()
mux.HandleFunc("/api/data", func(w http.ResponseWriter, r *http.Request) {
w.Write([]byte(`{"data":"ok"}`))
})
http.ListenAndServe(":8080", SlidingWindowMiddleware(sw)(mux))
}练习 3:组合多个中间件构建安全路由
使用 Recovery、Logging、Auth、RateLimiter 四个中间件,构建如下路由:
POST /api/auth/login— 仅 Recovery + LoggingGET /api/profile— Recovery + Logging + AuthDELETE /api/users/{id}— Recovery + Logging + Auth + AdminRole + RateLimiter
package main
import (
"fmt"
"log"
"net/http"
"os"
)
func Chain(handler http.Handler, middlewares ...func(http.Handler) http.Handler) http.Handler {
for i := len(middlewares) - 1; i >= 0; i-- {
handler = middlewares[i](handler)
}
return handler
}
func main() {
logger := log.New(os.Stdout, "[APP] ", log.LstdFlags)
// 创建中间件
recovery := Recovery(logger)
logging := RequestLogger(logger)
auth := AuthMiddleware
adminRole := RequireRole("admin")
rateLimiter := RateLimiter(NewTokenBucket(5, 10))
mux := http.NewServeMux()
// 公开路由:Recovery + Logging
mux.Handle("POST /api/auth/login", Chain(
http.HandlerFunc(loginHandler),
recovery, logging,
))
// 用户路由:Recovery + Logging + Auth
mux.Handle("GET /api/profile", Chain(
http.HandlerFunc(profileHandler),
recovery, logging, auth,
))
// 管理员路由:Recovery + Logging + Auth + AdminRole + RateLimiter
mux.Handle("DELETE /api/users/{id}", Chain(
http.HandlerFunc(deleteUserHandler),
recovery, logging, auth, adminRole, rateLimiter,
))
fmt.Println("Server running on :8080")
http.ListenAndServe(":8080", mux)
}