matrix-gptzoo/bot/conversation.go

69 lines
1.1 KiB
Go
Raw Permalink Normal View History

2023-05-17 19:51:45 +02:00
package bot
2023-05-18 09:59:01 +02:00
import (
"github.com/sashabaranov/go-openai"
"maunium.net/go/mautrix"
2023-05-18 09:59:01 +02:00
"maunium.net/go/mautrix/id"
)
2023-05-17 19:51:45 +02:00
type Character struct {
UserID string
Password string
AccessKey string
Prompt string
client *mautrix.Client
}
2023-05-17 19:51:45 +02:00
type Message struct {
2023-05-18 09:59:01 +02:00
EventID id.EventID
Role string
Content string
ParentID id.EventID
2023-05-17 19:51:45 +02:00
}
type Conversation struct {
Messages []Message
}
2023-06-08 19:18:38 +02:00
func NewConversation(id id.EventID, systemPrompt, question string) *Conversation {
2023-05-17 19:51:45 +02:00
return &Conversation{
Messages: []Message{
{
Role: openai.ChatMessageRoleSystem,
Content: systemPrompt,
},
{
2023-05-23 15:40:16 +02:00
EventID: id,
2023-05-17 19:51:45 +02:00
Role: openai.ChatMessageRoleUser,
Content: question,
},
},
}
}
2023-05-18 09:59:01 +02:00
func (c *Conversation) Contains(EventID id.EventID) bool {
2023-05-17 19:51:45 +02:00
for _, m := range c.Messages {
2023-05-18 09:59:01 +02:00
if m.EventID.String() == EventID.String() {
2023-05-17 19:51:45 +02:00
return true
}
}
return false
}
func (c *Conversation) Add(msg Message) {
c.Messages = append(c.Messages, msg)
}
type Conversations []*Conversation
2023-05-18 09:59:01 +02:00
func (cs Conversations) FindByEventID(EventID id.EventID) *Conversation {
2023-05-17 19:51:45 +02:00
for _, c := range cs {
if c.Contains(EventID) {
2023-05-18 09:59:01 +02:00
return c
2023-05-17 19:51:45 +02:00
}
}
2023-05-18 09:59:01 +02:00
return nil
2023-05-17 19:51:45 +02:00
}