导航菜单

RESTful API 设计

RESTful(Representational State Transfer)是一种软件架构风格,为 Web API 的设计提供了一组约束和最佳实践。遵循 RESTful 原则设计的 API 具有清晰、一致、可预测的特点,便于前端消费和第三方集成。

RESTful 设计原则

RESTful 架构原则

RESTful API 围绕资源(Resource)进行设计,每个资源用一个 URL(URI)标识,通过 HTTP 动词(GET/POST/PUT/DELETE/PATCH)对资源执行操作,且每次请求必须包含理解该请求所需的所有信息(无状态)。

核心原则

  1. 资源用名词表示:URL 中只使用名词,不用动词
  2. HTTP 动词表示操作:用 GET/POST/PUT/DELETE/PATCH 表达意图
  3. 无状态:每个请求都包含处理所需的全部信息
  4. 统一接口:一致的 URL 结构和响应格式
✅ 正确示例:
GET    /users          # 获取用户列表
GET    /users/123      # 获取单个用户
POST   /users          # 创建用户
PUT    /users/123      # 全量更新用户
PATCH  /users/123      # 部分更新用户
DELETE /users/123      # 删除用户

❌ 错误示例:
GET    /getUsers
POST   /createUser
POST   /deleteUser/123
GET    /user/list

资源路由设计

标准 CRUD 路由

package main

import (
    "encoding/json"
    "net/http"
    "strconv"
    "strings"
    "sync"
)

type User struct {
    ID    int    `json:"id"`
    Name  string `json:"name"`
    Email string `json:"email"`
    Age   int    `json:"age"`
}

var (
    users  = []User{}
    nextID = 1
    mu     sync.Mutex
)

func userHandler(w http.ResponseWriter, r *http.Request) {
    // 解析路径中的 ID
    path := strings.TrimPrefix(r.URL.Path, "/users/")
    var id int
    if path != "" {
        var err error
        id, err = strconv.Atoi(path)
        if err != nil {
            writeError(w, http.StatusBadRequest, "无效的用户 ID")
            return
        }
    }

    switch r.Method {
    case http.MethodGet:
        if id == 0 {
            // 获取用户列表
            writeJSON(w, http.StatusOK, users)
        } else {
            // 获取单个用户
            user := findUser(id)
            if user == nil {
                writeError(w, http.StatusNotFound, "用户不存在")
                return
            }
            writeJSON(w, http.StatusOK, user)
        }

    case http.MethodPost:
        if id != 0 {
            writeError(w, http.StatusBadRequest, "POST 请求不应包含 ID")
            return
        }
        var user User
        if err := json.NewDecoder(r.Body).Decode(&user); err != nil {
            writeError(w, http.StatusBadRequest, "无效的请求体")
            return
        }
        mu.Lock()
        user.ID = nextID
        nextID++
        users = append(users, user)
        mu.Unlock()
        writeJSON(w, http.StatusCreated, user)

    case http.MethodPut:
        user := findUser(id)
        if user == nil {
            writeError(w, http.StatusNotFound, "用户不存在")
            return
        }
        var updated User
        if err := json.NewDecoder(r.Body).Decode(&updated); err != nil {
            writeError(w, http.StatusBadRequest, "无效的请求体")
            return
        }
        mu.Lock()
        updated.ID = id
        *user = updated
        mu.Unlock()
        writeJSON(w, http.StatusOK, user)

    case http.MethodDelete:
        user := findUser(id)
        if user == nil {
            writeError(w, http.StatusNotFound, "用户不存在")
            return
        }
        mu.Lock()
        for i, u := range users {
            if u.ID == id {
                users = append(users[:i], users[i+1:]...)
                break
            }
        }
        mu.Unlock()
        writeJSON(w, http.StatusOK, map[string]string{"message": "删除成功"})

    default:
        writeError(w, http.StatusMethodNotAllowed, "不支持的请求方法")
    }
}

func findUser(id int) *User {
    mu.Lock()
    defer mu.Unlock()
    for i := range users {
        if users[i].ID == id {
            return &users[i]
        }
    }
    return nil
}

请求参数验证

结构体 Binding Tag

通过结构体标签(struct tag)定义验证规则,可以在解码 JSON 时进行自动验证,确保请求参数满足业务要求。

package main

import (
    "encoding/json"
    "net/http"
    "regexp"
    "strings"
)

// 通过自定义 UnmarshalJSON 实现验证
type CreateUserRequest struct {
    Username string `json:"username"`
    Email    string `json:"email"`
    Password string `json:"password"`
    Age      int    `json:"age"`
}

func (r *CreateUserRequest) Validate() map[string]string {
    errors := make(map[string]string)

    if r.Username == "" {
        errors["username"] = "用户名不能为空"
    } else if len(r.Username) < 3 || len(r.Username) > 20 {
        errors["username"] = "用户名长度需在3-20之间"
    }

    if r.Email == "" {
        errors["email"] = "邮箱不能为空"
    } else if !isValidEmail(r.Email) {
        errors["email"] = "邮箱格式不正确"
    }

    if r.Password == "" {
        errors["password"] = "密码不能为空"
    } else if len(r.Password) < 8 {
        errors["password"] = "密码长度至少8位"
    }

    if r.Age < 0 || r.Age > 150 {
        errors["age"] = "年龄不合法"
    }

    return errors
}

func isValidEmail(email string) bool {
    pattern := `^[a-zA-Z0-9._%+\-]+@[a-zA-Z0-9.\-]+\.[a-zA-Z]{2,}$`
    matched, _ := regexp.MatchString(pattern, email)
    return matched
}

func createUserHandler(w http.ResponseWriter, r *http.Request) {
    var req CreateUserRequest
    if err := json.NewDecoder(r.Body).Decode(&req); err != nil {
        writeError(w, http.StatusBadRequest, "无效的请求数据")
        return
    }

    // 参数验证
    if errs := req.Validate(); len(errs) > 0 {
        writeJSON(w, http.StatusBadRequest, map[string]interface{}{
            "code":    "VALIDATION_ERROR",
            "message": "参数验证失败",
            "errors":  errs,
        })
        return
    }

    // 验证通过,创建用户...
    writeJSON(w, http.StatusCreated, map[string]string{"message": "用户创建成功"})
}

统一响应格式

统一响应格式

所有 API 返回相同结构的 JSON 响应,方便前端统一处理。成功和失败使用不同的字段组合,但保持顶层结构一致。

package response

import (
    "encoding/json"
    "net/http"
)

// 统一成功响应
type Response struct {
    Code    int         `json:"code"`
    Message string      `json:"message"`
    Data    interface{} `json:"data,omitempty"`
}

// 统一错误响应
type ErrorResponse struct {
    Code    int         `json:"code"`
    Message string      `json:"message"`
    Error   string      `json:"error,omitempty"`
    Details interface{} `json:"details,omitempty"`
}

// 成功响应
func Success(w http.ResponseWriter, data interface{}) {
    writeJSON(w, http.StatusOK, Response{
        Code:    0,
        Message: "success",
        Data:    data,
    })
}

// 成功响应(自定义消息)
func SuccessWithMessage(w http.ResponseWriter, message string, data interface{}) {
    writeJSON(w, http.StatusOK, Response{
        Code:    0,
        Message: message,
        Data:    data,
    })
}

// 创建成功响应
func Created(w http.ResponseWriter, data interface{}) {
    writeJSON(w, http.StatusCreated, Response{
        Code:    0,
        Message: "created",
        Data:    data,
    })
}

// 错误响应
func Error(w http.ResponseWriter, statusCode int, code int, message string) {
    writeJSON(w, statusCode, ErrorResponse{
        Code:    code,
        Message: message,
    })
}

// 带详情的错误响应
func ErrorWithDetails(w http.ResponseWriter, statusCode int, code int, message string, details interface{}) {
    writeJSON(w, statusCode, ErrorResponse{
        Code:    code,
        Message: message,
        Details: details,
    })
}

func writeJSON(w http.ResponseWriter, statusCode int, data interface{}) {
    w.Header().Set("Content-Type", "application/json; charset=utf-8")
    w.WriteHeader(statusCode)
    json.NewEncoder(w).Encode(data)
}

使用示例:

func getUsersHandler(w http.ResponseWriter, r *http.Request) {
    users := []User{{ID: 1, Name: "张三"}, {ID: 2, Name: "李四"}}
    response.Success(w, users)
}

// 成功响应示例:
// {
//   "code": 0,
//   "message": "success",
//   "data": [{"id": 1, "name": "张三"}, {"id": 2, "name": "李四"}]
// }

func getUserHandler(w http.ResponseWriter, r *http.Request) {
    // 用户不存在
    response.Error(w, http.StatusNotFound, 404, "用户不存在")
}

// 错误响应示例:
// {
//   "code": 404,
//   "message": "用户不存在"
// }

分页

type Pagination struct {
    Page     int   `json:"page"`      // 当前页码,从1开始
    PageSize int   `json:"page_size"` // 每页数量
    Total    int64 `json:"total"`     // 总记录数
}

type PageResult struct {
    Items []interface{} `json:"items"`
    Page  Pagination    `json:"page"`
}

func getPagination(r *http.Request) (page, pageSize int) {
    page, _ = strconv.Atoi(r.URL.Query().Get("page"))
    pageSize, _ = strconv.Atoi(r.URL.Query().Get("page_size"))

    // 设置默认值和边界
    if page < 1 {
        page = 1
    }
    if pageSize < 1 || pageSize > 100 {
        pageSize = 20
    }
    return
}

func listUsersHandler(w http.ResponseWriter, r *http.Request) {
    page, pageSize := getPagination(r)

    // 模拟数据库查询
    allUsers := getUsersFromDB()
    total := int64(len(allUsers))

    // 计算偏移量
    offset := (page - 1) * pageSize
    end := offset + pageSize
    if offset > len(allUsers) {
        offset = len(allUsers)
    }
    if end > len(allUsers) {
        end = len(allUsers)
    }

    items := allUsers[offset:end]

    writeJSON(w, http.StatusOK, map[string]interface{}{
        "items": items,
        "page": Pagination{
            Page:     page,
            PageSize: pageSize,
            Total:    total,
        },
    })
}

// 请求示例: GET /users?page=2&page_size=10
// 响应示例:
// {
//   "code": 0,
//   "message": "success",
//   "data": {
//     "items": [...],
//     "page": {
//       "page": 2,
//       "page_size": 10,
//       "total": 57
//     }
//   }
// }

过滤与排序

type QueryParams struct {
    Page     int
    PageSize int
    Filters  map[string]string
    SortBy   string
    SortOrder string // "asc" 或 "desc"
}

func parseQueryParams(r *http.Request) QueryParams {
    q := r.URL.Query()
    page, _ := strconv.Atoi(q.Get("page"))
    pageSize, _ := strconv.Atoi(q.Get("page_size"))
    sortBy := q.Get("sort_by")
    sortOrder := strings.ToLower(q.Get("sort_order"))

    if page < 1 { page = 1 }
    if pageSize < 1 || pageSize > 100 { pageSize = 20 }
    if sortOrder != "asc" && sortOrder != "desc" { sortOrder = "asc" }

    // 收集过滤参数(排除已知的非过滤参数)
    filters := make(map[string]string)
    reserved := map[string]bool{
        "page": true, "page_size": true,
        "sort_by": true, "sort_order": true,
    }
    for key, values := range q {
        if !reserved[key] && len(values) > 0 {
            filters[key] = values[0]
        }
    }

    return QueryParams{
        Page:      page,
        PageSize:  pageSize,
        Filters:   filters,
        SortBy:    sortBy,
        SortOrder: sortOrder,
    }
}

// SQL 查询构建示例
func buildQuery(table string, params QueryParams) (string, []interface{}) {
    var conditions []string
    var args []interface{}

    // 过滤条件
    for field, value := range params.Filters {
        conditions = append(conditions, fmt.Sprintf("%s = ?", field))
        args = append(args, value)
    }

    query := fmt.Sprintf("SELECT * FROM %s", table)
    if len(conditions) > 0 {
        query += " WHERE " + strings.Join(conditions, " AND ")
    }

    // 排序
    if params.SortBy != "" {
        query += fmt.Sprintf(" ORDER BY %s %s", params.SortBy, strings.ToUpper(params.SortOrder))
    }

    // 分页
    offset := (params.Page - 1) * params.PageSize
    query += " LIMIT ? OFFSET ?"
    args = append(args, params.PageSize, offset)

    return query, args
}

// 请求示例: GET /users?role=admin&age_gt=18&sort_by=created_at&sort_order=desc&page=1&page_size=10

API 版本管理

URL 路径版本(推荐)

func main() {
    mux := http.NewServeMux()

    // v1 API
    mux.HandleFunc("GET /api/v1/users", listUsersV1)
    mux.HandleFunc("POST /api/v1/users", createUserV1)

    // v2 API(如用户结构有变化)
    mux.HandleFunc("GET /api/v2/users", listUsersV2)
    mux.HandleFunc("POST /api/v2/users", createUserV2)

    http.ListenAndServe(":8080", mux)
}

Header 版本

func versionMiddleware(next http.Handler) http.Handler {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        version := r.Header.Get("X-API-Version")
        if version == "" {
            version = "1" // 默认版本
        }

        // 将版本信息存入 Context
        ctx := context.WithValue(r.Context(), "apiVersion", version)
        next.ServeHTTP(w, r.WithContext(ctx))
    })
}

// 请求示例:
// GET /api/users
// X-API-Version: 2

Accept Header 版本

func versionFromAccept(accept string) string {
    switch {
    case strings.Contains(accept, "version=2"):
        return "v2"
    case strings.Contains(accept, "application/vnd.myapi.v2+json"):
        return "v2"
    default:
        return "v1"
    }
}

CORS 跨域处理

CORS(Cross-Origin Resource Sharing)

CORS 是一种浏览器安全机制,允许服务器声明哪些外部域名可以访问其资源。浏览器会在跨域请求前发送 OPTIONS 预检请求(Preflight),服务器需返回适当的响应头。

package main

import (
    "net/http"
    "strconv"
    "strings"
)

// CORS 中间件
func corsMiddleware(next http.Handler) http.Handler {
    return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
        origin := r.Header.Get("Origin")

        // 允许的来源(生产环境应配置具体域名)
        allowedOrigins := []string{
            "http://localhost:3000",
            "https://example.com",
        }

        allowed := false
        for _, o := range allowedOrigins {
            if origin == o {
                allowed = true
                break
            }
        }

        if allowed {
            w.Header().Set("Access-Control-Allow-Origin", origin)
        }

        // 允许的请求方法
        w.Header().Set("Access-Control-Allow-Methods",
            "GET, POST, PUT, PATCH, DELETE, OPTIONS")

        // 允许的请求头
        w.Header().Set("Access-Control-Allow-Headers",
            "Content-Type, Authorization, X-API-Version")

        // 允许浏览器缓存预检结果(秒)
        w.Header().Set("Access-Control-Max-Age", "86400")

        // 允许携带凭证(Cookie)
        w.Header().Set("Access-Control-Allow-Credentials", "true")

        // 暴露给前端的响应头
        w.Header().Set("Access-Control-Expose-Headers",
            "X-Total-Count, X-Request-ID")

        // 处理预检请求
        if r.Method == http.MethodOptions {
            w.WriteHeader(http.StatusNoContent)
            return
        }

        next.ServeHTTP(w, r)
    })
}

func main() {
    mux := http.NewServeMux()
    mux.HandleFunc("/api/users", usersHandler)

    // 应用 CORS 中间件
    handler := corsMiddleware(mux)
    http.ListenAndServe(":8080", handler)
}

HATEOAS(了解)

HATEOAS

HATEOAS(Hypermedia As The Engine Of Application State)是 REST 架构的最高约束等级。响应中包含指向相关资源的超链接,客户端无需硬编码 URL,而是通过响应中的链接发现可用的操作。

type UserResource struct {
    ID    int    `json:"id"`
    Name  string `json:"name"`
    Email string `json:"email"`
    Links Links  `json:"_links"`
}

type Link struct {
    Href   string `json:"href"`
    Method string `json:"method,omitempty"`
    Rel    string `json:"rel"`
}

type Links struct {
    Self  Link   `json:"self"`
    Users Link   `json:"users,omitempty"`
    Todos []Link `json:"todos,omitempty"`
}

func getUserHATEOAS(w http.ResponseWriter, r *http.Request, user *User) {
    resource := UserResource{
        ID:    user.ID,
        Name:  user.Name,
        Email: user.Email,
        Links: Links{
            Self: Link{
                Href: fmt.Sprintf("/api/v1/users/%d", user.ID),
                Rel:  "self",
            },
            Users: Link{
                Href: "/api/v1/users",
                Rel:  "collection",
            },
            Todos: []Link{
                {Href: fmt.Sprintf("/api/v1/users/%d/todos", user.ID), Rel: "todos"},
            },
        },
    }
    writeJSON(w, http.StatusOK, resource)
}

// 响应示例:
// {
//   "id": 1,
//   "name": "张三",
//   "email": "zhangsan@example.com",
//   "_links": {
//     "self": { "href": "/api/v1/users/1", "rel": "self" },
//     "users": { "href": "/api/v1/users", "rel": "collection" },
//     "todos": [
//       { "href": "/api/v1/users/1/todos", "rel": "todos" }
//     ]
//   }
// }

练习题

练习 1:实现带参数验证的 Todo API

实现一个 Todo 任务的 RESTful API,包含创建、查询列表、查询单条、更新、删除五个接口。创建和更新时需验证:标题不为空且不超过 100 字符,优先级只能是 lowmediumhigh 之一。

参考答案
package main

import (
    "encoding/json"
    "fmt"
    "net/http"
    "strconv"
    "strings"
    "sync"
    "time"
)

type Todo struct {
    ID          int       `json:"id"`
    Title       string    `json:"title"`
    Description string    `json:"description"`
    Priority    string    `json:"priority"`
    Completed   bool      `json:"completed"`
    CreatedAt   time.Time `json:"created_at"`
    UpdatedAt   time.Time `json:"updated_at"`
}

var (
    todos  = make(map[int]Todo)
    nextID = 1
    mu     sync.Mutex
)

type CreateTodoRequest struct {
    Title       string `json:"title"`
    Description string `json:"description"`
    Priority    string `json:"priority"`
}

func (r *CreateTodoRequest) Validate() map[string]string {
    errs := make(map[string]string)
    if r.Title == "" {
        errs["title"] = "标题不能为空"
    } else if len(r.Title) > 100 {
        errs["title"] = "标题不能超过100字符"
    }
    validPriorities := map[string]bool{"low": true, "medium": true, "high": true}
    if !validPriorities[r.Priority] {
        errs["priority"] = "优先级必须是 low、medium 或 high"
    }
    return errs
}

func todosHandler(w http.ResponseWriter, r *http.Request) {
    path := strings.TrimPrefix(r.URL.Path, "/api/todos")
    path = strings.TrimPrefix(path, "/")

    var id int
    if path != "" {
        var err error
        id, err = strconv.Atoi(path)
        if err != nil {
            writeError(w, http.StatusBadRequest, "无效的 ID")
            return
        }
    }

    switch r.Method {
    case http.MethodGet:
        if id == 0 {
            // 列表
            var list []Todo
            mu.Lock()
            for _, t := range todos {
                list = append(list, t)
            }
            mu.Unlock()
            writeSuccess(w, list)
        } else {
            mu.Lock()
            t, ok := todos[id]
            mu.Unlock()
            if !ok {
                writeError(w, http.StatusNotFound, "Todo 不存在")
                return
            }
            writeSuccess(w, t)
        }
    case http.MethodPost:
        var req CreateTodoRequest
        json.NewDecoder(r.Body).Decode(&req)
        if errs := req.Validate(); len(errs) > 0 {
            writeJSON(w, http.StatusBadRequest, map[string]interface{}{
                "code": 400, "message": "验证失败", "errors": errs,
            })
            return
        }
        now := time.Now()
        mu.Lock()
        todo := Todo{
            ID: nextID, Title: req.Title,
            Description: req.Description, Priority: req.Priority,
            CreatedAt: now, UpdatedAt: now,
        }
        todos[nextID] = todo
        nextID++
        mu.Unlock()
        writeJSON(w, http.StatusCreated, map[string]interface{}{
            "code": 0, "message": "创建成功", "data": todo,
        })
    case http.MethodPut:
        mu.Lock()
        defer mu.Unlock()
        t, ok := todos[id]
        if !ok {
            writeError(w, http.StatusNotFound, "Todo 不存在")
            return
        }
        var req CreateTodoRequest
        json.NewDecoder(r.Body).Decode(&req)
        if errs := req.Validate(); len(errs) > 0 {
            writeJSON(w, http.StatusBadRequest, map[string]interface{}{
                "code": 400, "message": "验证失败", "errors": errs,
            })
            return
        }
        t.Title = req.Title
        t.Description = req.Description
        t.Priority = req.Priority
        t.UpdatedAt = time.Now()
        todos[id] = t
        writeSuccess(w, t)
    case http.MethodDelete:
        mu.Lock()
        defer mu.Unlock()
        if _, ok := todos[id]; !ok {
            writeError(w, http.StatusNotFound, "Todo 不存在")
            return
        }
        delete(todos, id)
        writeSuccess(w, map[string]string{"message": "删除成功"})
    }
}

练习 2:实现分页 + 过滤 + 排序的用户列表 API

实现 GET /api/v1/users 接口,支持以下查询参数:

  • page(页码,默认1)、page_size(每页数量,默认20,最大100)
  • role(过滤:按角色筛选)
  • status(过滤:按状态筛选)
  • sort_by(排序字段,默认 created_at
  • sort_order(排序方向:asc / desc,默认 desc

返回统一分页格式。

参考答案
package main

import (
    "encoding/json"
    "fmt"
    "net/http"
    "sort"
    "strconv"
    "strings"
    "time"
)

type User struct {
    ID        int       `json:"id"`
    Name      string    `json:"name"`
    Email     string    `json:"email"`
    Role      string    `json:"role"`
    Status    string    `json:"status"`
    CreatedAt time.Time `json:"created_at"`
}

// 模拟数据库
var allUsers = []User{
    {ID: 1, Name: "张三", Email: "zhangsan@example.com", Role: "admin", Status: "active", CreatedAt: time.Now().Add(-48 * time.Hour)},
    {ID: 2, Name: "李四", Email: "lisi@example.com", Role: "user", Status: "active", CreatedAt: time.Now().Add(-24 * time.Hour)},
    {ID: 3, Name: "王五", Email: "wangwu@example.com", Role: "user", Status: "inactive", CreatedAt: time.Now()},
}

// 排序字段白名单
var allowedSortFields = map[string]bool{
    "id": true, "name": true, "created_at": true, "role": true,
}

func listUsersHandler(w http.ResponseWriter, r *http.Request) {
    q := r.URL.Query()

    // 解析分页参数
    page, _ := strconv.Atoi(q.Get("page"))
    pageSize, _ := strconv.Atoi(q.Get("page_size"))
    if page < 1 { page = 1 }
    if pageSize < 1 || pageSize > 100 { pageSize = 20 }

    // 过滤
    role := q.Get("role")
    status := q.Get("status")

    filtered := make([]User, 0, len(allUsers))
    for _, u := range allUsers {
        if role != "" && u.Role != role { continue }
        if status != "" && u.Status != status { continue }
        filtered = append(filtered, u)
    }

    // 排序
    sortBy := q.Get("sort_by")
    sortOrder := strings.ToLower(q.Get("sort_order"))
    if !allowedSortFields[sortBy] { sortBy = "created_at" }
    if sortOrder != "asc" { sortOrder = "desc" }

    sort.Slice(filtered, func(i, j int) bool {
        var less bool
        switch sortBy {
        case "name":
            less = filtered[i].Name < filtered[j].Name
        case "id":
            less = filtered[i].ID < filtered[j].ID
        default:
            less = filtered[i].CreatedAt.Before(filtered[j].CreatedAt)
        }
        if sortOrder == "desc" { return !less }
        return less
    })

    total := int64(len(filtered))
    offset := (page - 1) * pageSize
    if offset > len(filtered) { offset = len(filtered) }
    end := offset + pageSize
    if end > len(filtered) { end = len(filtered) }
    items := filtered[offset:end]

    writeJSON(w, http.StatusOK, map[string]interface{}{
        "code": 0, "message": "success",
        "data": map[string]interface{}{
            "items": items,
            "page": map[string]interface{}{
                "page": page, "page_size": pageSize, "total": total,
            },
        },
    })
}

练习 3:实现 CORS 中间件并配置允许特定域名

实现一个 CORS 中间件,支持通过环境变量或配置项配置允许的域名列表。对于非预检请求,验证 Origin 是否在允许列表中,不在则返回 403。

参考答案
package main

import (
    "log"
    "net/http"
    "os"
    "strings"
)

func NewCORSMiddleware(allowedOrigins []string) func(http.Handler) http.Handler {
    originsSet := make(map[string]bool)
    for _, o := range allowedOrigins {
        originsSet[o] = true
    }

    return func(next http.Handler) http.Handler {
        return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
            origin := r.Header.Get("Origin")

            // 预检请求处理
            if r.Method == http.MethodOptions {
                if originsSet[origin] {
                    w.Header().Set("Access-Control-Allow-Origin", origin)
                    w.Header().Set("Access-Control-Allow-Methods",
                        "GET, POST, PUT, PATCH, DELETE, OPTIONS")
                    w.Header().Set("Access-Control-Allow-Headers",
                        "Content-Type, Authorization")
                    w.Header().Set("Access-Control-Max-Age", "3600")
                    w.Header().Set("Access-Control-Allow-Credentials", "true")
                }
                w.WriteHeader(http.StatusNoContent)
                return
            }

            // 非预检请求:验证 Origin
            if origin != "" && !originsSet[origin] {
                http.Error(w, `{"code":403,"message":"Origin not allowed"}`,
                    http.StatusForbidden)
                return
            }

            if originsSet[origin] {
                w.Header().Set("Access-Control-Allow-Origin", origin)
                w.Header().Set("Access-Control-Allow-Credentials", "true")
                w.Header().Set("Access-Control-Expose-Headers",
                    "X-Total-Count, X-Request-ID")
            }

            next.ServeHTTP(w, r)
        })
    }
}

func main() {
    allowedOrigins := strings.Split(
        os.Getenv("CORS_ALLOWED_ORIGINS"), ",",
    )
    if len(allowedOrigins) == 0 {
        allowedOrigins = []string{"http://localhost:3000"}
    }

    cors := NewCORSMiddleware(allowedOrigins)
    mux := http.NewServeMux()
    mux.HandleFunc("/api/users", usersHandler)

    log.Println("Server running on :8080")
    http.ListenAndServe(":8080", cors(mux))
}

搜索