From dd7351b99d5283313af9593e6da663e14a412662 Mon Sep 17 00:00:00 2001 From: Paul Larsen Date: Wed, 25 Sep 2024 08:38:06 +0100 Subject: [PATCH] Add the bot's userinfo to the context struct to ensure we have all the necessary information to determine update ownership at runtime --- ext/context.go | 9 ++++-- ext/dispatcher.go | 2 +- ext/dispatcher_ext_test.go | 2 +- ext/handlers/common_test.go | 12 ++++---- ext/handlers/conversation/key_strategies.go | 7 ++--- ext/handlers/conversation_test.go | 34 ++++++++++----------- 6 files changed, 35 insertions(+), 31 deletions(-) diff --git a/ext/context.go b/ext/context.go index cda02969..98206343 100644 --- a/ext/context.go +++ b/ext/context.go @@ -10,6 +10,10 @@ import ( type Context struct { // gotgbot.Update is inlined so that we can access all fields immediately if necessary. *gotgbot.Update + // We store the info of the Bot that received this update so we can keep track of update ownership. + // We do NOT store full gotgbot.Bot struct, as that would leak the bot token and bot client information in what + // should be a data-only struct. + Bot gotgbot.User // Data represents update-local storage. // This can be used to pass data across handlers - for example, to cache operations relevant to the current update, // such as admin checks. @@ -35,9 +39,9 @@ type Context struct { EffectiveSender *gotgbot.Sender } -// NewContext populates a context with the relevant fields from the current update. +// NewContext populates a context with the relevant fields from the current bot and update. // It takes a data field in the case where custom data needs to be passed. -func NewContext(update *gotgbot.Update, data map[string]interface{}) *Context { +func NewContext(b *gotgbot.Bot, update *gotgbot.Update, data map[string]interface{}) *Context { var msg *gotgbot.Message var chat *gotgbot.Chat var user *gotgbot.User @@ -162,6 +166,7 @@ func NewContext(update *gotgbot.Update, data map[string]interface{}) *Context { return &Context{ Update: update, + Bot: b.User, Data: data, EffectiveMessage: msg, EffectiveChat: chat, diff --git a/ext/dispatcher.go b/ext/dispatcher.go index b7a063e0..5074ad38 100644 --- a/ext/dispatcher.go +++ b/ext/dispatcher.go @@ -268,7 +268,7 @@ func (d *Dispatcher) processRawUpdate(b *gotgbot.Bot, r json.RawMessage) error { // ProcessUpdate iterates over the list of groups to execute the matching handlers. // This is also where we recover from any panics that are thrown by user code, to avoid taking down the bot. func (d *Dispatcher) ProcessUpdate(b *gotgbot.Bot, u *gotgbot.Update, data map[string]interface{}) (err error) { - ctx := NewContext(u, data) + ctx := NewContext(b, u, data) defer func() { if r := recover(); r != nil { diff --git a/ext/dispatcher_ext_test.go b/ext/dispatcher_ext_test.go index 6a122e8f..0f5e85f4 100644 --- a/ext/dispatcher_ext_test.go +++ b/ext/dispatcher_ext_test.go @@ -100,7 +100,7 @@ func TestDispatcher(t *testing.T) { } t.Log("Processing one update...") - err := d.ProcessUpdate(nil, &gotgbot.Update{ + err := d.ProcessUpdate(&gotgbot.Bot{}, &gotgbot.Update{ Message: &gotgbot.Message{Text: "test text"}, }, nil) if err != nil { diff --git a/ext/handlers/common_test.go b/ext/handlers/common_test.go index b876b1f1..fcb7e422 100644 --- a/ext/handlers/common_test.go +++ b/ext/handlers/common_test.go @@ -33,13 +33,13 @@ func NewTestBot() *gotgbot.Bot { } } -func NewMessage(userId int64, chatId int64, message string) *ext.Context { - return newMessage(userId, chatId, message, nil) +func NewMessage(b *gotgbot.Bot, userId int64, chatId int64, message string) *ext.Context { + return newMessage(b, userId, chatId, message, nil) } -func NewCommandMessage(userId int64, chatId int64, command string, args []string) *ext.Context { +func NewCommandMessage(b *gotgbot.Bot, userId int64, chatId int64, command string, args []string) *ext.Context { msg, ents := buildCommand(command, args) - return newMessage(userId, chatId, msg, ents) + return newMessage(b, userId, chatId, msg, ents) } func buildCommand(cmd string, args []string) (string, []gotgbot.MessageEntity) { @@ -53,13 +53,13 @@ func buildCommand(cmd string, args []string) (string, []gotgbot.MessageEntity) { } } -func newMessage(userId int64, chatId int64, message string, entities []gotgbot.MessageEntity) *ext.Context { +func newMessage(b *gotgbot.Bot, userId int64, chatId int64, message string, entities []gotgbot.MessageEntity) *ext.Context { chatType := "supergroup" if userId == chatId { chatType = "private" } - return ext.NewContext(&gotgbot.Update{ + return ext.NewContext(b, &gotgbot.Update{ UpdateId: rand.Int63(), // should this be consistent? Message: &gotgbot.Message{ MessageId: rand.Int63(), // should this be consistent? diff --git a/ext/handlers/conversation/key_strategies.go b/ext/handlers/conversation/key_strategies.go index 472c63a1..b8f9b9b6 100644 --- a/ext/handlers/conversation/key_strategies.go +++ b/ext/handlers/conversation/key_strategies.go @@ -3,7 +3,6 @@ package conversation import ( "errors" "fmt" - "strconv" "github.com/PaulSonOfLars/gotgbot/v2/ext" ) @@ -27,7 +26,7 @@ func KeyStrategySenderAndChat(ctx *ext.Context) (string, error) { if ctx.EffectiveSender == nil || ctx.EffectiveChat == nil { return "", fmt.Errorf("missing sender or chat fields: %w", ErrEmptyKey) } - return fmt.Sprintf("%d/%d", ctx.EffectiveSender.Id(), ctx.EffectiveChat.Id), nil + return fmt.Sprintf("%d/%d/%d", ctx.Bot.Id, ctx.EffectiveSender.Id(), ctx.EffectiveChat.Id), nil } // KeyStrategySender gives a unique conversation to each sender, and that single conversation is available in all chats. @@ -35,7 +34,7 @@ func KeyStrategySender(ctx *ext.Context) (string, error) { if ctx.EffectiveSender == nil { return "", fmt.Errorf("missing sender field: %w", ErrEmptyKey) } - return strconv.FormatInt(ctx.EffectiveSender.Id(), 10), nil + return fmt.Sprintf("%d/%d", ctx.Bot.Id, ctx.EffectiveSender.Id()), nil } // KeyStrategyChat gives a unique conversation to each chat, which all senders can interact in together. @@ -43,7 +42,7 @@ func KeyStrategyChat(ctx *ext.Context) (string, error) { if ctx.EffectiveChat == nil { return "", fmt.Errorf("missing chat field: %w", ErrEmptyKey) } - return strconv.FormatInt(ctx.EffectiveChat.Id, 10), nil + return fmt.Sprintf("%d/%d", ctx.Bot.Id, ctx.EffectiveChat.Id), nil } // StateKey provides a sane default for handling incoming updates. diff --git a/ext/handlers/conversation_test.go b/ext/handlers/conversation_test.go index d77881a6..61428bb3 100644 --- a/ext/handlers/conversation_test.go +++ b/ext/handlers/conversation_test.go @@ -37,14 +37,14 @@ func TestBasicConversation(t *testing.T) { var chatId int64 = 1234 // Emulate sending the "start" command, triggering the entrypoint. - startCommand := NewCommandMessage(userId, chatId, "start", []string{}) + startCommand := NewCommandMessage(b, userId, chatId, "start", []string{}) runHandler(t, b, &conv, startCommand, "", nextStep) if !started { t.Fatalf("expected the entrypoint handler to have run") } // Emulate sending the "message" text, triggering the internal handler (and causing it to "end"). - textMessage := NewMessage(userId, chatId, "message") + textMessage := NewMessage(b, userId, chatId, "message") runHandler(t, b, &conv, textMessage, nextStep, "") if !ended { t.Fatalf("expected the internal handler to have run") @@ -79,8 +79,8 @@ func TestBasicKeyedConversation(t *testing.T) { var chatId int64 = 1234 // Emulate sending the "start" command, triggering the entrypoint. - startFromUserOne := NewCommandMessage(userIdOne, chatId, "start", []string{}) - messageFromTwo := NewMessage(userIdTwo, chatId, "message") + startFromUserOne := NewCommandMessage(b, userIdOne, chatId, "start", []string{}) + messageFromTwo := NewMessage(b, userIdTwo, chatId, "message") runHandler(t, b, &conv, startFromUserOne, "", nextStep) @@ -121,14 +121,14 @@ func TestBasicConversationExit(t *testing.T) { var chatId int64 = 1234 // Emulate sending the "start" command, triggering the entrypoint, and starting the conversation. - startCommand := NewCommandMessage(userId, chatId, "start", []string{}) + startCommand := NewCommandMessage(b, userId, chatId, "start", []string{}) runHandler(t, b, &conv, startCommand, "", nextStep) if !started { t.Fatalf("expected the entrypoint handler to have run") } // Emulate sending the "cancel" command, triggering the exitpoint, and immediately ending the conversation. - cancelCommand := NewCommandMessage(userId, chatId, "cancel", []string{}) + cancelCommand := NewCommandMessage(b, userId, chatId, "cancel", []string{}) runHandler(t, b, &conv, cancelCommand, nextStep, "") if !ended { t.Fatalf("expected the cancel command to have run") @@ -138,7 +138,7 @@ func TestBasicConversationExit(t *testing.T) { checkExpectedState(t, &conv, cancelCommand, "") // Emulate sending the "message" text, which now should not interact with the conversation. - textMessage := NewMessage(userId, chatId, "message") + textMessage := NewMessage(b, userId, chatId, "message") if conv.CheckUpdate(b, textMessage) { t.Fatalf("did not expect the internal handler to run") } @@ -177,14 +177,14 @@ func TestFallbackConversation(t *testing.T) { var chatId int64 = 1234 // Emulate sending the "start" command, triggering the entrypoint. - startCommand := NewCommandMessage(userId, chatId, "start", []string{}) + startCommand := NewCommandMessage(b, userId, chatId, "start", []string{}) runHandler(t, b, &conv, startCommand, "", nextStep) if !started { t.Fatalf("expected the entrypoint handler to have run") } // Emulate sending the "cancel" command, triggering the fallback handler (and causing it to "end"). - cancelCommand := NewCommandMessage(userId, chatId, "cancel", []string{}) + cancelCommand := NewCommandMessage(b, userId, chatId, "cancel", []string{}) runHandler(t, b, &conv, cancelCommand, nextStep, "") if !fallback { t.Fatalf("expected the fallback handler to have run") @@ -220,14 +220,14 @@ func TestReEntryConversation(t *testing.T) { var chatId int64 = 1234 // Emulate sending the "start" command, triggering the entrypoint. - startCommand := NewCommandMessage(userId, chatId, "start", []string{}) + startCommand := NewCommandMessage(b, userId, chatId, "start", []string{}) runHandler(t, b, &conv, startCommand, "", nextStep) if startCount != 1 { t.Fatalf("expected the entrypoint handler to have run") } // Send a message which matches both the entrypoint, and the "nextStep" state. - cancelCommand := NewCommandMessage(userId, chatId, "start", []string{"message"}) + cancelCommand := NewCommandMessage(b, userId, chatId, "start", []string{"message"}) runHandler(t, b, &conv, cancelCommand, nextStep, nextStep) // Should hit if startCount != 2 { t.Fatalf("expected the entrypoint handler to have run a second time") @@ -285,20 +285,20 @@ func TestNestedConversation(t *testing.T) { var chatId int64 = 1234 // Emulate sending the "start" command, triggering the entrypoint. - start := NewCommandMessage(userId, chatId, startCmd, []string{}) + start := NewCommandMessage(b, userId, chatId, startCmd, []string{}) runHandler(t, b, &conv, start, "", firstStep) // Emulate sending the "message" text, triggering the internal handler (and causing it to "end"). - textMessage := NewMessage(userId, chatId, messageText) + textMessage := NewMessage(b, userId, chatId, messageText) runHandler(t, b, &conv, textMessage, firstStep, secondStep) // Emulate sending the "nested_start" command, triggering the entrypoint of the nested conversation. - nestedStart := NewCommandMessage(userId, chatId, nestedStartCmd, []string{}) + nestedStart := NewCommandMessage(b, userId, chatId, nestedStartCmd, []string{}) willRunHandler(t, b, &nestedConv, nestedStart, "") runHandler(t, b, &conv, nestedStart, secondStep, secondStep) // Emulate sending the "nested_start" command, triggering the entrypoint of the nested conversation. - nestedFinish := NewMessage(userId, chatId, finishNestedText) + nestedFinish := NewMessage(b, userId, chatId, finishNestedText) willRunHandler(t, b, &nestedConv, nestedFinish, nestedStep) runHandler(t, b, &conv, nestedFinish, secondStep, thirdStep) @@ -307,7 +307,7 @@ func TestNestedConversation(t *testing.T) { t.Log("Nested conversation finished") // Emulate sending the "message" text, triggering the internal handler (and causing it to "end"). - finish := NewMessage(userId, chatId, finishText) + finish := NewMessage(b, userId, chatId, finishText) runHandler(t, b, &conv, finish, thirdStep, "") checkExpectedState(t, &conv, textMessage, "") @@ -329,7 +329,7 @@ func TestEmptyKeyConversation(t *testing.T) { ) // Run an empty - pollUpd := ext.NewContext(&gotgbot.Update{ + pollUpd := ext.NewContext(b, &gotgbot.Update{ UpdateId: rand.Int63(), // should this be consistent? Poll: &gotgbot.Poll{ Id: "some_id",