标签 即时消息 下的文章

本文是该系列的第四篇。

在这篇文章中,我们将对端点进行编码,以创建一条消息并列出它们,同时还将编写一个端点以更新参与者上次阅读消息的时间。 首先在 main() 函数中添加这些路由。

router.HandleFunc("POST", "/api/conversations/:conversationID/messages", requireJSON(guard(createMessage)))
router.HandleFunc("GET", "/api/conversations/:conversationID/messages", guard(getMessages))
router.HandleFunc("POST", "/api/conversations/:conversationID/read_messages", guard(readMessages))

消息会进入对话,因此端点包含对话 ID。

创建消息

该端点处理对 /api/conversations/{conversationID}/messages 的 POST 请求,其 JSON 主体仅包含消息内容,并返回新创建的消息。它有两个副作用:更新对话 last_message_id 以及更新参与者 messages_read_at

func createMessage(w http.ResponseWriter, r *http.Request) {
    var input struct {
        Content string `json:"content"`
    }
    defer r.Body.Close()
    if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
        http.Error(w, err.Error(), http.StatusBadRequest)
        return
    }

    errs := make(map[string]string)
    input.Content = removeSpaces(input.Content)
    if input.Content == "" {
        errs["content"] = "Message content required"
    } else if len([]rune(input.Content)) > 480 {
        errs["content"] = "Message too long. 480 max"
    }
    if len(errs) != 0 {
        respond(w, Errors{errs}, http.StatusUnprocessableEntity)
        return
    }

    ctx := r.Context()
    authUserID := ctx.Value(keyAuthUserID).(string)
    conversationID := way.Param(ctx, "conversationID")

    tx, err := db.BeginTx(ctx, nil)
    if err != nil {
        respondError(w, fmt.Errorf("could not begin tx: %v", err))
        return
    }
    defer tx.Rollback()

    isParticipant, err := queryParticipantExistance(ctx, tx, authUserID, conversationID)
    if err != nil {
        respondError(w, fmt.Errorf("could not query participant existance: %v", err))
        return
    }

    if !isParticipant {
        http.Error(w, "Conversation not found", http.StatusNotFound)
        return
    }

    var message Message
    if err := tx.QueryRowContext(ctx, `
        INSERT INTO messages (content, user_id, conversation_id) VALUES
            ($1, $2, $3)
        RETURNING id, created_at
    `, input.Content, authUserID, conversationID).Scan(
        &message.ID,
        &message.CreatedAt,
    ); err != nil {
        respondError(w, fmt.Errorf("could not insert message: %v", err))
        return
    }

    if _, err := tx.ExecContext(ctx, `
        UPDATE conversations SET last_message_id = $1
        WHERE id = $2
    `, message.ID, conversationID); err != nil {
        respondError(w, fmt.Errorf("could not update conversation last message ID: %v", err))
        return
    }

    if err = tx.Commit(); err != nil {
        respondError(w, fmt.Errorf("could not commit tx to create a message: %v", err))
        return
    }

    go func() {
        if err = updateMessagesReadAt(nil, authUserID, conversationID); err != nil {
            log.Printf("could not update messages read at: %v\n", err)
        }
    }()

    message.Content = input.Content
    message.UserID = authUserID
    message.ConversationID = conversationID
    // TODO: notify about new message.
    message.Mine = true

    respond(w, message, http.StatusCreated)
}

首先,它将请求正文解码为包含消息内容的结构。然后,它验证内容不为空并且少于 480 个字符。

var rxSpaces = regexp.MustCompile("\\s+")

func removeSpaces(s string) string {
    if s == "" {
        return s
    }

    lines := make([]string, 0)
    for _, line := range strings.Split(s, "\n") {
        line = rxSpaces.ReplaceAllLiteralString(line, " ")
        line = strings.TrimSpace(line)
        if line != "" {
            lines = append(lines, line)
        }
    }
    return strings.Join(lines, "\n")
}

这是删除空格的函数。它遍历每一行,删除两个以上的连续空格,然后回非空行。

验证之后,它将启动一个 SQL 事务。首先,它查询对话中的参与者是否存在。

func queryParticipantExistance(ctx context.Context, tx *sql.Tx, userID, conversationID string) (bool, error) {
    if ctx == nil {
        ctx = context.Background()
    }
    var exists bool
    if err := tx.QueryRowContext(ctx, `SELECT EXISTS (
        SELECT 1 FROM participants
        WHERE user_id = $1 AND conversation_id = $2
    )`, userID, conversationID).Scan(&exists); err != nil {
        return false, err
    }
    return exists, nil
}

我将其提取到一个函数中,因为稍后可以重用。

如果用户不是对话参与者,我们将返回一个 404 NOT Found 错误。

然后,它插入消息并更新对话 last_message_id。从这时起,由于我们不允许删除消息,因此 last_message_id 不能为 NULL

接下来提交事务,并在 goroutine 中更新参与者 messages_read_at

func updateMessagesReadAt(ctx context.Context, userID, conversationID string) error {
    if ctx == nil {
        ctx = context.Background()
    }

    if _, err := db.ExecContext(ctx, `
        UPDATE participants SET messages_read_at = now()
        WHERE user_id = $1 AND conversation_id = $2
    `, userID, conversationID); err != nil {
        return err
    }
    return nil
}

在回复这条新消息之前,我们必须通知一下。这是我们将要在下一篇文章中编写的实时部分,因此我在那里留一了个注释。

获取消息

这个端点处理对 /api/conversations/{conversationID}/messages 的 GET 请求。 它用一个包含会话中所有消息的 JSON 数组进行响应。它还具有更新参与者 messages_read_at 的副作用。

func getMessages(w http.ResponseWriter, r *http.Request) {
    ctx := r.Context()
    authUserID := ctx.Value(keyAuthUserID).(string)
    conversationID := way.Param(ctx, "conversationID")

    tx, err := db.BeginTx(ctx, &sql.TxOptions{ReadOnly: true})
    if err != nil {
        respondError(w, fmt.Errorf("could not begin tx: %v", err))
        return
    }
    defer tx.Rollback()

    isParticipant, err := queryParticipantExistance(ctx, tx, authUserID, conversationID)
    if err != nil {
        respondError(w, fmt.Errorf("could not query participant existance: %v", err))
        return
    }

    if !isParticipant {
        http.Error(w, "Conversation not found", http.StatusNotFound)
        return
    }

    rows, err := tx.QueryContext(ctx, `
        SELECT
            id,
            content,
            created_at,
            user_id = $1 AS mine
        FROM messages
        WHERE messages.conversation_id = $2
        ORDER BY messages.created_at DESC
    `, authUserID, conversationID)
    if err != nil {
        respondError(w, fmt.Errorf("could not query messages: %v", err))
        return
    }
    defer rows.Close()

    messages := make([]Message, 0)
    for rows.Next() {
        var message Message
        if err = rows.Scan(
            &message.ID,
            &message.Content,
            &message.CreatedAt,
            &message.Mine,
        ); err != nil {
            respondError(w, fmt.Errorf("could not scan message: %v", err))
            return
        }

        messages = append(messages, message)
    }

    if err = rows.Err(); err != nil {
        respondError(w, fmt.Errorf("could not iterate over messages: %v", err))
        return
    }

    if err = tx.Commit(); err != nil {
        respondError(w, fmt.Errorf("could not commit tx to get messages: %v", err))
        return
    }

    go func() {
        if err = updateMessagesReadAt(nil, authUserID, conversationID); err != nil {
            log.Printf("could not update messages read at: %v\n", err)
        }
    }()

    respond(w, messages, http.StatusOK)
}

首先,它以只读模式开始一个 SQL 事务。检查参与者是否存在,并查询所有消息。在每条消息中,我们使用当前经过身份验证的用户 ID 来了解用户是否拥有该消息(mine)。 然后,它提交事务,在 goroutine 中更新参与者 messages_read_at 并以消息响应。

读取消息

该端点处理对 /api/conversations/{conversationID}/read_messages 的 POST 请求。 没有任何请求或响应主体。 在前端,每次有新消息到达实时流时,我们都会发出此请求。

func readMessages(w http.ResponseWriter, r *http.Request) {
    ctx := r.Context()
    authUserID := ctx.Value(keyAuthUserID).(string)
    conversationID := way.Param(ctx, "conversationID")

    if err := updateMessagesReadAt(ctx, authUserID, conversationID); err != nil {
        respondError(w, fmt.Errorf("could not update messages read at: %v", err))
        return
    }

    w.WriteHeader(http.StatusNoContent)
}

它使用了与更新参与者 messages_read_at 相同的函数。


到此为止。实时消息是后台仅剩的部分了。请等待下一篇文章。


via: https://nicolasparada.netlify.com/posts/go-messenger-messages/

作者:Nicolás Parada 选题:lujun9972 译者:gxlct008 校对:wxy

本文由 LCTT 原创编译,Linux中国 荣誉推出

本文是该系列的第三篇。

在我们的即时消息应用中,消息表现为两个参与者对话的堆叠。如果你想要开始一场对话,就应该向应用提供你想要交谈的用户,而当对话创建后(如果该对话此前并不存在),就可以向该对话发送消息。

就前端而言,我们可能想要显示一份近期对话列表。并在此处显示对话的最后一条消息以及另一个参与者的姓名和头像。

在这篇帖子中,我们将会编写一些 端点 endpoint 来完成像“创建对话”、“获取对话列表”以及“找到单个对话”这样的任务。

首先,要在主函数 main() 中添加下面的路由。

router.HandleFunc("POST", "/api/conversations", requireJSON(guard(createConversation)))
router.HandleFunc("GET", "/api/conversations", guard(getConversations))
router.HandleFunc("GET", "/api/conversations/:conversationID", guard(getConversation))

这三个端点都需要进行身份验证,所以我们将会使用 guard() 中间件。我们也会构建一个新的中间件,用于检查请求内容是否为 JSON 格式。

JSON 请求检查中间件

func requireJSON(handler http.HandlerFunc) http.HandlerFunc {
    return func(w http.ResponseWriter, r *http.Request) {
        if ct := r.Header.Get("Content-Type"); !strings.HasPrefix(ct, "application/json") {
            http.Error(w, "Content type of application/json required", http.StatusUnsupportedMediaType)
            return
        }
        handler(w, r)
    }
}

如果 请求 request 不是 JSON 格式,那么它会返回 415 Unsupported Media Type(不支持的媒体类型)错误。

创建对话

type Conversation struct {
    ID                string   `json:"id"`
    OtherParticipant  *User    `json:"otherParticipant"`
    LastMessage       *Message `json:"lastMessage"`
    HasUnreadMessages bool     `json:"hasUnreadMessages"`
}

就像上面的代码那样,对话中保持对另一个参与者和最后一条消息的引用,还有一个 bool 类型的字段,用来告知是否有未读消息。

type Message struct {
    ID             string    `json:"id"`
    Content        string    `json:"content"`
    UserID         string    `json:"-"`
    ConversationID string    `json:"conversationID,omitempty"`
    CreatedAt      time.Time `json:"createdAt"`
    Mine           bool      `json:"mine"`
    ReceiverID     string    `json:"-"`
}

我们会在下一篇文章介绍与消息相关的内容,但由于我们这里也需要用到它,所以先定义了 Message 结构体。其中大多数字段与数据库表一致。我们需要使用 Mine 来断定消息是否属于当前已验证用户所有。一旦加入实时功能,ReceiverID 可以帮助我们过滤消息。

接下来让我们编写 HTTP 处理程序。尽管它有些长,但也没什么好怕的。

func createConversation(w http.ResponseWriter, r *http.Request) {
    var input struct {
        Username string `json:"username"`
    }
    defer r.Body.Close()
    if err := json.NewDecoder(r.Body).Decode(&input); err != nil {
        http.Error(w, err.Error(), http.StatusBadRequest)
        return
    }

    input.Username = strings.TrimSpace(input.Username)
    if input.Username == "" {
        respond(w, Errors{map[string]string{
            "username": "Username required",
        }}, http.StatusUnprocessableEntity)
        return
    }

    ctx := r.Context()
    authUserID := ctx.Value(keyAuthUserID).(string)

    tx, err := db.BeginTx(ctx, nil)
    if err != nil {
        respondError(w, fmt.Errorf("could not begin tx: %v", err))
        return
    }
    defer tx.Rollback()

    var otherParticipant User
    if err := tx.QueryRowContext(ctx, `
        SELECT id, avatar_url FROM users WHERE username = $1
    `, input.Username).Scan(
        &otherParticipant.ID,
        &otherParticipant.AvatarURL,
    ); err == sql.ErrNoRows {
        http.Error(w, "User not found", http.StatusNotFound)
        return
    } else if err != nil {
        respondError(w, fmt.Errorf("could not query other participant: %v", err))
        return
    }

    otherParticipant.Username = input.Username

    if otherParticipant.ID == authUserID {
        http.Error(w, "Try start a conversation with someone else", http.StatusForbidden)
        return
    }

    var conversationID string
    if err := tx.QueryRowContext(ctx, `
        SELECT conversation_id FROM participants WHERE user_id = $1
        INTERSECT
        SELECT conversation_id FROM participants WHERE user_id = $2
    `, authUserID, otherParticipant.ID).Scan(&conversationID); err != nil && err != sql.ErrNoRows {
        respondError(w, fmt.Errorf("could not query common conversation id: %v", err))
        return
    } else if err == nil {
        http.Redirect(w, r, "/api/conversations/"+conversationID, http.StatusFound)
        return
    }

    var conversation Conversation
    if err = tx.QueryRowContext(ctx, `
        INSERT INTO conversations DEFAULT VALUES
        RETURNING id
    `).Scan(&conversation.ID); err != nil {
        respondError(w, fmt.Errorf("could not insert conversation: %v", err))
        return
    }

    if _, err = tx.ExecContext(ctx, `
        INSERT INTO participants (user_id, conversation_id) VALUES
            ($1, $2),
            ($3, $2)
    `, authUserID, conversation.ID, otherParticipant.ID); err != nil {
        respondError(w, fmt.Errorf("could not insert participants: %v", err))
        return
    }

    if err = tx.Commit(); err != nil {
        respondError(w, fmt.Errorf("could not commit tx to create conversation: %v", err))
        return
    }

    conversation.OtherParticipant = &otherParticipant

    respond(w, conversation, http.StatusCreated)
}

在此端点,你会向 /api/conversations 发送 POST 请求,请求的 JSON 主体中包含要对话的用户的用户名。

因此,首先需要将请求主体解析成包含用户名的结构。然后,校验用户名不能为空。

type Errors struct {
    Errors map[string]string `json:"errors"`
}

这是错误消息的结构体 Errors,它仅仅是一个映射。如果输入空用户名,你就会得到一段带有 422 Unprocessable Entity(无法处理的实体)错误消息的 JSON 。

{
    "errors": {
        "username": "Username required"
    }
}

然后,我们开始执行 SQL 事务。收到的仅仅是用户名,但事实上,我们需要知道实际的用户 ID 。因此,事务的第一项内容是查询另一个参与者的 ID 和头像。如果找不到该用户,我们将会返回 404 Not Found(未找到) 错误。另外,如果找到的用户恰好和“当前已验证用户”相同,我们应该返回 403 Forbidden(拒绝处理)错误。这是由于对话只应当在两个不同的用户之间发起,而不能是同一个。

然后,我们试图找到这两个用户所共有的对话,所以需要使用 INTERSECT 语句。如果存在,只需要通过 /api/conversations/{conversationID} 重定向到该对话并将其返回。

如果未找到共有的对话,我们需要创建一个新的对话并添加指定的两个参与者。最后,我们 COMMIT 该事务并使用新创建的对话进行响应。

获取对话列表

端点 /api/conversations 将获取当前已验证用户的所有对话。

func getConversations(w http.ResponseWriter, r *http.Request) {
    ctx := r.Context()
    authUserID := ctx.Value(keyAuthUserID).(string)

    rows, err := db.QueryContext(ctx, `
        SELECT
            conversations.id,
            auth_user.messages_read_at < messages.created_at AS has_unread_messages,
            messages.id,
            messages.content,
            messages.created_at,
            messages.user_id = $1 AS mine,
            other_users.id,
            other_users.username,
            other_users.avatar_url
        FROM conversations
        INNER JOIN messages ON conversations.last_message_id = messages.id
        INNER JOIN participants other_participants
            ON other_participants.conversation_id = conversations.id
                AND other_participants.user_id != $1
        INNER JOIN users other_users ON other_participants.user_id = other_users.id
        INNER JOIN participants auth_user
            ON auth_user.conversation_id = conversations.id
                AND auth_user.user_id = $1
        ORDER BY messages.created_at DESC
    `, authUserID)
    if err != nil {
        respondError(w, fmt.Errorf("could not query conversations: %v", err))
        return
    }
    defer rows.Close()

    conversations := make([]Conversation, 0)
    for rows.Next() {
        var conversation Conversation
        var lastMessage Message
        var otherParticipant User
        if err = rows.Scan(
            &conversation.ID,
            &conversation.HasUnreadMessages,
            &lastMessage.ID,
            &lastMessage.Content,
            &lastMessage.CreatedAt,
            &lastMessage.Mine,
            &otherParticipant.ID,
            &otherParticipant.Username,
            &otherParticipant.AvatarURL,
        ); err != nil {
            respondError(w, fmt.Errorf("could not scan conversation: %v", err))
            return
        }

        conversation.LastMessage = &lastMessage
        conversation.OtherParticipant = &otherParticipant
        conversations = append(conversations, conversation)
    }

    if err = rows.Err(); err != nil {
        respondError(w, fmt.Errorf("could not iterate over conversations: %v", err))
        return
    }

    respond(w, conversations, http.StatusOK)
}

该处理程序仅对数据库进行查询。它通过一些联接来查询对话表……首先,从消息表中获取最后一条消息。然后依据“ID 与当前已验证用户不同”的条件,从参与者表找到对话的另一个参与者。然后联接到用户表以获取该用户的用户名和头像。最后,再次联接参与者表,并以相反的条件从该表中找出参与对话的另一个用户,其实就是当前已验证用户。我们会对比消息中的 messages_read_atcreated_at 两个字段,以确定对话中是否存在未读消息。然后,我们通过 user_id 字段来判定该消息是否属于“我”(指当前已验证用户)。

注意,此查询过程假定对话中只有两个用户参与,它也仅仅适用于这种情况。另外,该设计也不很适用于需要显示未读消息数量的情况。如果需要显示未读消息的数量,我认为可以在 participants 表上添加一个unread_messages_count INT 字段,并在每次创建新消息的时候递增它,如果用户已读则重置该字段。

接下来需要遍历每一条记录,通过扫描每一个存在的对话来建立一个 对话切片 slice of conversations 并在最后进行响应。

找到单个对话

端点 /api/conversations/{conversationID} 会根据 ID 对单个对话进行响应。

func getConversation(w http.ResponseWriter, r *http.Request) {
    ctx := r.Context()
    authUserID := ctx.Value(keyAuthUserID).(string)
    conversationID := way.Param(ctx, "conversationID")

    var conversation Conversation
    var otherParticipant User
    if err := db.QueryRowContext(ctx, `
        SELECT
            IFNULL(auth_user.messages_read_at < messages.created_at, false) AS has_unread_messages,
            other_users.id,
            other_users.username,
            other_users.avatar_url
        FROM conversations
        LEFT JOIN messages ON conversations.last_message_id = messages.id
        INNER JOIN participants other_participants
            ON other_participants.conversation_id = conversations.id
                AND other_participants.user_id != $1
        INNER JOIN users other_users ON other_participants.user_id = other_users.id
        INNER JOIN participants auth_user
            ON auth_user.conversation_id = conversations.id
                AND auth_user.user_id = $1
        WHERE conversations.id = $2
    `, authUserID, conversationID).Scan(
        &conversation.HasUnreadMessages,
        &otherParticipant.ID,
        &otherParticipant.Username,
        &otherParticipant.AvatarURL,
    ); err == sql.ErrNoRows {
        http.Error(w, "Conversation not found", http.StatusNotFound)
        return
    } else if err != nil {
        respondError(w, fmt.Errorf("could not query conversation: %v", err))
        return
    }

    conversation.ID = conversationID
    conversation.OtherParticipant = &otherParticipant

    respond(w, conversation, http.StatusOK)
}

这里的查询与之前有点类似。尽管我们并不关心最后一条消息的显示问题,并因此忽略了与之相关的一些字段,但是我们需要根据这条消息来判断对话中是否存在未读消息。此时,我们使用 LEFT JOIN 来代替 INNER JOIN,因为 last_message_id 字段是 NULLABLE(可以为空)的;而其他情况下,我们无法得到任何记录。基于同样的理由,我们在 has_unread_messages 的比较中使用了 IFNULL 语句。最后,我们按 ID 进行过滤。

如果查询没有返回任何记录,我们的响应会返回 404 Not Found 错误,否则响应将会返回 200 OK 以及找到的对话。


本篇帖子以创建了一些对话端点结束。

在下一篇帖子中,我们将会看到如何创建并列出消息。


via: https://nicolasparada.netlify.com/posts/go-messenger-conversations/

作者:Nicolás Parada 选题:lujun9972 译者:PsiACE 校对:wxy

本文由 LCTT 原创编译,Linux中国 荣誉推出

上一篇:模式

在这篇帖子中,我们将会通过为应用添加社交登录功能进入后端开发。

社交登录的工作方式十分简单:用户点击链接,然后重定向到 GitHub 授权页面。当用户授予我们对他的个人信息的访问权限之后,就会重定向回登录页面。下一次尝试登录时,系统将不会再次请求授权,也就是说,我们的应用已经记住了这个用户。这使得整个登录流程看起来就和你用鼠标单击一样快。

如果进一步考虑其内部实现的话,过程就会变得复杂起来。首先,我们需要注册一个新的 GitHub OAuth 应用

这一步中,比较重要的是回调 URL。我们将它设置为 http://localhost:3000/api/oauth/github/callback。这是因为,在开发过程中,我们总是在本地主机上工作。一旦你要将应用交付生产,请使用正确的回调 URL 注册一个新的应用。

注册以后,你将会收到“客户端 id”和“安全密钥”。安全起见,请不要与任何人分享他们 ?

顺便让我们开始写一些代码吧。现在,创建一个 main.go 文件:

package main

import (
    "database/sql"
    "fmt"
    "log"
    "net/http"
    "net/url"
    "os"
    "strconv"

    "github.com/gorilla/securecookie"
    "github.com/joho/godotenv"
    "github.com/knq/jwt"
    _ "github.com/lib/pq"
    "github.com/matryer/way"
    "golang.org/x/oauth2"
    "golang.org/x/oauth2/github"
)

var origin *url.URL
var db *sql.DB
var githubOAuthConfig *oauth2.Config
var cookieSigner *securecookie.SecureCookie
var jwtSigner jwt.Signer

func main() {
    godotenv.Load()

    port := intEnv("PORT", 3000)
    originString := env("ORIGIN", fmt.Sprintf("http://localhost:%d/", port))
    databaseURL := env("DATABASE_URL", "postgresql://[email protected]:26257/messenger?sslmode=disable")
    githubClientID := os.Getenv("GITHUB_CLIENT_ID")
    githubClientSecret := os.Getenv("GITHUB_CLIENT_SECRET")
    hashKey := env("HASH_KEY", "secret")
    jwtKey := env("JWT_KEY", "secret")

    var err error
    if origin, err = url.Parse(originString); err != nil || !origin.IsAbs() {
        log.Fatal("invalid origin")
        return
    }

    if i, err := strconv.Atoi(origin.Port()); err == nil {
        port = i
    }

    if githubClientID == "" || githubClientSecret == "" {
        log.Fatalf("remember to set both $GITHUB_CLIENT_ID and $GITHUB_CLIENT_SECRET")
        return
    }

    if db, err = sql.Open("postgres", databaseURL); err != nil {
        log.Fatalf("could not open database connection: %v\n", err)
        return
    }
    defer db.Close()
    if err = db.Ping(); err != nil {
        log.Fatalf("could not ping to db: %v\n", err)
        return
    }

    githubRedirectURL := *origin
    githubRedirectURL.Path = "/api/oauth/github/callback"
    githubOAuthConfig = &oauth2.Config{
        ClientID:     githubClientID,
        ClientSecret: githubClientSecret,
        Endpoint:     github.Endpoint,
        RedirectURL:  githubRedirectURL.String(),
        Scopes:       []string{"read:user"},
    }

    cookieSigner = securecookie.New([]byte(hashKey), nil).MaxAge(0)

    jwtSigner, err = jwt.HS256.New([]byte(jwtKey))
    if err != nil {
        log.Fatalf("could not create JWT signer: %v\n", err)
        return
    }

    router := way.NewRouter()
    router.HandleFunc("GET", "/api/oauth/github", githubOAuthStart)
    router.HandleFunc("GET", "/api/oauth/github/callback", githubOAuthCallback)
    router.HandleFunc("GET", "/api/auth_user", guard(getAuthUser))

    log.Printf("accepting connections on port %d\n", port)
    log.Printf("starting server at %s\n", origin.String())
    addr := fmt.Sprintf(":%d", port)
    if err = http.ListenAndServe(addr, router); err != nil {
        log.Fatalf("could not start server: %v\n", err)
    }
}

func env(key, fallbackValue string) string {
    v, ok := os.LookupEnv(key)
    if !ok {
        return fallbackValue
    }
    return v
}

func intEnv(key string, fallbackValue int) int {
    v, ok := os.LookupEnv(key)
    if !ok {
        return fallbackValue
    }
    i, err := strconv.Atoi(v)
    if err != nil {
        return fallbackValue
    }
    return i
}

安装依赖项:

go get -u github.com/gorilla/securecookie
go get -u github.com/joho/godotenv
go get -u github.com/knq/jwt
go get -u github.com/lib/pq
ge get -u github.com/matoous/go-nanoid
go get -u github.com/matryer/way
go get -u golang.org/x/oauth2

我们将会使用 .env 文件来保存密钥和其他配置。请创建这个文件,并保证里面至少包含以下内容:

GITHUB_CLIENT_ID=your_github_client_id
GITHUB_CLIENT_SECRET=your_github_client_secret

我们还要用到的其他环境变量有:

  • PORT:服务器运行的端口,默认值是 3000
  • ORIGIN:你的域名,默认值是 http://localhost:3000/。我们也可以在这里指定端口。
  • DATABASE_URL:Cockroach 数据库的地址。默认值是 postgresql://[email protected]:26257/messenger?sslmode=disable
  • HASH_KEY:用于为 cookie 签名的密钥。没错,我们会使用已签名的 cookie 来确保安全。
  • JWT_KEY:用于签署 JSON 网络令牌 Web Token 的密钥。

因为代码中已经设定了默认值,所以你也不用把它们写到 .env 文件中。

在读取配置并连接到数据库之后,我们会创建一个 OAuth 配置。我们会使用 ORIGIN 信息来构建回调 URL(就和我们在 GitHub 页面上注册的一样)。我们的数据范围设置为 “read:user”。这会允许我们读取公开的用户信息,这里我们只需要他的用户名和头像就够了。然后我们会初始化 cookie 和 JWT 签名器。定义一些端点并启动服务器。

在实现 HTTP 处理程序之前,让我们编写一些函数来发送 HTTP 响应。

func respond(w http.ResponseWriter, v interface{}, statusCode int) {
    b, err := json.Marshal(v)
    if err != nil {
        respondError(w, fmt.Errorf("could not marshal response: %v", err))
        return
    }
    w.Header().Set("Content-Type", "application/json; charset=utf-8")
    w.WriteHeader(statusCode)
    w.Write(b)
}

func respondError(w http.ResponseWriter, err error) {
    log.Println(err)
    http.Error(w, http.StatusText(http.StatusInternalServerError), http.StatusInternalServerError)
}

第一个函数用来发送 JSON,而第二个将错误记录到控制台并返回一个 500 Internal Server Error 错误信息。

OAuth 开始

所以,用户点击写着 “Access with GitHub” 的链接。该链接指向 /api/oauth/github,这将会把用户重定向到 github。

func githubOAuthStart(w http.ResponseWriter, r *http.Request) {
    state, err := gonanoid.Nanoid()
    if err != nil {
        respondError(w, fmt.Errorf("could not generte state: %v", err))
        return
    }

    stateCookieValue, err := cookieSigner.Encode("state", state)
    if err != nil {
        respondError(w, fmt.Errorf("could not encode state cookie: %v", err))
        return
    }

    http.SetCookie(w, &http.Cookie{
        Name:     "state",
        Value:    stateCookieValue,
        Path:     "/api/oauth/github",
        HttpOnly: true,
    })
    http.Redirect(w, r, githubOAuthConfig.AuthCodeURL(state), http.StatusTemporaryRedirect)
}

OAuth2 使用一种机制来防止 CSRF 攻击,因此它需要一个“状态”(state)。我们使用 Nanoid() 来创建一个随机字符串,并用这个字符串作为状态。我们也把它保存为一个 cookie。

OAuth 回调

一旦用户授权我们访问他的个人信息,他将会被重定向到这个端点。这个 URL 的查询字符串上将会包含状态(state)和授权码(code): /api/oauth/github/callback?state=&code=

const jwtLifetime = time.Hour * 24 * 14

type GithubUser struct {
    ID        int     `json:"id"`
    Login     string  `json:"login"`
    AvatarURL *string `json:"avatar_url,omitempty"`
}

type User struct {
    ID        string  `json:"id"`
    Username  string  `json:"username"`
    AvatarURL *string `json:"avatarUrl"`
}

func githubOAuthCallback(w http.ResponseWriter, r *http.Request) {
    stateCookie, err := r.Cookie("state")
    if err != nil {
        http.Error(w, http.StatusText(http.StatusTeapot), http.StatusTeapot)
        return
    }

    http.SetCookie(w, &http.Cookie{
        Name:     "state",
        Value:    "",
        MaxAge:   -1,
        HttpOnly: true,
    })

    var state string
    if err = cookieSigner.Decode("state", stateCookie.Value, &state); err != nil {
        http.Error(w, http.StatusText(http.StatusTeapot), http.StatusTeapot)
        return
    }

    q := r.URL.Query()

    if state != q.Get("state") {
        http.Error(w, http.StatusText(http.StatusTeapot), http.StatusTeapot)
        return
    }

    ctx := r.Context()

    t, err := githubOAuthConfig.Exchange(ctx, q.Get("code"))
    if err != nil {
        respondError(w, fmt.Errorf("could not fetch github token: %v", err))
        return
    }

    client := githubOAuthConfig.Client(ctx, t)
    resp, err := client.Get("https://api.github.com/user")
    if err != nil {
        respondError(w, fmt.Errorf("could not fetch github user: %v", err))
        return
    }

    var githubUser GithubUser
    if err = json.NewDecoder(resp.Body).Decode(&githubUser); err != nil {
        respondError(w, fmt.Errorf("could not decode github user: %v", err))
        return
    }
    defer resp.Body.Close()

    tx, err := db.BeginTx(ctx, nil)
    if err != nil {
        respondError(w, fmt.Errorf("could not begin tx: %v", err))
        return
    }

    var user User
    if err = tx.QueryRowContext(ctx, `
        SELECT id, username, avatar_url FROM users WHERE github_id = $1
    `, githubUser.ID).Scan(&user.ID, &user.Username, &user.AvatarURL); err == sql.ErrNoRows {
        if err = tx.QueryRowContext(ctx, `
            INSERT INTO users (username, avatar_url, github_id) VALUES ($1, $2, $3)
            RETURNING id
        `, githubUser.Login, githubUser.AvatarURL, githubUser.ID).Scan(&user.ID); err != nil {
            respondError(w, fmt.Errorf("could not insert user: %v", err))
            return
        }
        user.Username = githubUser.Login
        user.AvatarURL = githubUser.AvatarURL
    } else if err != nil {
        respondError(w, fmt.Errorf("could not query user by github ID: %v", err))
        return
    }

    if err = tx.Commit(); err != nil {
        respondError(w, fmt.Errorf("could not commit to finish github oauth: %v", err))
        return
    }

    exp := time.Now().Add(jwtLifetime)
    token, err := jwtSigner.Encode(jwt.Claims{
        Subject:    user.ID,
        Expiration: json.Number(strconv.FormatInt(exp.Unix(), 10)),
    })
    if err != nil {
        respondError(w, fmt.Errorf("could not create token: %v", err))
        return
    }

    expiresAt, _ := exp.MarshalText()

    data := make(url.Values)
    data.Set("token", string(token))
    data.Set("expires_at", string(expiresAt))

    http.Redirect(w, r, "/callback?"+data.Encode(), http.StatusTemporaryRedirect)
}

首先,我们会尝试使用之前保存的状态对 cookie 进行解码。并将其与查询字符串中的状态进行比较。如果它们不匹配,我们会返回一个 418 I'm teapot(未知来源)错误。

接着,我们使用授权码生成一个令牌。这个令牌被用于创建 HTTP 客户端来向 GitHub API 发出请求。所以最终我们会向 https://api.github.com/user 发送一个 GET 请求。这个端点将会以 JSON 格式向我们提供当前经过身份验证的用户信息。我们将会解码这些内容,一并获取用户的 ID、登录名(用户名)和头像 URL。

然后我们将会尝试在数据库上找到具有该 GitHub ID 的用户。如果没有找到,就使用该数据创建一个新的。

之后,对于新创建的用户,我们会发出一个将用户 ID 作为主题(Subject)的 JSON 网络令牌,并使用该令牌重定向到前端,查询字符串中一并包含该令牌的到期日(Expiration)。

这一 Web 应用也会被用在其他帖子,但是重定向的链接会是 /callback?token=&expires_at=。在那里,我们将会利用 JavaScript 从 URL 中获取令牌和到期日,并通过 Authorization 标头中的令牌以 Bearer token_here 的形式对 /api/auth_user 进行 GET 请求,来获取已认证的身份用户并将其保存到 localStorage。

Guard 中间件

为了获取当前已经过身份验证的用户,我们设计了 Guard 中间件。这是因为在接下来的文章中,我们会有很多需要进行身份认证的端点,而中间件将会允许我们共享这一功能。

type ContextKey struct {
    Name string
}

var keyAuthUserID = ContextKey{"auth_user_id"}

func guard(handler http.HandlerFunc) http.HandlerFunc {
    return func(w http.ResponseWriter, r *http.Request) {
        var token string
        if a := r.Header.Get("Authorization"); strings.HasPrefix(a, "Bearer ") {
            token = a[7:]
        } else if t := r.URL.Query().Get("token"); t != "" {
            token = t
        } else {
            http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
            return
        }

        var claims jwt.Claims
        if err := jwtSigner.Decode([]byte(token), &claims); err != nil {
            http.Error(w, http.StatusText(http.StatusUnauthorized), http.StatusUnauthorized)
            return
        }

        ctx := r.Context()
        ctx = context.WithValue(ctx, keyAuthUserID, claims.Subject)

        handler(w, r.WithContext(ctx))
    }
}

首先,我们尝试从 Authorization 标头或者是 URL 查询字符串中的 token 字段中读取令牌。如果没有找到,我们需要返回 401 Unauthorized(未授权)错误。然后我们将会对令牌中的申明进行解码,并使用该主题作为当前已经过身份验证的用户 ID。

现在,我们可以用这一中间件来封装任何需要授权的 http.handlerFunc,并且在处理函数的上下文中保有已经过身份验证的用户 ID。

var guarded = guard(func(w http.ResponseWriter, r *http.Request) {
    authUserID := r.Context().Value(keyAuthUserID).(string)
})

获取认证用户

func getAuthUser(w http.ResponseWriter, r *http.Request) {
    ctx := r.Context()
    authUserID := ctx.Value(keyAuthUserID).(string)

    var user User
    if err := db.QueryRowContext(ctx, `
        SELECT username, avatar_url FROM users WHERE id = $1
    `, authUserID).Scan(&user.Username, &user.AvatarURL); err == sql.ErrNoRows {
        http.Error(w, http.StatusText(http.StatusTeapot), http.StatusTeapot)
        return
    } else if err != nil {
        respondError(w, fmt.Errorf("could not query auth user: %v", err))
        return
    }

    user.ID = authUserID

    respond(w, user, http.StatusOK)
}

我们使用 Guard 中间件来获取当前经过身份认证的用户 ID 并查询数据库。

这一部分涵盖了后端的 OAuth 流程。在下一篇帖子中,我们将会看到如何开始与其他用户的对话。


via: https://nicolasparada.netlify.com/posts/go-messenger-oauth/

作者:Nicolás Parada 选题:lujun9972 译者:PsiACE 校对:wxy

本文由 LCTT 原创编译,Linux中国 荣誉推出