From 9c77046f3be51a5ed1651fd7945ca8242899a7e7 Mon Sep 17 00:00:00 2001 From: Erik Winter Date: Thu, 18 May 2023 09:59:01 +0200 Subject: [PATCH] multiple continuous conversations --- bot/conversation.go | 35 +++++++++++++--------------------- bot/matrix.go | 46 +++++++++++++++++++++++++++++++++++++++------ 2 files changed, 53 insertions(+), 28 deletions(-) diff --git a/bot/conversation.go b/bot/conversation.go index 57da5cf..ae58f1d 100644 --- a/bot/conversation.go +++ b/bot/conversation.go @@ -1,14 +1,17 @@ package bot -import "github.com/sashabaranov/go-openai" +import ( + "github.com/sashabaranov/go-openai" + "maunium.net/go/mautrix/id" +) const systemPrompt = "You are a chatbot that helps people by responding to their questions with short messages." type Message struct { - EventID string - Role string - Content string - ReplyToID string + EventID id.EventID + Role string + Content string + ParentID id.EventID } type Conversation struct { @@ -30,9 +33,9 @@ func NewConversation(question string) *Conversation { } } -func (c *Conversation) Contains(EventID string) bool { +func (c *Conversation) Contains(EventID id.EventID) bool { for _, m := range c.Messages { - if m.EventID == EventID { + if m.EventID.String() == EventID.String() { return true } } @@ -46,24 +49,12 @@ func (c *Conversation) Add(msg Message) { type Conversations []*Conversation -func (cs Conversations) Contains(EventID string) bool { +func (cs Conversations) FindByEventID(EventID id.EventID) *Conversation { for _, c := range cs { if c.Contains(EventID) { - return true + return c } } - return false -} - -func (cs Conversations) Add(msg Message) { - for _, c := range cs { - if c.Contains(msg.EventID) { - c.Add(msg) - return - } - } - - c := NewConversation(msg.Content) - cs = append(cs, c) + return nil } diff --git a/bot/matrix.go b/bot/matrix.go index 3f2e4af..2d6a2c4 100644 --- a/bot/matrix.go +++ b/bot/matrix.go @@ -6,6 +6,7 @@ import ( "github.com/chzyer/readline" "github.com/rs/zerolog" + "github.com/sashabaranov/go-openai" "maunium.net/go/mautrix" "maunium.net/go/mautrix/crypto/cryptohelper" "maunium.net/go/mautrix/event" @@ -118,30 +119,63 @@ func (m *Matrix) InviteHandler() (event.Type, mautrix.EventHandler) { func (m *Matrix) RespondHandler() (event.Type, mautrix.EventHandler) { return event.EventMessage, func(source mautrix.EventSource, evt *event.Event) { content := evt.Content.AsMessage() - eventID := evt.ID m.client.Log.Info(). Str("content", content.Body). Msg("Received message") if evt.Sender != id.UserID(m.config.UserID) { - resp, err := m.gptClient.Complete(NewConversation(content.Body)) + eventID := evt.ID + parentID := id.EventID("") + if relatesTo := content.GetRelatesTo(); relatesTo != nil { + parentID = relatesTo.GetReplyTo() + } + + // find existing conversation and add message, or start a new one + var conv *Conversation + if parentID != "" { + conv = m.conversations.FindByEventID(parentID) + } + if conv != nil { + conv.Add(Message{ + EventID: eventID, + ParentID: parentID, + Role: openai.ChatMessageRoleUser, + Content: content.Body, + }) + + } else { + conv = NewConversation(content.Body) + m.conversations = append(m.conversations, conv) + } + + // get reply from GPT + reply, err := m.gptClient.Complete(conv) if err != nil { m.client.Log.Error().Err(err).Msg("OpenAI API returned with ") return } - formattedResp := format.RenderMarkdown(resp, true, false) - formattedResp.RelatesTo = &event.RelatesTo{ + formattedReply := format.RenderMarkdown(reply, true, false) + formattedReply.RelatesTo = &event.RelatesTo{ InReplyTo: &event.InReplyTo{ EventID: eventID, }, } - if _, err := m.client.SendMessageEvent(evt.RoomID, event.EventMessage, &formattedResp); err != nil { + resp, err := m.client.SendMessageEvent(evt.RoomID, event.EventMessage, &formattedReply) + if err != nil { m.client.Log.Err(err).Msg("failed to send message") return } - m.client.Log.Info().Str("message", fmt.Sprintf("%+v", formattedResp.Body)).Msg("Sent message") + // add reply to conversation + conv.Add(Message{ + EventID: resp.EventID, + Role: openai.ChatMessageRoleAssistant, + Content: reply, + ParentID: eventID, + }) + + m.client.Log.Info().Str("message", fmt.Sprintf("%+v", formattedReply.Body)).Msg("Sent reply") } }