2025-02-21 10:53:24 +01:00
|
|
|
package llm
|
|
|
|
|
|
|
|
import (
|
|
|
|
"bytes"
|
|
|
|
"context"
|
|
|
|
"encoding/json"
|
|
|
|
"fmt"
|
|
|
|
"io"
|
|
|
|
"net/http"
|
|
|
|
"strings"
|
|
|
|
"time"
|
|
|
|
)
|
|
|
|
|
|
|
|
type CompletionRequest struct {
|
2025-02-21 15:22:17 +01:00
|
|
|
System string `json:"system"`
|
|
|
|
Prompt string `json:"prompt"`
|
|
|
|
Model string `json:"model"`
|
|
|
|
Streaming bool `json:"stream"`
|
|
|
|
Format json.RawMessage `json:"format,omitempty"`
|
2025-02-21 10:53:24 +01:00
|
|
|
}
|
|
|
|
|
|
|
|
type CompletionResponse struct {
|
|
|
|
Response string `json:"response"`
|
|
|
|
}
|
|
|
|
|
|
|
|
type Ollama struct {
|
|
|
|
baseURL string
|
|
|
|
embedModel string
|
|
|
|
completeModel string
|
|
|
|
client *http.Client
|
|
|
|
}
|
|
|
|
|
|
|
|
func NewOllama(baseURL, embedModel, completeModel string) *Ollama {
|
|
|
|
return &Ollama{
|
|
|
|
baseURL: baseURL,
|
|
|
|
embedModel: embedModel,
|
|
|
|
completeModel: completeModel,
|
|
|
|
client: &http.Client{Timeout: 600 * time.Second},
|
|
|
|
}
|
|
|
|
}
|
|
|
|
|
2025-02-21 15:22:17 +01:00
|
|
|
func (o *Ollama) Complete(system, prompt string, format json.RawMessage) (string, error) {
|
2025-02-21 10:53:24 +01:00
|
|
|
url := fmt.Sprintf("%s/api/generate", o.baseURL)
|
|
|
|
requestBody := CompletionRequest{
|
|
|
|
Prompt: prompt,
|
|
|
|
Model: o.completeModel,
|
|
|
|
System: system,
|
|
|
|
}
|
2025-02-21 15:22:17 +01:00
|
|
|
if format != nil {
|
|
|
|
requestBody.Format = format
|
|
|
|
}
|
2025-02-21 10:53:24 +01:00
|
|
|
jsonData, err := json.Marshal(requestBody)
|
|
|
|
if err != nil {
|
|
|
|
return "", fmt.Errorf("could not marshal request to json: %v", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
resp, err := http.Post(url, "application/json", bytes.NewBuffer(jsonData))
|
|
|
|
if err != nil {
|
|
|
|
return "", fmt.Errorf("could not post request to ollama: %v", err)
|
|
|
|
}
|
|
|
|
defer resp.Body.Close()
|
|
|
|
|
|
|
|
if resp.StatusCode < 200 || resp.StatusCode >= 300 {
|
|
|
|
return "", fmt.Errorf("received non-successful status code: %d", resp.StatusCode)
|
|
|
|
}
|
|
|
|
|
|
|
|
body, err := io.ReadAll(resp.Body)
|
|
|
|
if err != nil {
|
|
|
|
return "", fmt.Errorf("could not read response: %v", err)
|
|
|
|
}
|
2025-02-21 15:22:17 +01:00
|
|
|
// fmt.Println(string(body))
|
2025-02-21 10:53:24 +01:00
|
|
|
|
|
|
|
var completionResponse CompletionResponse
|
|
|
|
err = json.Unmarshal(body, &completionResponse)
|
|
|
|
if err != nil {
|
|
|
|
return "", fmt.Errorf("could not unmarshal response: %v ", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
return completionResponse.Response, nil
|
|
|
|
}
|
|
|
|
|
2025-02-21 15:22:17 +01:00
|
|
|
const snippetSchema = `{
|
|
|
|
"type": "object",
|
|
|
|
"properties": {
|
|
|
|
"snippets": {
|
|
|
|
"type": "array",
|
|
|
|
"items": {
|
|
|
|
"type": "object",
|
|
|
|
"properties": {
|
|
|
|
"identifier": {
|
|
|
|
"type": "string"
|
|
|
|
},
|
|
|
|
"kind": {
|
|
|
|
"type": "string",
|
|
|
|
"enum": ["function", "type", "constant", "variable", "other"]
|
|
|
|
},
|
|
|
|
"lineRange": {
|
|
|
|
"type": "object",
|
|
|
|
"properties": {
|
|
|
|
"start": {
|
|
|
|
"type": "integer",
|
|
|
|
"minimum": 1
|
|
|
|
},
|
|
|
|
"end": {
|
|
|
|
"type": "integer",
|
|
|
|
"minimum": 1
|
|
|
|
}
|
|
|
|
},
|
|
|
|
"required": ["start", "end"],
|
|
|
|
"additionalProperties": false
|
|
|
|
}
|
|
|
|
},
|
|
|
|
"required": ["identifier", "kind", "lineRange"],
|
|
|
|
"additionalProperties": false
|
|
|
|
}
|
|
|
|
}
|
|
|
|
},
|
|
|
|
"required": ["snippets"],
|
|
|
|
"additionalProperties": false
|
|
|
|
}`
|
|
|
|
|
|
|
|
type Snippet struct {
|
|
|
|
Identifier string `json:"identifier"`
|
|
|
|
Kind string `json:"kind"`
|
|
|
|
LineRange struct {
|
|
|
|
Start int `json:"start"`
|
|
|
|
End int `json:"end"`
|
|
|
|
} `json:"lineRange"`
|
|
|
|
}
|
|
|
|
|
|
|
|
type SnippetCompletionResponse struct {
|
|
|
|
Snippets []Snippet `json:"snippets"`
|
|
|
|
}
|
|
|
|
|
|
|
|
func (o *Ollama) CompleteWithSnippets(system, prompt string) ([]Snippet, error) {
|
|
|
|
resp, err := o.Complete(system, prompt, []byte(snippetSchema))
|
|
|
|
if err != nil {
|
|
|
|
return nil, err
|
|
|
|
}
|
|
|
|
var snippetCompletionResponse SnippetCompletionResponse
|
|
|
|
err = json.Unmarshal([]byte(resp), &snippetCompletionResponse)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("could not unmarshal response: %v ", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
return snippetCompletionResponse.Snippets, nil
|
|
|
|
}
|
|
|
|
|
2025-02-21 10:53:24 +01:00
|
|
|
func (o *Ollama) Embed(inputText string) ([]float32, error) {
|
|
|
|
reqBody := map[string]interface{}{
|
|
|
|
"model": "text-embedding-3-small",
|
|
|
|
"input": inputText,
|
|
|
|
}
|
|
|
|
|
|
|
|
jsonBody, err := json.Marshal(reqBody)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("error marshaling request: %w", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
req, err := http.NewRequestWithContext(
|
|
|
|
context.Background(),
|
|
|
|
"POST",
|
|
|
|
o.baseURL+"/v1/embeddings",
|
|
|
|
strings.NewReader(string(jsonBody)),
|
|
|
|
)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("error creating request: %w", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
req.Header.Set("Content-Type", "application/json")
|
|
|
|
// req.Header.Set("Authorization", "Bearer "+o.apiKey)
|
|
|
|
|
|
|
|
resp, err := o.client.Do(req)
|
|
|
|
if err != nil {
|
|
|
|
return nil, fmt.Errorf("error making request: %w", err)
|
|
|
|
}
|
|
|
|
defer resp.Body.Close()
|
|
|
|
|
|
|
|
if resp.StatusCode != http.StatusOK {
|
|
|
|
body, _ := io.ReadAll(resp.Body)
|
|
|
|
return nil, fmt.Errorf("unexpected status code: %d, resp: %s", resp.StatusCode, string(body))
|
|
|
|
}
|
|
|
|
|
|
|
|
var result struct {
|
|
|
|
Data []struct {
|
|
|
|
Embedding []float32 `json:"embedding"`
|
|
|
|
} `json:"data"`
|
|
|
|
}
|
|
|
|
|
|
|
|
if err := json.NewDecoder(resp.Body).Decode(&result); err != nil {
|
|
|
|
return nil, fmt.Errorf("error decoding response: %w", err)
|
|
|
|
}
|
|
|
|
|
|
|
|
if len(result.Data) == 0 || len(result.Data[0].Embedding) == 0 {
|
|
|
|
return nil, fmt.Errorf("no embeddings returned")
|
|
|
|
}
|
|
|
|
|
|
|
|
return result.Data[0].Embedding, nil
|
|
|
|
}
|