multiple continuous conversations
This commit is contained in:
parent
51464046ed
commit
9c77046f3b
|
@ -1,14 +1,17 @@
|
||||||
package bot
|
package bot
|
||||||
|
|
||||||
import "github.com/sashabaranov/go-openai"
|
import (
|
||||||
|
"github.com/sashabaranov/go-openai"
|
||||||
|
"maunium.net/go/mautrix/id"
|
||||||
|
)
|
||||||
|
|
||||||
const systemPrompt = "You are a chatbot that helps people by responding to their questions with short messages."
|
const systemPrompt = "You are a chatbot that helps people by responding to their questions with short messages."
|
||||||
|
|
||||||
type Message struct {
|
type Message struct {
|
||||||
EventID string
|
EventID id.EventID
|
||||||
Role string
|
Role string
|
||||||
Content string
|
Content string
|
||||||
ReplyToID string
|
ParentID id.EventID
|
||||||
}
|
}
|
||||||
|
|
||||||
type Conversation struct {
|
type Conversation struct {
|
||||||
|
@ -30,9 +33,9 @@ func NewConversation(question string) *Conversation {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
func (c *Conversation) Contains(EventID string) bool {
|
func (c *Conversation) Contains(EventID id.EventID) bool {
|
||||||
for _, m := range c.Messages {
|
for _, m := range c.Messages {
|
||||||
if m.EventID == EventID {
|
if m.EventID.String() == EventID.String() {
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
@ -46,24 +49,12 @@ func (c *Conversation) Add(msg Message) {
|
||||||
|
|
||||||
type Conversations []*Conversation
|
type Conversations []*Conversation
|
||||||
|
|
||||||
func (cs Conversations) Contains(EventID string) bool {
|
func (cs Conversations) FindByEventID(EventID id.EventID) *Conversation {
|
||||||
for _, c := range cs {
|
for _, c := range cs {
|
||||||
if c.Contains(EventID) {
|
if c.Contains(EventID) {
|
||||||
return true
|
return c
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
return false
|
return nil
|
||||||
}
|
|
||||||
|
|
||||||
func (cs Conversations) Add(msg Message) {
|
|
||||||
for _, c := range cs {
|
|
||||||
if c.Contains(msg.EventID) {
|
|
||||||
c.Add(msg)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
c := NewConversation(msg.Content)
|
|
||||||
cs = append(cs, c)
|
|
||||||
}
|
}
|
||||||
|
|
|
@ -6,6 +6,7 @@ import (
|
||||||
|
|
||||||
"github.com/chzyer/readline"
|
"github.com/chzyer/readline"
|
||||||
"github.com/rs/zerolog"
|
"github.com/rs/zerolog"
|
||||||
|
"github.com/sashabaranov/go-openai"
|
||||||
"maunium.net/go/mautrix"
|
"maunium.net/go/mautrix"
|
||||||
"maunium.net/go/mautrix/crypto/cryptohelper"
|
"maunium.net/go/mautrix/crypto/cryptohelper"
|
||||||
"maunium.net/go/mautrix/event"
|
"maunium.net/go/mautrix/event"
|
||||||
|
@ -118,30 +119,63 @@ func (m *Matrix) InviteHandler() (event.Type, mautrix.EventHandler) {
|
||||||
func (m *Matrix) RespondHandler() (event.Type, mautrix.EventHandler) {
|
func (m *Matrix) RespondHandler() (event.Type, mautrix.EventHandler) {
|
||||||
return event.EventMessage, func(source mautrix.EventSource, evt *event.Event) {
|
return event.EventMessage, func(source mautrix.EventSource, evt *event.Event) {
|
||||||
content := evt.Content.AsMessage()
|
content := evt.Content.AsMessage()
|
||||||
eventID := evt.ID
|
|
||||||
m.client.Log.Info().
|
m.client.Log.Info().
|
||||||
Str("content", content.Body).
|
Str("content", content.Body).
|
||||||
Msg("Received message")
|
Msg("Received message")
|
||||||
|
|
||||||
if evt.Sender != id.UserID(m.config.UserID) {
|
if evt.Sender != id.UserID(m.config.UserID) {
|
||||||
resp, err := m.gptClient.Complete(NewConversation(content.Body))
|
eventID := evt.ID
|
||||||
|
parentID := id.EventID("")
|
||||||
|
if relatesTo := content.GetRelatesTo(); relatesTo != nil {
|
||||||
|
parentID = relatesTo.GetReplyTo()
|
||||||
|
}
|
||||||
|
|
||||||
|
// find existing conversation and add message, or start a new one
|
||||||
|
var conv *Conversation
|
||||||
|
if parentID != "" {
|
||||||
|
conv = m.conversations.FindByEventID(parentID)
|
||||||
|
}
|
||||||
|
if conv != nil {
|
||||||
|
conv.Add(Message{
|
||||||
|
EventID: eventID,
|
||||||
|
ParentID: parentID,
|
||||||
|
Role: openai.ChatMessageRoleUser,
|
||||||
|
Content: content.Body,
|
||||||
|
})
|
||||||
|
|
||||||
|
} else {
|
||||||
|
conv = NewConversation(content.Body)
|
||||||
|
m.conversations = append(m.conversations, conv)
|
||||||
|
}
|
||||||
|
|
||||||
|
// get reply from GPT
|
||||||
|
reply, err := m.gptClient.Complete(conv)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
m.client.Log.Error().Err(err).Msg("OpenAI API returned with ")
|
m.client.Log.Error().Err(err).Msg("OpenAI API returned with ")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
formattedResp := format.RenderMarkdown(resp, true, false)
|
formattedReply := format.RenderMarkdown(reply, true, false)
|
||||||
formattedResp.RelatesTo = &event.RelatesTo{
|
formattedReply.RelatesTo = &event.RelatesTo{
|
||||||
InReplyTo: &event.InReplyTo{
|
InReplyTo: &event.InReplyTo{
|
||||||
EventID: eventID,
|
EventID: eventID,
|
||||||
},
|
},
|
||||||
}
|
}
|
||||||
if _, err := m.client.SendMessageEvent(evt.RoomID, event.EventMessage, &formattedResp); err != nil {
|
resp, err := m.client.SendMessageEvent(evt.RoomID, event.EventMessage, &formattedReply)
|
||||||
|
if err != nil {
|
||||||
m.client.Log.Err(err).Msg("failed to send message")
|
m.client.Log.Err(err).Msg("failed to send message")
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
m.client.Log.Info().Str("message", fmt.Sprintf("%+v", formattedResp.Body)).Msg("Sent message")
|
// add reply to conversation
|
||||||
|
conv.Add(Message{
|
||||||
|
EventID: resp.EventID,
|
||||||
|
Role: openai.ChatMessageRoleAssistant,
|
||||||
|
Content: reply,
|
||||||
|
ParentID: eventID,
|
||||||
|
})
|
||||||
|
|
||||||
|
m.client.Log.Info().Str("message", fmt.Sprintf("%+v", formattedReply.Body)).Msg("Sent reply")
|
||||||
|
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
Loading…
Reference in New Issue