From 652f7a78f34cf16e21ecdee375eea6f2c09e806d Mon Sep 17 00:00:00 2001 From: Erik Winter Date: Fri, 21 Feb 2025 15:22:17 +0100 Subject: [PATCH] use ast --- llm/ollama.go | 82 +++++++++++++++++++++++++++++++++++++++++++--- main.go | 91 +++++++++++++++++++++++++++++++++++++-------------- 2 files changed, 144 insertions(+), 29 deletions(-) diff --git a/llm/ollama.go b/llm/ollama.go index ceadcdc..e99c7b2 100644 --- a/llm/ollama.go +++ b/llm/ollama.go @@ -12,10 +12,11 @@ import ( ) type CompletionRequest struct { - System string `json:"system"` - Prompt string `json:"prompt"` - Model string `json:"model"` - Streaming bool `json:"stream"` + System string `json:"system"` + Prompt string `json:"prompt"` + Model string `json:"model"` + Streaming bool `json:"stream"` + Format json.RawMessage `json:"format,omitempty"` } type CompletionResponse struct { @@ -38,13 +39,16 @@ func NewOllama(baseURL, embedModel, completeModel string) *Ollama { } } -func (o *Ollama) Complete(system, prompt string) (string, error) { +func (o *Ollama) Complete(system, prompt string, format json.RawMessage) (string, error) { url := fmt.Sprintf("%s/api/generate", o.baseURL) requestBody := CompletionRequest{ Prompt: prompt, Model: o.completeModel, System: system, } + if format != nil { + requestBody.Format = format + } jsonData, err := json.Marshal(requestBody) if err != nil { return "", fmt.Errorf("could not marshal request to json: %v", err) @@ -64,6 +68,7 @@ func (o *Ollama) Complete(system, prompt string) (string, error) { if err != nil { return "", fmt.Errorf("could not read response: %v", err) } + // fmt.Println(string(body)) var completionResponse CompletionResponse err = json.Unmarshal(body, &completionResponse) @@ -74,6 +79,73 @@ func (o *Ollama) Complete(system, prompt string) (string, error) { return completionResponse.Response, nil } +const snippetSchema = `{ + "type": "object", + "properties": { + "snippets": { + "type": "array", + "items": { + "type": "object", + "properties": { + "identifier": { + "type": "string" + }, + "kind": { + "type": "string", + "enum": ["function", "type", "constant", "variable", "other"] + }, + "lineRange": { + "type": "object", + "properties": { + "start": { + "type": "integer", + "minimum": 1 + }, + "end": { + "type": "integer", + "minimum": 1 + } + }, + "required": ["start", "end"], + "additionalProperties": false + } + }, + "required": ["identifier", "kind", "lineRange"], + "additionalProperties": false + } + } + }, + "required": ["snippets"], + "additionalProperties": false +}` + +type Snippet struct { + Identifier string `json:"identifier"` + Kind string `json:"kind"` + LineRange struct { + Start int `json:"start"` + End int `json:"end"` + } `json:"lineRange"` +} + +type SnippetCompletionResponse struct { + Snippets []Snippet `json:"snippets"` +} + +func (o *Ollama) CompleteWithSnippets(system, prompt string) ([]Snippet, error) { + resp, err := o.Complete(system, prompt, []byte(snippetSchema)) + if err != nil { + return nil, err + } + var snippetCompletionResponse SnippetCompletionResponse + err = json.Unmarshal([]byte(resp), &snippetCompletionResponse) + if err != nil { + return nil, fmt.Errorf("could not unmarshal response: %v ", err) + } + + return snippetCompletionResponse.Snippets, nil +} + func (o *Ollama) Embed(inputText string) ([]float32, error) { reqBody := map[string]interface{}{ "model": "text-embedding-3-small", diff --git a/main.go b/main.go index fec2389..965574b 100644 --- a/main.go +++ b/main.go @@ -1,41 +1,84 @@ package main import ( + "bytes" "fmt" - "log" - "os" - - "go-mod.ewintr.nl/henk/llm" + "go/ast" + "go/format" + "go/parser" + "go/token" ) -func main() { - - // startDir := "." - // err := filepath.Walk(startDir, walkFunc) - // if err != nil { - // log.Fatalf("Error walking the path: %v\n", err) - // } - ollamaClient := llm.NewOllama("http://192.168.1.12:11434", "nomic-embed-text:latest", "qwen2.5-coder:3b-instruct-q8_0") - - response, err := ollamaClient.Complete("You are a nice person.", "Say Hi!") +// printNode prints a single AST node back to Go source code. +func printNode(node ast.Node) (string, error) { + var writer bytes.Buffer + err := format.Node(&writer, token.NewFileSet(), node) if err != nil { - fmt.Println("Error:", err) - return + return "", err } - fmt.Println(response) + return writer.String(), nil } -func walkFunc(path string, info os.FileInfo, err error) error { +// walkFile walks through the AST and collects top-level declarations. +func walkFile(f *ast.File) ([]ast.Decl, error) { + var topLevelDecls []ast.Decl + + for _, decl := range f.Decls { + topLevelDecls = append(topLevelDecls, decl) + } + return topLevelDecls, nil +} + +// processGoFile processes a Go file and prints each top-level declaration. +func processGoFile(filePath string) error { + fset := token.NewFileSet() + f, err := parser.ParseFile(fset, filePath, nil, parser.AllErrors|parser.ParseComments) if err != nil { - return err + return fmt.Errorf("error parsing %s: %w", filePath, err) } - if !info.IsDir() { - data, err := os.ReadFile(path) + + topLevelDecls, err := walkFile(f) + if err != nil { + return fmt.Errorf("error walking file: %w", err) + } + + for i, decl := range topLevelDecls { + snippet, err := printNode(decl) if err != nil { - log.Printf("Error reading file %s: %v\n", path, err) - return nil + return fmt.Errorf("error printing node: %w", err) } - fmt.Printf("Contents of file %s:\n%s\n", path, string(data)) + fmt.Printf("Top-level Declaration %d:\n%s\n---\n", i+1, snippet) } + return nil } + +func main() { + filePath := "llm/ollama.go" // Replace with your Go file path + + err := processGoFile(filePath) + if err != nil { + fmt.Println(err) + } +} + +// startDir := "." +// err := filepath.Walk(startDir, walkFunc) +// if err != nil { +// log.Fatalf("Error walking the path: %v\n", err) +// } +// ollamaClient := llm.NewOllama("http://192.168.1.12:11434", "nomic-embed-text:latest", "qwen2.5-coder:32b-instruct-q8_0") + +// response, err := ollamaClient.Complete("You are a nice person.", "Say Hi!") +// if err != nil { +// fmt.Println("Error:", err) +// return +// } +// fmt.Println(response) +// prompt := fmt.Sprintf("The following is a file with Go source code. Split the code up into logical snippets. Snippets are either a function, a type, a constant or a variable. List the identifier and the line range for each snippet. Respond in JSON. \n\n Here comes the source code:\n\n```\n%s\n```", sourceDoc) +// response, err := ollamaClient.CompleteWithSnippets(systemMessage, prompt) +// if err != nil { +// fmt.Println("Error:", err) +// return +// } +// fmt.Printf("%+v\n", response)