multiple continuous conversations

This commit is contained in:
Erik Winter 2023-05-18 09:59:01 +02:00
parent 51464046ed
commit 9c77046f3b
2 changed files with 53 additions and 28 deletions

View File

@ -1,14 +1,17 @@
package bot package bot
import "github.com/sashabaranov/go-openai" import (
"github.com/sashabaranov/go-openai"
"maunium.net/go/mautrix/id"
)
const systemPrompt = "You are a chatbot that helps people by responding to their questions with short messages." const systemPrompt = "You are a chatbot that helps people by responding to their questions with short messages."
type Message struct { type Message struct {
EventID string EventID id.EventID
Role string Role string
Content string Content string
ReplyToID string ParentID id.EventID
} }
type Conversation struct { type Conversation struct {
@ -30,9 +33,9 @@ func NewConversation(question string) *Conversation {
} }
} }
func (c *Conversation) Contains(EventID string) bool { func (c *Conversation) Contains(EventID id.EventID) bool {
for _, m := range c.Messages { for _, m := range c.Messages {
if m.EventID == EventID { if m.EventID.String() == EventID.String() {
return true return true
} }
} }
@ -46,24 +49,12 @@ func (c *Conversation) Add(msg Message) {
type Conversations []*Conversation type Conversations []*Conversation
func (cs Conversations) Contains(EventID string) bool { func (cs Conversations) FindByEventID(EventID id.EventID) *Conversation {
for _, c := range cs { for _, c := range cs {
if c.Contains(EventID) { if c.Contains(EventID) {
return true return c
} }
} }
return false return nil
}
func (cs Conversations) Add(msg Message) {
for _, c := range cs {
if c.Contains(msg.EventID) {
c.Add(msg)
return
}
}
c := NewConversation(msg.Content)
cs = append(cs, c)
} }

View File

@ -6,6 +6,7 @@ import (
"github.com/chzyer/readline" "github.com/chzyer/readline"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"github.com/sashabaranov/go-openai"
"maunium.net/go/mautrix" "maunium.net/go/mautrix"
"maunium.net/go/mautrix/crypto/cryptohelper" "maunium.net/go/mautrix/crypto/cryptohelper"
"maunium.net/go/mautrix/event" "maunium.net/go/mautrix/event"
@ -118,30 +119,63 @@ func (m *Matrix) InviteHandler() (event.Type, mautrix.EventHandler) {
func (m *Matrix) RespondHandler() (event.Type, mautrix.EventHandler) { func (m *Matrix) RespondHandler() (event.Type, mautrix.EventHandler) {
return event.EventMessage, func(source mautrix.EventSource, evt *event.Event) { return event.EventMessage, func(source mautrix.EventSource, evt *event.Event) {
content := evt.Content.AsMessage() content := evt.Content.AsMessage()
eventID := evt.ID
m.client.Log.Info(). m.client.Log.Info().
Str("content", content.Body). Str("content", content.Body).
Msg("Received message") Msg("Received message")
if evt.Sender != id.UserID(m.config.UserID) { if evt.Sender != id.UserID(m.config.UserID) {
resp, err := m.gptClient.Complete(NewConversation(content.Body)) eventID := evt.ID
parentID := id.EventID("")
if relatesTo := content.GetRelatesTo(); relatesTo != nil {
parentID = relatesTo.GetReplyTo()
}
// find existing conversation and add message, or start a new one
var conv *Conversation
if parentID != "" {
conv = m.conversations.FindByEventID(parentID)
}
if conv != nil {
conv.Add(Message{
EventID: eventID,
ParentID: parentID,
Role: openai.ChatMessageRoleUser,
Content: content.Body,
})
} else {
conv = NewConversation(content.Body)
m.conversations = append(m.conversations, conv)
}
// get reply from GPT
reply, err := m.gptClient.Complete(conv)
if err != nil { if err != nil {
m.client.Log.Error().Err(err).Msg("OpenAI API returned with ") m.client.Log.Error().Err(err).Msg("OpenAI API returned with ")
return return
} }
formattedResp := format.RenderMarkdown(resp, true, false) formattedReply := format.RenderMarkdown(reply, true, false)
formattedResp.RelatesTo = &event.RelatesTo{ formattedReply.RelatesTo = &event.RelatesTo{
InReplyTo: &event.InReplyTo{ InReplyTo: &event.InReplyTo{
EventID: eventID, EventID: eventID,
}, },
} }
if _, err := m.client.SendMessageEvent(evt.RoomID, event.EventMessage, &formattedResp); err != nil { resp, err := m.client.SendMessageEvent(evt.RoomID, event.EventMessage, &formattedReply)
if err != nil {
m.client.Log.Err(err).Msg("failed to send message") m.client.Log.Err(err).Msg("failed to send message")
return return
} }
m.client.Log.Info().Str("message", fmt.Sprintf("%+v", formattedResp.Body)).Msg("Sent message") // add reply to conversation
conv.Add(Message{
EventID: resp.EventID,
Role: openai.ChatMessageRoleAssistant,
Content: reply,
ParentID: eventID,
})
m.client.Log.Info().Str("message", fmt.Sprintf("%+v", formattedReply.Body)).Msg("Sent reply")
} }
} }