diff --git a/bot/conversation.go b/bot/conversation.go index ae58f1d..3f6ce5e 100644 --- a/bot/conversation.go +++ b/bot/conversation.go @@ -18,7 +18,7 @@ type Conversation struct { Messages []Message } -func NewConversation(question string) *Conversation { +func NewConversation(id id.EventID, question string) *Conversation { return &Conversation{ Messages: []Message{ { @@ -26,6 +26,7 @@ func NewConversation(question string) *Conversation { Content: systemPrompt, }, { + EventID: id, Role: openai.ChatMessageRoleUser, Content: question, }, diff --git a/bot/conversation_test.go b/bot/conversation_test.go new file mode 100644 index 0000000..35c5bd6 --- /dev/null +++ b/bot/conversation_test.go @@ -0,0 +1,77 @@ +package bot_test + +import ( + "testing" + + "ewintr.nl/matrix-bots/bot" +) + +func TestNewConversation(t *testing.T) { + t.Parallel() + + conv := bot.NewConversation("test", "question") + if conv == nil { + t.Error("NewConversation returned nil") + } + if len(conv.Messages) != 2 { + t.Error("NewConversation did not create 2 messages") + } + if conv.Messages[1].Content != "question" { + t.Error("NewConversation did not set question") + } +} + +func TestConversation_Contains(t *testing.T) { + t.Parallel() + + for _, tc := range []struct { + name string + conv *bot.Conversation + exp bool + }{ + { + name: "empty", + conv: &bot.Conversation{}, + exp: false, + }, + { + name: "not contains", + conv: &bot.Conversation{ + Messages: []bot.Message{ + { + EventID: "other", + Content: "content", + }, + }, + }, + }, + { + name: "contains", + conv: &bot.Conversation{ + Messages: []bot.Message{ + { + EventID: "id", + Content: "content", + }, + }, + }, + exp: true, + }, + } { + t.Run(tc.name, func(t *testing.T) { + if tc.conv.Contains("id") != tc.exp { + t.Errorf("expected %v, got %v", tc.exp, tc.conv.Contains("test")) + } + }) + } +} + +func TestConversation_Add(t *testing.T) { + conv := &bot.Conversation{} + conv.Add(bot.Message{ + EventID: "id", + }) + if !conv.Contains("id") { + t.Error("Add did not add message") + } +} diff --git a/bot/matrix.go b/bot/matrix.go index 11dde80..c348ccf 100644 --- a/bot/matrix.go +++ b/bot/matrix.go @@ -116,7 +116,7 @@ func (m *Matrix) InviteHandler() (event.Type, mautrix.EventHandler) { } } -func (m *Matrix) RespondHandler() (event.Type, mautrix.EventHandler) { +func (m *Matrix) ResponseHandler() (event.Type, mautrix.EventHandler) { return event.EventMessage, func(source mautrix.EventSource, evt *event.Event) { content := evt.Content.AsMessage() eventID := evt.ID @@ -137,6 +137,7 @@ func (m *Matrix) RespondHandler() (event.Type, mautrix.EventHandler) { parentID = relatesTo.GetReplyTo() } if parentID != "" { + m.client.Log.Info().Msg("parent found, looking for conversation") conv = m.conversations.FindByEventID(parentID) } if conv != nil { @@ -148,7 +149,7 @@ func (m *Matrix) RespondHandler() (event.Type, mautrix.EventHandler) { }) m.client.Log.Info().Msg("found parent, appending message to conversation") } else { - conv = NewConversation(content.Body) + conv = NewConversation(eventID, content.Body) m.conversations = append(m.conversations, conv) m.client.Log.Info().Msg("no parent found, starting new conversation") } diff --git a/main.go b/main.go index d12c523..5d92cb3 100644 --- a/main.go +++ b/main.go @@ -29,7 +29,7 @@ func main() { go matrixClient.Run() matrixClient.AddEventHandler(matrixClient.InviteHandler()) - matrixClient.AddEventHandler(matrixClient.RespondHandler()) + matrixClient.AddEventHandler(matrixClient.ResponseHandler()) done := make(chan os.Signal) signal.Notify(done, os.Interrupt)