Skip to content

Commit

Permalink
Add the bot's userinfo to the context struct to ensure we have all th…
Browse files Browse the repository at this point in the history
…e necessary information to determine update ownership at runtime
  • Loading branch information
PaulSonOfLars committed Sep 25, 2024
1 parent b45dee8 commit dd7351b
Show file tree
Hide file tree
Showing 6 changed files with 35 additions and 31 deletions.
9 changes: 7 additions & 2 deletions ext/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,6 +10,10 @@ import (
type Context struct {
// gotgbot.Update is inlined so that we can access all fields immediately if necessary.
*gotgbot.Update
// We store the info of the Bot that received this update so we can keep track of update ownership.
// We do NOT store full gotgbot.Bot struct, as that would leak the bot token and bot client information in what
// should be a data-only struct.
Bot gotgbot.User
// Data represents update-local storage.
// This can be used to pass data across handlers - for example, to cache operations relevant to the current update,
// such as admin checks.
Expand All @@ -35,9 +39,9 @@ type Context struct {
EffectiveSender *gotgbot.Sender
}

// NewContext populates a context with the relevant fields from the current update.
// NewContext populates a context with the relevant fields from the current bot and update.
// It takes a data field in the case where custom data needs to be passed.
func NewContext(update *gotgbot.Update, data map[string]interface{}) *Context {
func NewContext(b *gotgbot.Bot, update *gotgbot.Update, data map[string]interface{}) *Context {
var msg *gotgbot.Message
var chat *gotgbot.Chat
var user *gotgbot.User
Expand Down Expand Up @@ -162,6 +166,7 @@ func NewContext(update *gotgbot.Update, data map[string]interface{}) *Context {

return &Context{
Update: update,
Bot: b.User,
Data: data,
EffectiveMessage: msg,
EffectiveChat: chat,
Expand Down
2 changes: 1 addition & 1 deletion ext/dispatcher.go
Original file line number Diff line number Diff line change
Expand Up @@ -268,7 +268,7 @@ func (d *Dispatcher) processRawUpdate(b *gotgbot.Bot, r json.RawMessage) error {
// ProcessUpdate iterates over the list of groups to execute the matching handlers.
// This is also where we recover from any panics that are thrown by user code, to avoid taking down the bot.
func (d *Dispatcher) ProcessUpdate(b *gotgbot.Bot, u *gotgbot.Update, data map[string]interface{}) (err error) {
ctx := NewContext(u, data)
ctx := NewContext(b, u, data)

defer func() {
if r := recover(); r != nil {
Expand Down
2 changes: 1 addition & 1 deletion ext/dispatcher_ext_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -100,7 +100,7 @@ func TestDispatcher(t *testing.T) {
}

t.Log("Processing one update...")
err := d.ProcessUpdate(nil, &gotgbot.Update{
err := d.ProcessUpdate(&gotgbot.Bot{}, &gotgbot.Update{
Message: &gotgbot.Message{Text: "test text"},
}, nil)
if err != nil {
Expand Down
12 changes: 6 additions & 6 deletions ext/handlers/common_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,13 @@ func NewTestBot() *gotgbot.Bot {
}
}

func NewMessage(userId int64, chatId int64, message string) *ext.Context {
return newMessage(userId, chatId, message, nil)
func NewMessage(b *gotgbot.Bot, userId int64, chatId int64, message string) *ext.Context {
return newMessage(b, userId, chatId, message, nil)
}

func NewCommandMessage(userId int64, chatId int64, command string, args []string) *ext.Context {
func NewCommandMessage(b *gotgbot.Bot, userId int64, chatId int64, command string, args []string) *ext.Context {
msg, ents := buildCommand(command, args)
return newMessage(userId, chatId, msg, ents)
return newMessage(b, userId, chatId, msg, ents)
}

func buildCommand(cmd string, args []string) (string, []gotgbot.MessageEntity) {
Expand All @@ -53,13 +53,13 @@ func buildCommand(cmd string, args []string) (string, []gotgbot.MessageEntity) {
}
}

func newMessage(userId int64, chatId int64, message string, entities []gotgbot.MessageEntity) *ext.Context {
func newMessage(b *gotgbot.Bot, userId int64, chatId int64, message string, entities []gotgbot.MessageEntity) *ext.Context {
chatType := "supergroup"
if userId == chatId {
chatType = "private"
}

return ext.NewContext(&gotgbot.Update{
return ext.NewContext(b, &gotgbot.Update{
UpdateId: rand.Int63(), // should this be consistent?
Message: &gotgbot.Message{
MessageId: rand.Int63(), // should this be consistent?
Expand Down
7 changes: 3 additions & 4 deletions ext/handlers/conversation/key_strategies.go
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,6 @@ package conversation
import (
"errors"
"fmt"
"strconv"

"github.com/PaulSonOfLars/gotgbot/v2/ext"
)
Expand All @@ -27,23 +26,23 @@ func KeyStrategySenderAndChat(ctx *ext.Context) (string, error) {
if ctx.EffectiveSender == nil || ctx.EffectiveChat == nil {
return "", fmt.Errorf("missing sender or chat fields: %w", ErrEmptyKey)
}
return fmt.Sprintf("%d/%d", ctx.EffectiveSender.Id(), ctx.EffectiveChat.Id), nil
return fmt.Sprintf("%d/%d/%d", ctx.Bot.Id, ctx.EffectiveSender.Id(), ctx.EffectiveChat.Id), nil
}

// KeyStrategySender gives a unique conversation to each sender, and that single conversation is available in all chats.
func KeyStrategySender(ctx *ext.Context) (string, error) {
if ctx.EffectiveSender == nil {
return "", fmt.Errorf("missing sender field: %w", ErrEmptyKey)
}
return strconv.FormatInt(ctx.EffectiveSender.Id(), 10), nil
return fmt.Sprintf("%d/%d", ctx.Bot.Id, ctx.EffectiveSender.Id()), nil
}

// KeyStrategyChat gives a unique conversation to each chat, which all senders can interact in together.
func KeyStrategyChat(ctx *ext.Context) (string, error) {
if ctx.EffectiveChat == nil {
return "", fmt.Errorf("missing chat field: %w", ErrEmptyKey)
}
return strconv.FormatInt(ctx.EffectiveChat.Id, 10), nil
return fmt.Sprintf("%d/%d", ctx.Bot.Id, ctx.EffectiveChat.Id), nil
}

// StateKey provides a sane default for handling incoming updates.
Expand Down
34 changes: 17 additions & 17 deletions ext/handlers/conversation_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,14 @@ func TestBasicConversation(t *testing.T) {
var chatId int64 = 1234

// Emulate sending the "start" command, triggering the entrypoint.
startCommand := NewCommandMessage(userId, chatId, "start", []string{})
startCommand := NewCommandMessage(b, userId, chatId, "start", []string{})
runHandler(t, b, &conv, startCommand, "", nextStep)
if !started {
t.Fatalf("expected the entrypoint handler to have run")
}

// Emulate sending the "message" text, triggering the internal handler (and causing it to "end").
textMessage := NewMessage(userId, chatId, "message")
textMessage := NewMessage(b, userId, chatId, "message")
runHandler(t, b, &conv, textMessage, nextStep, "")
if !ended {
t.Fatalf("expected the internal handler to have run")
Expand Down Expand Up @@ -79,8 +79,8 @@ func TestBasicKeyedConversation(t *testing.T) {
var chatId int64 = 1234

// Emulate sending the "start" command, triggering the entrypoint.
startFromUserOne := NewCommandMessage(userIdOne, chatId, "start", []string{})
messageFromTwo := NewMessage(userIdTwo, chatId, "message")
startFromUserOne := NewCommandMessage(b, userIdOne, chatId, "start", []string{})
messageFromTwo := NewMessage(b, userIdTwo, chatId, "message")

runHandler(t, b, &conv, startFromUserOne, "", nextStep)

Expand Down Expand Up @@ -121,14 +121,14 @@ func TestBasicConversationExit(t *testing.T) {
var chatId int64 = 1234

// Emulate sending the "start" command, triggering the entrypoint, and starting the conversation.
startCommand := NewCommandMessage(userId, chatId, "start", []string{})
startCommand := NewCommandMessage(b, userId, chatId, "start", []string{})
runHandler(t, b, &conv, startCommand, "", nextStep)
if !started {
t.Fatalf("expected the entrypoint handler to have run")
}

// Emulate sending the "cancel" command, triggering the exitpoint, and immediately ending the conversation.
cancelCommand := NewCommandMessage(userId, chatId, "cancel", []string{})
cancelCommand := NewCommandMessage(b, userId, chatId, "cancel", []string{})
runHandler(t, b, &conv, cancelCommand, nextStep, "")
if !ended {
t.Fatalf("expected the cancel command to have run")
Expand All @@ -138,7 +138,7 @@ func TestBasicConversationExit(t *testing.T) {
checkExpectedState(t, &conv, cancelCommand, "")

// Emulate sending the "message" text, which now should not interact with the conversation.
textMessage := NewMessage(userId, chatId, "message")
textMessage := NewMessage(b, userId, chatId, "message")
if conv.CheckUpdate(b, textMessage) {
t.Fatalf("did not expect the internal handler to run")
}
Expand Down Expand Up @@ -177,14 +177,14 @@ func TestFallbackConversation(t *testing.T) {
var chatId int64 = 1234

// Emulate sending the "start" command, triggering the entrypoint.
startCommand := NewCommandMessage(userId, chatId, "start", []string{})
startCommand := NewCommandMessage(b, userId, chatId, "start", []string{})
runHandler(t, b, &conv, startCommand, "", nextStep)
if !started {
t.Fatalf("expected the entrypoint handler to have run")
}

// Emulate sending the "cancel" command, triggering the fallback handler (and causing it to "end").
cancelCommand := NewCommandMessage(userId, chatId, "cancel", []string{})
cancelCommand := NewCommandMessage(b, userId, chatId, "cancel", []string{})
runHandler(t, b, &conv, cancelCommand, nextStep, "")
if !fallback {
t.Fatalf("expected the fallback handler to have run")
Expand Down Expand Up @@ -220,14 +220,14 @@ func TestReEntryConversation(t *testing.T) {
var chatId int64 = 1234

// Emulate sending the "start" command, triggering the entrypoint.
startCommand := NewCommandMessage(userId, chatId, "start", []string{})
startCommand := NewCommandMessage(b, userId, chatId, "start", []string{})
runHandler(t, b, &conv, startCommand, "", nextStep)
if startCount != 1 {
t.Fatalf("expected the entrypoint handler to have run")
}

// Send a message which matches both the entrypoint, and the "nextStep" state.
cancelCommand := NewCommandMessage(userId, chatId, "start", []string{"message"})
cancelCommand := NewCommandMessage(b, userId, chatId, "start", []string{"message"})
runHandler(t, b, &conv, cancelCommand, nextStep, nextStep) // Should hit
if startCount != 2 {
t.Fatalf("expected the entrypoint handler to have run a second time")
Expand Down Expand Up @@ -285,20 +285,20 @@ func TestNestedConversation(t *testing.T) {
var chatId int64 = 1234

// Emulate sending the "start" command, triggering the entrypoint.
start := NewCommandMessage(userId, chatId, startCmd, []string{})
start := NewCommandMessage(b, userId, chatId, startCmd, []string{})
runHandler(t, b, &conv, start, "", firstStep)

// Emulate sending the "message" text, triggering the internal handler (and causing it to "end").
textMessage := NewMessage(userId, chatId, messageText)
textMessage := NewMessage(b, userId, chatId, messageText)
runHandler(t, b, &conv, textMessage, firstStep, secondStep)

// Emulate sending the "nested_start" command, triggering the entrypoint of the nested conversation.
nestedStart := NewCommandMessage(userId, chatId, nestedStartCmd, []string{})
nestedStart := NewCommandMessage(b, userId, chatId, nestedStartCmd, []string{})
willRunHandler(t, b, &nestedConv, nestedStart, "")
runHandler(t, b, &conv, nestedStart, secondStep, secondStep)

// Emulate sending the "nested_start" command, triggering the entrypoint of the nested conversation.
nestedFinish := NewMessage(userId, chatId, finishNestedText)
nestedFinish := NewMessage(b, userId, chatId, finishNestedText)
willRunHandler(t, b, &nestedConv, nestedFinish, nestedStep)
runHandler(t, b, &conv, nestedFinish, secondStep, thirdStep)

Expand All @@ -307,7 +307,7 @@ func TestNestedConversation(t *testing.T) {
t.Log("Nested conversation finished")

// Emulate sending the "message" text, triggering the internal handler (and causing it to "end").
finish := NewMessage(userId, chatId, finishText)
finish := NewMessage(b, userId, chatId, finishText)
runHandler(t, b, &conv, finish, thirdStep, "")

checkExpectedState(t, &conv, textMessage, "")
Expand All @@ -329,7 +329,7 @@ func TestEmptyKeyConversation(t *testing.T) {
)

// Run an empty
pollUpd := ext.NewContext(&gotgbot.Update{
pollUpd := ext.NewContext(b, &gotgbot.Update{
UpdateId: rand.Int63(), // should this be consistent?
Poll: &gotgbot.Poll{
Id: "some_id",
Expand Down

0 comments on commit dd7351b

Please sign in to comment.