transactions

This commit is contained in:
Erik Winter 2025-01-13 09:13:48 +01:00
parent 428b828c45
commit 7f4d870311
23 changed files with 278 additions and 171 deletions

BIN
dist/plan vendored

Binary file not shown.

View File

@ -8,6 +8,7 @@ import (
"github.com/google/uuid" "github.com/google/uuid"
"go-mod.ewintr.nl/planner/item" "go-mod.ewintr.nl/planner/item"
"go-mod.ewintr.nl/planner/sync/client"
) )
type AddArgs struct { type AddArgs struct {
@ -92,16 +93,22 @@ type Add struct {
Args AddArgs Args AddArgs
} }
func (a Add) Do(deps Dependencies) (CommandResult, error) { func (a Add) Do(repos Repositories, _ client.Client) (CommandResult, error) {
if err := deps.TaskRepo.Store(a.Args.Task); err != nil { tx, err := repos.Begin()
if err != nil {
return nil, fmt.Errorf("could not start transaction: %v", err)
}
defer tx.Rollback()
if err := repos.Task(tx).Store(a.Args.Task); err != nil {
return nil, fmt.Errorf("could not store event: %v", err) return nil, fmt.Errorf("could not store event: %v", err)
} }
localID, err := deps.LocalIDRepo.Next() localID, err := repos.LocalID(tx).Next()
if err != nil { if err != nil {
return nil, fmt.Errorf("could not create next local id: %v", err) return nil, fmt.Errorf("could not create next local id: %v", err)
} }
if err := deps.LocalIDRepo.Store(a.Args.Task.ID, localID); err != nil { if err := repos.LocalID(tx).Store(a.Args.Task.ID, localID); err != nil {
return nil, fmt.Errorf("could not store local id: %v", err) return nil, fmt.Errorf("could not store local id: %v", err)
} }
@ -109,10 +116,14 @@ func (a Add) Do(deps Dependencies) (CommandResult, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("could not convert event to sync item: %v", err) return nil, fmt.Errorf("could not convert event to sync item: %v", err)
} }
if err := deps.SyncRepo.Store(it); err != nil { if err := repos.Sync(tx).Store(it); err != nil {
return nil, fmt.Errorf("could not store sync item: %v", err) return nil, fmt.Errorf("could not store sync item: %v", err)
} }
if err := tx.Commit(); err != nil {
return nil, fmt.Errorf("could not add task: %v", err)
}
return AddRender{}, nil return AddRender{}, nil
} }

View File

@ -60,9 +60,7 @@ func TestAdd(t *testing.T) {
} { } {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
// setup // setup
taskRepo := memory.NewTask() mems := memory.New()
localIDRepo := memory.NewLocalID()
syncRepo := memory.NewSync()
// parse // parse
cmd, actParseErr := command.NewAddArgs().Parse(tc.main, tc.fields) cmd, actParseErr := command.NewAddArgs().Parse(tc.main, tc.fields)
@ -74,16 +72,12 @@ func TestAdd(t *testing.T) {
} }
// do // do
if _, err := cmd.Do(command.Dependencies{ if _, err := cmd.Do(mems, nil); err != nil {
TaskRepo: taskRepo,
LocalIDRepo: localIDRepo,
SyncRepo: syncRepo,
}); err != nil {
t.Errorf("exp nil, got %v", err) t.Errorf("exp nil, got %v", err)
} }
// check // check
actTasks, err := taskRepo.FindMany(storage.TaskListParams{}) actTasks, err := mems.Task(nil).FindMany(storage.TaskListParams{})
if err != nil { if err != nil {
t.Errorf("exp nil, got %v", err) t.Errorf("exp nil, got %v", err)
} }
@ -91,7 +85,7 @@ func TestAdd(t *testing.T) {
t.Errorf("exp 1, got %d", len(actTasks)) t.Errorf("exp 1, got %d", len(actTasks))
} }
actLocalIDs, err := localIDRepo.FindAll() actLocalIDs, err := mems.LocalID(nil).FindAll()
if err != nil { if err != nil {
t.Errorf("exp nil, got %v", err) t.Errorf("exp nil, got %v", err)
} }
@ -110,7 +104,7 @@ func TestAdd(t *testing.T) {
t.Errorf("(exp -, got +)\n%s", diff) t.Errorf("(exp -, got +)\n%s", diff)
} }
updated, err := syncRepo.FindAll() updated, err := mems.Sync(nil).FindAll()
if err != nil { if err != nil {
t.Errorf("exp nil, got %v", err) t.Errorf("exp nil, got %v", err)
} }

View File

@ -20,11 +20,11 @@ var (
ErrInvalidArg = errors.New("invalid argument") ErrInvalidArg = errors.New("invalid argument")
) )
type Dependencies struct { type Repositories interface {
LocalIDRepo storage.LocalID Begin() (*storage.Tx, error)
TaskRepo storage.Task LocalID(tx *storage.Tx) storage.LocalID
SyncRepo storage.Sync Sync(tx *storage.Tx) storage.Sync
SyncClient client.Client Task(tx *storage.Tx) storage.Task
} }
type CommandArgs interface { type CommandArgs interface {
@ -32,7 +32,7 @@ type CommandArgs interface {
} }
type Command interface { type Command interface {
Do(deps Dependencies) (CommandResult, error) Do(repos Repositories, client client.Client) (CommandResult, error)
} }
type CommandResult interface { type CommandResult interface {
@ -40,13 +40,15 @@ type CommandResult interface {
} }
type CLI struct { type CLI struct {
deps Dependencies repos Repositories
client client.Client
cmdArgs []CommandArgs cmdArgs []CommandArgs
} }
func NewCLI(deps Dependencies) *CLI { func NewCLI(repos Repositories, client client.Client) *CLI {
return &CLI{ return &CLI{
deps: deps, repos: repos,
client: client,
cmdArgs: []CommandArgs{ cmdArgs: []CommandArgs{
NewShowArgs(), NewProjectsArgs(), NewShowArgs(), NewProjectsArgs(),
NewAddArgs(), NewDeleteArgs(), NewListArgs(), NewAddArgs(), NewDeleteArgs(), NewListArgs(),
@ -66,7 +68,7 @@ func (cli *CLI) Run(args []string) error {
return err return err
} }
result, err := cmd.Do(cli.deps) result, err := cmd.Do(cli.repos, cli.client)
if err != nil { if err != nil {
return err return err
} }

View File

@ -4,6 +4,8 @@ import (
"fmt" "fmt"
"slices" "slices"
"strconv" "strconv"
"go-mod.ewintr.nl/planner/sync/client"
) )
type DeleteArgs struct { type DeleteArgs struct {
@ -45,9 +47,15 @@ type Delete struct {
Args DeleteArgs Args DeleteArgs
} }
func (del Delete) Do(deps Dependencies) (CommandResult, error) { func (del Delete) Do(repos Repositories, _ client.Client) (CommandResult, error) {
tx, err := repos.Begin()
if err != nil {
return nil, fmt.Errorf("could not start transaction: %v", err)
}
defer tx.Rollback()
var id string var id string
idMap, err := deps.LocalIDRepo.FindAll() idMap, err := repos.LocalID(tx).FindAll()
if err != nil { if err != nil {
return nil, fmt.Errorf("could not get local ids: %v", err) return nil, fmt.Errorf("could not get local ids: %v", err)
} }
@ -60,7 +68,7 @@ func (del Delete) Do(deps Dependencies) (CommandResult, error) {
return nil, fmt.Errorf("could not find local id") return nil, fmt.Errorf("could not find local id")
} }
tsk, err := deps.TaskRepo.FindOne(id) tsk, err := repos.Task(tx).FindOne(id)
if err != nil { if err != nil {
return nil, fmt.Errorf("could not get task: %v", err) return nil, fmt.Errorf("could not get task: %v", err)
} }
@ -70,15 +78,15 @@ func (del Delete) Do(deps Dependencies) (CommandResult, error) {
return nil, fmt.Errorf("could not convert task to sync item: %v", err) return nil, fmt.Errorf("could not convert task to sync item: %v", err)
} }
it.Deleted = true it.Deleted = true
if err := deps.SyncRepo.Store(it); err != nil { if err := repos.Sync(tx).Store(it); err != nil {
return nil, fmt.Errorf("could not store sync item: %v", err) return nil, fmt.Errorf("could not store sync item: %v", err)
} }
if err := deps.LocalIDRepo.Delete(id); err != nil { if err := repos.LocalID(tx).Delete(id); err != nil {
return nil, fmt.Errorf("could not delete local id: %v", err) return nil, fmt.Errorf("could not delete local id: %v", err)
} }
if err := deps.TaskRepo.Delete(id); err != nil { if err := repos.Task(tx).Delete(id); err != nil {
return nil, fmt.Errorf("could not delete task: %v", err) return nil, fmt.Errorf("could not delete task: %v", err)
} }

View File

@ -49,13 +49,12 @@ func TestDelete(t *testing.T) {
} { } {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
// setup // setup
taskRepo := memory.NewTask() mems := memory.New()
syncRepo := memory.NewSync()
if err := taskRepo.Store(e); err != nil { if err := mems.Task(nil).Store(e); err != nil {
t.Errorf("exp nil, got %v", err) t.Errorf("exp nil, got %v", err)
} }
localIDRepo := memory.NewLocalID() if err := mems.LocalID(nil).Store(e.ID, 1); err != nil {
if err := localIDRepo.Store(e.ID, 1); err != nil {
t.Errorf("exp nil, got %v", err) t.Errorf("exp nil, got %v", err)
} }
@ -69,11 +68,7 @@ func TestDelete(t *testing.T) {
} }
// do // do
_, actDoErr := cmd.Do(command.Dependencies{ _, actDoErr := cmd.Do(mems, nil)
TaskRepo: taskRepo,
LocalIDRepo: localIDRepo,
SyncRepo: syncRepo,
})
if tc.expDoErr != (actDoErr != nil) { if tc.expDoErr != (actDoErr != nil) {
t.Errorf("exp false, got %v", actDoErr) t.Errorf("exp false, got %v", actDoErr)
} }
@ -82,18 +77,18 @@ func TestDelete(t *testing.T) {
} }
// check // check
_, repoErr := taskRepo.FindOne(e.ID) _, repoErr := mems.Task(nil).FindOne(e.ID)
if !errors.Is(repoErr, storage.ErrNotFound) { if !errors.Is(repoErr, storage.ErrNotFound) {
t.Errorf("exp %v, got %v", storage.ErrNotFound, repoErr) t.Errorf("exp %v, got %v", storage.ErrNotFound, repoErr)
} }
idMap, idErr := localIDRepo.FindAll() idMap, idErr := mems.LocalID(nil).FindAll()
if idErr != nil { if idErr != nil {
t.Errorf("exp nil, got %v", idErr) t.Errorf("exp nil, got %v", idErr)
} }
if len(idMap) != 0 { if len(idMap) != 0 {
t.Errorf("exp 0, got %v", len(idMap)) t.Errorf("exp 0, got %v", len(idMap))
} }
updated, err := syncRepo.FindAll() updated, err := mems.Sync(nil).FindAll()
if err != nil { if err != nil {
t.Errorf("exp nil, got %v", err) t.Errorf("exp nil, got %v", err)
} }

View File

@ -9,6 +9,7 @@ import (
"go-mod.ewintr.nl/planner/item" "go-mod.ewintr.nl/planner/item"
"go-mod.ewintr.nl/planner/plan/format" "go-mod.ewintr.nl/planner/plan/format"
"go-mod.ewintr.nl/planner/plan/storage" "go-mod.ewintr.nl/planner/plan/storage"
"go-mod.ewintr.nl/planner/sync/client"
) )
type ListArgs struct { type ListArgs struct {
@ -96,12 +97,18 @@ type List struct {
Args ListArgs Args ListArgs
} }
func (list List) Do(deps Dependencies) (CommandResult, error) { func (list List) Do(repos Repositories, _ client.Client) (CommandResult, error) {
localIDs, err := deps.LocalIDRepo.FindAll() tx, err := repos.Begin()
if err != nil {
return nil, fmt.Errorf("could not start transaction: %v", err)
}
defer tx.Rollback()
localIDs, err := repos.LocalID(tx).FindAll()
if err != nil { if err != nil {
return nil, fmt.Errorf("could not get local ids: %v", err) return nil, fmt.Errorf("could not get local ids: %v", err)
} }
all, err := deps.TaskRepo.FindMany(storage.TaskListParams{ all, err := repos.Task(tx).FindMany(storage.TaskListParams{
HasRecurrer: list.Args.HasRecurrer, HasRecurrer: list.Args.HasRecurrer,
From: list.Args.From, From: list.Args.From,
To: list.Args.To, To: list.Args.To,

View File

@ -79,8 +79,8 @@ func TestListParse(t *testing.T) {
func TestList(t *testing.T) { func TestList(t *testing.T) {
t.Parallel() t.Parallel()
taskRepo := memory.NewTask() mems := memory.New()
localRepo := memory.NewLocalID()
e := item.Task{ e := item.Task{
ID: "id", ID: "id",
Date: item.NewDate(2024, 10, 7), Date: item.NewDate(2024, 10, 7),
@ -88,10 +88,10 @@ func TestList(t *testing.T) {
Title: "name", Title: "name",
}, },
} }
if err := taskRepo.Store(e); err != nil { if err := mems.Task(nil).Store(e); err != nil {
t.Errorf("exp nil, got %v", err) t.Errorf("exp nil, got %v", err)
} }
if err := localRepo.Store(e.ID, 1); err != nil { if err := mems.LocalID(nil).Store(e.ID, 1); err != nil {
t.Errorf("exp nil, got %v", err) t.Errorf("exp nil, got %v", err)
} }
@ -115,10 +115,7 @@ func TestList(t *testing.T) {
}, },
} { } {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
res, err := tc.cmd.Do(command.Dependencies{ res, err := tc.cmd.Do(mems, nil)
TaskRepo: taskRepo,
LocalIDRepo: localRepo,
})
if err != nil { if err != nil {
t.Errorf("exp nil, got %v", err) t.Errorf("exp nil, got %v", err)
} }

View File

@ -5,6 +5,7 @@ import (
"sort" "sort"
"go-mod.ewintr.nl/planner/plan/format" "go-mod.ewintr.nl/planner/plan/format"
"go-mod.ewintr.nl/planner/sync/client"
) )
type ProjectsArgs struct{} type ProjectsArgs struct{}
@ -23,8 +24,14 @@ func (pa ProjectsArgs) Parse(main []string, fields map[string]string) (Command,
type Projects struct{} type Projects struct{}
func (ps Projects) Do(deps Dependencies) (CommandResult, error) { func (ps Projects) Do(repos Repositories, _ client.Client) (CommandResult, error) {
projects, err := deps.TaskRepo.Projects() tx, err := repos.Begin()
if err != nil {
return nil, fmt.Errorf("could not start transaction: %v", err)
}
defer tx.Rollback()
projects, err := repos.Task(tx).Projects()
if err != nil { if err != nil {
return nil, fmt.Errorf("could not find projects: %v", err) return nil, fmt.Errorf("could not find projects: %v", err)
} }

View File

@ -8,6 +8,7 @@ import (
"go-mod.ewintr.nl/planner/item" "go-mod.ewintr.nl/planner/item"
"go-mod.ewintr.nl/planner/plan/format" "go-mod.ewintr.nl/planner/plan/format"
"go-mod.ewintr.nl/planner/plan/storage" "go-mod.ewintr.nl/planner/plan/storage"
"go-mod.ewintr.nl/planner/sync/client"
) )
type ShowArgs struct { type ShowArgs struct {
@ -38,8 +39,14 @@ type Show struct {
args ShowArgs args ShowArgs
} }
func (s Show) Do(deps Dependencies) (CommandResult, error) { func (s Show) Do(repos Repositories, _ client.Client) (CommandResult, error) {
id, err := deps.LocalIDRepo.FindOne(s.args.localID) tx, err := repos.Begin()
if err != nil {
return nil, fmt.Errorf("could not start transaction: %v", err)
}
defer tx.Rollback()
id, err := repos.LocalID(tx).FindOne(s.args.localID)
switch { switch {
case errors.Is(err, storage.ErrNotFound): case errors.Is(err, storage.ErrNotFound):
return nil, fmt.Errorf("could not find local id") return nil, fmt.Errorf("could not find local id")
@ -47,7 +54,7 @@ func (s Show) Do(deps Dependencies) (CommandResult, error) {
return nil, err return nil, err
} }
tsk, err := deps.TaskRepo.FindOne(id) tsk, err := repos.Task(tx).FindOne(id)
if err != nil { if err != nil {
return nil, fmt.Errorf("could not find task") return nil, fmt.Errorf("could not find task")
} }

View File

@ -12,8 +12,8 @@ import (
func TestShow(t *testing.T) { func TestShow(t *testing.T) {
t.Parallel() t.Parallel()
taskRepo := memory.NewTask() mems := memory.New()
localRepo := memory.NewLocalID()
tsk := item.Task{ tsk := item.Task{
ID: "id", ID: "id",
Date: item.NewDate(2024, 10, 7), Date: item.NewDate(2024, 10, 7),
@ -21,10 +21,10 @@ func TestShow(t *testing.T) {
Title: "name", Title: "name",
}, },
} }
if err := taskRepo.Store(tsk); err != nil { if err := mems.Task(nil).Store(tsk); err != nil {
t.Errorf("exp nil, got %v", err) t.Errorf("exp nil, got %v", err)
} }
if err := localRepo.Store(tsk.ID, 1); err != nil { if err := mems.LocalID(nil).Store(tsk.ID, 1); err != nil {
t.Errorf("exp nil, got %v", err) t.Errorf("exp nil, got %v", err)
} }
@ -69,10 +69,7 @@ func TestShow(t *testing.T) {
} }
// do // do
_, actDoErr := cmd.Do(command.Dependencies{ _, actDoErr := cmd.Do(mems, nil)
TaskRepo: taskRepo,
LocalIDRepo: localRepo,
})
if tc.expDoErr != (actDoErr != nil) { if tc.expDoErr != (actDoErr != nil) {
t.Errorf("exp %v, got %v", tc.expDoErr, actDoErr != nil) t.Errorf("exp %v, got %v", tc.expDoErr, actDoErr != nil)
} }

View File

@ -8,6 +8,7 @@ import (
"go-mod.ewintr.nl/planner/item" "go-mod.ewintr.nl/planner/item"
"go-mod.ewintr.nl/planner/plan/storage" "go-mod.ewintr.nl/planner/plan/storage"
"go-mod.ewintr.nl/planner/sync/client"
) )
type SyncArgs struct{} type SyncArgs struct{}
@ -26,25 +27,31 @@ func (sa SyncArgs) Parse(main []string, flags map[string]string) (Command, error
type Sync struct{} type Sync struct{}
func (s Sync) Do(deps Dependencies) (CommandResult, error) { func (s Sync) Do(repos Repositories, client client.Client) (CommandResult, error) {
tx, err := repos.Begin()
if err != nil {
return nil, fmt.Errorf("could not start transaction: %v", err)
}
defer tx.Rollback()
// local new and updated // local new and updated
sendItems, err := deps.SyncRepo.FindAll() sendItems, err := repos.Sync(tx).FindAll()
if err != nil { if err != nil {
return nil, fmt.Errorf("could not get updated items: %v", err) return nil, fmt.Errorf("could not get updated items: %v", err)
} }
if err := deps.SyncClient.Update(sendItems); err != nil { if err := client.Update(sendItems); err != nil {
return nil, fmt.Errorf("could not send updated items: %v", err) return nil, fmt.Errorf("could not send updated items: %v", err)
} }
if err := deps.SyncRepo.DeleteAll(); err != nil { if err := repos.Sync(tx).DeleteAll(); err != nil {
return nil, fmt.Errorf("could not clear updated items: %v", err) return nil, fmt.Errorf("could not clear updated items: %v", err)
} }
// get new/updated items // get new/updated items
oldTS, err := deps.SyncRepo.LastUpdate() oldTS, err := repos.Sync(tx).LastUpdate()
if err != nil { if err != nil {
return nil, fmt.Errorf("could not find timestamp of last update: %v", err) return nil, fmt.Errorf("could not find timestamp of last update: %v", err)
} }
recItems, err := deps.SyncClient.Updated([]item.Kind{item.KindTask}, oldTS) recItems, err := client.Updated([]item.Kind{item.KindTask}, oldTS)
if err != nil { if err != nil {
return nil, fmt.Errorf("could not receive updates: %v", err) return nil, fmt.Errorf("could not receive updates: %v", err)
} }
@ -56,10 +63,10 @@ func (s Sync) Do(deps Dependencies) (CommandResult, error) {
newTS = ri.Updated newTS = ri.Updated
} }
if ri.Deleted { if ri.Deleted {
if err := deps.LocalIDRepo.Delete(ri.ID); err != nil && !errors.Is(err, storage.ErrNotFound) { if err := repos.LocalID(tx).Delete(ri.ID); err != nil && !errors.Is(err, storage.ErrNotFound) {
return nil, fmt.Errorf("could not delete local id: %v", err) return nil, fmt.Errorf("could not delete local id: %v", err)
} }
if err := deps.TaskRepo.Delete(ri.ID); err != nil && !errors.Is(err, storage.ErrNotFound) { if err := repos.Task(tx).Delete(ri.ID); err != nil && !errors.Is(err, storage.ErrNotFound) {
return nil, fmt.Errorf("could not delete task: %v", err) return nil, fmt.Errorf("could not delete task: %v", err)
} }
continue continue
@ -67,7 +74,7 @@ func (s Sync) Do(deps Dependencies) (CommandResult, error) {
updated = append(updated, ri) updated = append(updated, ri)
} }
lidMap, err := deps.LocalIDRepo.FindAll() lidMap, err := repos.LocalID(tx).FindAll()
if err != nil { if err != nil {
return nil, fmt.Errorf("could not get local ids: %v", err) return nil, fmt.Errorf("could not get local ids: %v", err)
} }
@ -83,23 +90,23 @@ func (s Sync) Do(deps Dependencies) (CommandResult, error) {
RecurNext: u.RecurNext, RecurNext: u.RecurNext,
TaskBody: tskBody, TaskBody: tskBody,
} }
if err := deps.TaskRepo.Store(tsk); err != nil { if err := repos.Task(tx).Store(tsk); err != nil {
return nil, fmt.Errorf("could not store task: %v", err) return nil, fmt.Errorf("could not store task: %v", err)
} }
lid, ok := lidMap[u.ID] lid, ok := lidMap[u.ID]
if !ok { if !ok {
lid, err = deps.LocalIDRepo.Next() lid, err = repos.LocalID(tx).Next()
if err != nil { if err != nil {
return nil, fmt.Errorf("could not get next local id: %v", err) return nil, fmt.Errorf("could not get next local id: %v", err)
} }
if err := deps.LocalIDRepo.Store(u.ID, lid); err != nil { if err := repos.LocalID(tx).Store(u.ID, lid); err != nil {
return nil, fmt.Errorf("could not store local id: %v", err) return nil, fmt.Errorf("could not store local id: %v", err)
} }
} }
} }
if err := deps.SyncRepo.SetLastUpdate(newTS); err != nil { if err := repos.Sync(tx).SetLastUpdate(newTS); err != nil {
return nil, fmt.Errorf("could not store update timestamp: %v", err) return nil, fmt.Errorf("could not store update timestamp: %v", err)
} }

View File

@ -47,9 +47,7 @@ func TestSyncSend(t *testing.T) {
t.Parallel() t.Parallel()
syncClient := client.NewMemory() syncClient := client.NewMemory()
syncRepo := memory.NewSync() mems := memory.New()
localIDRepo := memory.NewLocalID()
taskRepo := memory.NewTask()
it := item.Item{ it := item.Item{
ID: "a", ID: "a",
@ -60,7 +58,7 @@ func TestSyncSend(t *testing.T) {
"duration":"1h" "duration":"1h"
}`, }`,
} }
if err := syncRepo.Store(it); err != nil { if err := mems.Sync(nil).Store(it); err != nil {
t.Errorf("exp nil, got %v", err) t.Errorf("exp nil, got %v", err)
} }
@ -81,12 +79,7 @@ func TestSyncSend(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("exp nil, got %v", err) t.Errorf("exp nil, got %v", err)
} }
if _, err := cmd.Do(command.Dependencies{ if _, err := cmd.Do(mems, syncClient); err != nil {
TaskRepo: taskRepo,
LocalIDRepo: localIDRepo,
SyncRepo: syncRepo,
SyncClient: syncClient,
}); err != nil {
t.Errorf("exp nil, got %v", err) t.Errorf("exp nil, got %v", err)
} }
actItems, actErr := syncClient.Updated(tc.ks, tc.ts) actItems, actErr := syncClient.Updated(tc.ks, tc.ts)
@ -97,7 +90,7 @@ func TestSyncSend(t *testing.T) {
t.Errorf("(exp +, got -)\n%s", diff) t.Errorf("(exp +, got -)\n%s", diff)
} }
actLeft, actErr := syncRepo.FindAll() actLeft, actErr := mems.Sync(nil).FindAll()
if actErr != nil { if actErr != nil {
t.Errorf("exp nil, got %v", actErr) t.Errorf("exp nil, got %v", actErr)
} }
@ -185,16 +178,14 @@ func TestSyncReceive(t *testing.T) {
} { } {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
syncClient := client.NewMemory() syncClient := client.NewMemory()
syncRepo := memory.NewSync() mems := memory.New()
localIDRepo := memory.NewLocalID()
taskRepo := memory.NewTask()
// setup // setup
for i, p := range tc.present { for i, p := range tc.present {
if err := taskRepo.Store(p); err != nil { if err := mems.Task(nil).Store(p); err != nil {
t.Errorf("exp nil, got %v", err) t.Errorf("exp nil, got %v", err)
} }
if err := localIDRepo.Store(p.ID, i+1); err != nil { if err := mems.LocalID(nil).Store(p.ID, i+1); err != nil {
t.Errorf("exp nil, got %v", err) t.Errorf("exp nil, got %v", err)
} }
} }
@ -207,24 +198,19 @@ func TestSyncReceive(t *testing.T) {
if err != nil { if err != nil {
t.Errorf("exp nil, got %v", err) t.Errorf("exp nil, got %v", err)
} }
if _, err := cmd.Do(command.Dependencies{ if _, err := cmd.Do(mems, syncClient); err != nil {
TaskRepo: taskRepo,
LocalIDRepo: localIDRepo,
SyncRepo: syncRepo,
SyncClient: syncClient,
}); err != nil {
t.Errorf("exp nil, got %v", err) t.Errorf("exp nil, got %v", err)
} }
// check result // check result
actTasks, err := taskRepo.FindMany(storage.TaskListParams{}) actTasks, err := mems.Task(nil).FindMany(storage.TaskListParams{})
if err != nil { if err != nil {
t.Errorf("exp nil, got %v", err) t.Errorf("exp nil, got %v", err)
} }
if diff := item.TaskDiffs(tc.expTask, actTasks); diff != "" { if diff := item.TaskDiffs(tc.expTask, actTasks); diff != "" {
t.Errorf("(exp +, got -)\n%s", diff) t.Errorf("(exp +, got -)\n%s", diff)
} }
actLocalIDs, err := localIDRepo.FindAll() actLocalIDs, err := mems.LocalID(nil).FindAll()
if err != nil { if err != nil {
t.Errorf("exp nil, got %v", err) t.Errorf("exp nil, got %v", err)
} }

View File

@ -10,6 +10,7 @@ import (
"go-mod.ewintr.nl/planner/item" "go-mod.ewintr.nl/planner/item"
"go-mod.ewintr.nl/planner/plan/storage" "go-mod.ewintr.nl/planner/plan/storage"
"go-mod.ewintr.nl/planner/sync/client"
) )
type UpdateArgs struct { type UpdateArgs struct {
@ -116,8 +117,14 @@ type Update struct {
args UpdateArgs args UpdateArgs
} }
func (u Update) Do(deps Dependencies) (CommandResult, error) { func (u Update) Do(repos Repositories, _ client.Client) (CommandResult, error) {
id, err := deps.LocalIDRepo.FindOne(u.args.LocalID) tx, err := repos.Begin()
if err != nil {
return nil, fmt.Errorf("could not start transaction: %v", err)
}
defer tx.Rollback()
id, err := repos.LocalID(tx).FindOne(u.args.LocalID)
switch { switch {
case errors.Is(err, storage.ErrNotFound): case errors.Is(err, storage.ErrNotFound):
return nil, fmt.Errorf("could not find local id") return nil, fmt.Errorf("could not find local id")
@ -125,7 +132,7 @@ func (u Update) Do(deps Dependencies) (CommandResult, error) {
return nil, err return nil, err
} }
tsk, err := deps.TaskRepo.FindOne(id) tsk, err := repos.Task(tx).FindOne(id)
if err != nil { if err != nil {
return nil, fmt.Errorf("could not find task") return nil, fmt.Errorf("could not find task")
} }
@ -154,7 +161,7 @@ func (u Update) Do(deps Dependencies) (CommandResult, error) {
return nil, fmt.Errorf("task is unvalid") return nil, fmt.Errorf("task is unvalid")
} }
if err := deps.TaskRepo.Store(tsk); err != nil { if err := repos.Task(tx).Store(tsk); err != nil {
return nil, fmt.Errorf("could not store task: %v", err) return nil, fmt.Errorf("could not store task: %v", err)
} }
@ -162,7 +169,7 @@ func (u Update) Do(deps Dependencies) (CommandResult, error) {
if err != nil { if err != nil {
return nil, fmt.Errorf("could not convert task to sync item: %v", err) return nil, fmt.Errorf("could not convert task to sync item: %v", err)
} }
if err := deps.SyncRepo.Store(it); err != nil { if err := repos.Sync(tx).Store(it); err != nil {
return nil, fmt.Errorf("could not store sync item: %v", err) return nil, fmt.Errorf("could not store sync item: %v", err)
} }

View File

@ -189,10 +189,8 @@ func TestUpdateExecute(t *testing.T) {
} { } {
t.Run(tc.name, func(t *testing.T) { t.Run(tc.name, func(t *testing.T) {
// setup // setup
taskRepo := memory.NewTask() mems := memory.New()
localIDRepo := memory.NewLocalID() if err := mems.Task(nil).Store(item.Task{
syncRepo := memory.NewSync()
if err := taskRepo.Store(item.Task{
ID: tskID, ID: tskID,
Date: aDate, Date: aDate,
TaskBody: item.TaskBody{ TaskBody: item.TaskBody{
@ -204,7 +202,7 @@ func TestUpdateExecute(t *testing.T) {
}); err != nil { }); err != nil {
t.Errorf("exp nil, got %v", err) t.Errorf("exp nil, got %v", err)
} }
if err := localIDRepo.Store(tskID, lid); err != nil { if err := mems.LocalID(nil).Store(tskID, lid); err != nil {
t.Errorf("exp nil, ,got %v", err) t.Errorf("exp nil, ,got %v", err)
} }
@ -218,11 +216,7 @@ func TestUpdateExecute(t *testing.T) {
} }
// do // do
_, actDoErr := cmd.Do(command.Dependencies{ _, actDoErr := cmd.Do(mems, nil)
TaskRepo: taskRepo,
LocalIDRepo: localIDRepo,
SyncRepo: syncRepo,
})
if tc.expDoErr != (actDoErr != nil) { if tc.expDoErr != (actDoErr != nil) {
t.Errorf("exp %v, got %v", tc.expDoErr, actDoErr) t.Errorf("exp %v, got %v", tc.expDoErr, actDoErr)
} }
@ -231,14 +225,14 @@ func TestUpdateExecute(t *testing.T) {
} }
// check // check
actTask, err := taskRepo.FindOne(tskID) actTask, err := mems.Task(nil).FindOne(tskID)
if err != nil { if err != nil {
t.Errorf("exp nil, got %v", err) t.Errorf("exp nil, got %v", err)
} }
if diff := item.TaskDiff(tc.expTask, actTask); diff != "" { if diff := item.TaskDiff(tc.expTask, actTask); diff != "" {
t.Errorf("(exp -, got +)\n%s", diff) t.Errorf("(exp -, got +)\n%s", diff)
} }
updated, err := syncRepo.FindAll() updated, err := mems.Sync(nil).FindAll()
if err != nil { if err != nil {
t.Errorf("exp nil, got %v", err) t.Errorf("exp nil, got %v", err)
} }

View File

@ -27,7 +27,7 @@ func main() {
os.Exit(1) os.Exit(1)
} }
localIDRepo, taskRepo, syncRepo, err := sqlite.NewSqlites(conf.DBPath) repos, err := sqlite.NewSqlites(conf.DBPath)
if err != nil { if err != nil {
fmt.Printf("could not open db file: %s\n", err) fmt.Printf("could not open db file: %s\n", err)
os.Exit(1) os.Exit(1)
@ -35,12 +35,7 @@ func main() {
syncClient := client.New(conf.SyncURL, conf.ApiKey) syncClient := client.New(conf.SyncURL, conf.ApiKey)
cli := command.NewCLI(command.Dependencies{ cli := command.NewCLI(repos, syncClient)
LocalIDRepo: localIDRepo,
TaskRepo: taskRepo,
SyncRepo: syncRepo,
SyncClient: syncClient,
})
if err := cli.Run(os.Args[1:]); err != nil { if err := cli.Run(os.Args[1:]); err != nil {
fmt.Println(err) fmt.Println(err)
os.Exit(1) os.Exit(1)

View File

@ -0,0 +1,35 @@
package memory
import (
"go-mod.ewintr.nl/planner/plan/storage"
)
type Memories struct {
localID *LocalID
sync *Sync
task *Task
}
func New() *Memories {
return &Memories{
localID: NewLocalID(),
sync: NewSync(),
task: NewTask(),
}
}
func (mems *Memories) Begin() (*storage.Tx, error) {
return &storage.Tx{}, nil
}
func (mems *Memories) LocalID(_ *storage.Tx) storage.LocalID {
return mems.localID
}
func (mems *Memories) Sync(_ *storage.Tx) storage.Sync {
return mems.sync
}
func (mems *Memories) Task(_ *storage.Tx) storage.Task {
return mems.task
}

View File

@ -9,12 +9,12 @@ import (
) )
type LocalID struct { type LocalID struct {
db *sql.DB tx *storage.Tx
} }
func (l *LocalID) FindOne(lid int) (string, error) { func (l *LocalID) FindOne(lid int) (string, error) {
var id string var id string
err := l.db.QueryRow(` err := l.tx.QueryRow(`
SELECT id SELECT id
FROM localids FROM localids
WHERE local_id = ? WHERE local_id = ?
@ -30,7 +30,7 @@ WHERE local_id = ?
} }
func (l *LocalID) FindAll() (map[string]int, error) { func (l *LocalID) FindAll() (map[string]int, error) {
rows, err := l.db.Query(` rows, err := l.tx.Query(`
SELECT id, local_id SELECT id, local_id
FROM localids FROM localids
`) `)
@ -69,7 +69,7 @@ func (l *LocalID) Next() (int, error) {
} }
func (l *LocalID) Store(id string, localID int) error { func (l *LocalID) Store(id string, localID int) error {
if _, err := l.db.Exec(` if _, err := l.tx.Exec(`
INSERT INTO localids INSERT INTO localids
(id, local_id) (id, local_id)
VALUES VALUES
@ -81,7 +81,7 @@ VALUES
} }
func (l *LocalID) Delete(id string) error { func (l *LocalID) Delete(id string) error {
result, err := l.db.Exec(` result, err := l.tx.Exec(`
DELETE FROM localids DELETE FROM localids
WHERE id = ?`, id) WHERE id = ?`, id)
if err != nil { if err != nil {

View File

@ -5,6 +5,7 @@ import (
"errors" "errors"
"fmt" "fmt"
"go-mod.ewintr.nl/planner/plan/storage"
_ "modernc.org/sqlite" _ "modernc.org/sqlite"
) )
@ -19,27 +20,43 @@ var (
ErrSqliteFailure = errors.New("sqlite returned an error") ErrSqliteFailure = errors.New("sqlite returned an error")
) )
func NewSqlites(dbPath string) (*LocalID, *SqliteTask, *SqliteSync, error) { type Sqlites struct {
db *sql.DB
}
func (sqs *Sqlites) Begin() (*storage.Tx, error) {
tx, err := sqs.db.Begin()
if err != nil {
return nil, err
}
return storage.NewTx(tx), nil
}
func (sqs *Sqlites) LocalID(tx *storage.Tx) storage.LocalID {
return &LocalID{tx: tx}
}
func (sqs *Sqlites) Sync(tx *storage.Tx) storage.Sync {
return &Sync{tx: tx}
}
func (sqs *Sqlites) Task(tx *storage.Tx) storage.Task {
return &SqliteTask{tx: tx}
}
func NewSqlites(dbPath string) (*Sqlites, error) {
db, err := sql.Open("sqlite", dbPath) db, err := sql.Open("sqlite", dbPath)
if err != nil { if err != nil {
return nil, nil, nil, fmt.Errorf("%w: %v", ErrInvalidConfiguration, err) return nil, fmt.Errorf("%w: %v", ErrInvalidConfiguration, err)
}
sl := &LocalID{
db: db,
}
se := &SqliteTask{
db: db,
}
ss := &SqliteSync{
db: db,
} }
if err := migrate(db, migrations); err != nil { if err := migrate(db, migrations); err != nil {
return nil, nil, nil, err return nil, err
} }
return sl, se, ss, nil return &Sqlites{
db: db,
}, nil
} }
func migrate(db *sql.DB, wanted []string) error { func migrate(db *sql.DB, wanted []string) error {

View File

@ -6,18 +6,15 @@ import (
"time" "time"
"go-mod.ewintr.nl/planner/item" "go-mod.ewintr.nl/planner/item"
"go-mod.ewintr.nl/planner/plan/storage"
) )
type SqliteSync struct { type Sync struct {
db *sql.DB tx *storage.Tx
} }
func NewSqliteSync(db *sql.DB) *SqliteSync { func (s *Sync) FindAll() ([]item.Item, error) {
return &SqliteSync{db: db} rows, err := s.tx.Query("SELECT id, kind, updated, deleted, date, recurrer, recur_next, body FROM items")
}
func (s *SqliteSync) FindAll() ([]item.Item, error) {
rows, err := s.db.Query("SELECT id, kind, updated, deleted, date, recurrer, recur_next, body FROM items")
if err != nil { if err != nil {
return nil, fmt.Errorf("%w: failed to query items: %v", ErrSqliteFailure, err) return nil, fmt.Errorf("%w: failed to query items: %v", ErrSqliteFailure, err)
} }
@ -49,7 +46,7 @@ func (s *SqliteSync) FindAll() ([]item.Item, error) {
return items, nil return items, nil
} }
func (s *SqliteSync) Store(i item.Item) error { func (s *Sync) Store(i item.Item) error {
if i.Updated.IsZero() { if i.Updated.IsZero() {
i.Updated = time.Now() i.Updated = time.Now()
} }
@ -58,7 +55,7 @@ func (s *SqliteSync) Store(i item.Item) error {
recurStr = i.Recurrer.String() recurStr = i.Recurrer.String()
} }
_, err := s.db.Exec( _, err := s.tx.Exec(
`INSERT OR REPLACE INTO items (id, kind, updated, deleted, date, recurrer, recur_next, body) `INSERT OR REPLACE INTO items (id, kind, updated, deleted, date, recurrer, recur_next, body)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`, VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
i.ID, i.ID,
@ -76,24 +73,24 @@ func (s *SqliteSync) Store(i item.Item) error {
return nil return nil
} }
func (s *SqliteSync) DeleteAll() error { func (s *Sync) DeleteAll() error {
_, err := s.db.Exec("DELETE FROM items") _, err := s.tx.Exec("DELETE FROM items")
if err != nil { if err != nil {
return fmt.Errorf("%w: failed to delete all items: %v", ErrSqliteFailure, err) return fmt.Errorf("%w: failed to delete all items: %v", ErrSqliteFailure, err)
} }
return nil return nil
} }
func (s *SqliteSync) SetLastUpdate(ts time.Time) error { func (s *Sync) SetLastUpdate(ts time.Time) error {
if _, err := s.db.Exec(`UPDATE syncupdate SET timestamp = ?`, ts.Format(time.RFC3339)); err != nil { if _, err := s.tx.Exec(`UPDATE syncupdate SET timestamp = ?`, ts.Format(time.RFC3339)); err != nil {
return fmt.Errorf("%w: could not store timestamp: %v", ErrSqliteFailure, err) return fmt.Errorf("%w: could not store timestamp: %v", ErrSqliteFailure, err)
} }
return nil return nil
} }
func (s *SqliteSync) LastUpdate() (time.Time, error) { func (s *Sync) LastUpdate() (time.Time, error) {
var tsStr string var tsStr string
if err := s.db.QueryRow("SELECT timestamp FROM syncupdate").Scan(&tsStr); err != nil { if err := s.tx.QueryRow("SELECT timestamp FROM syncupdate").Scan(&tsStr); err != nil {
return time.Time{}, fmt.Errorf("%w: failed to get last update: %v", ErrSqliteFailure, err) return time.Time{}, fmt.Errorf("%w: failed to get last update: %v", ErrSqliteFailure, err)
} }
ts, err := time.Parse(time.RFC3339, tsStr) ts, err := time.Parse(time.RFC3339, tsStr)

View File

@ -10,7 +10,7 @@ import (
) )
type SqliteTask struct { type SqliteTask struct {
db *sql.DB tx *storage.Tx
} }
func (t *SqliteTask) Store(tsk item.Task) error { func (t *SqliteTask) Store(tsk item.Task) error {
@ -18,7 +18,7 @@ func (t *SqliteTask) Store(tsk item.Task) error {
if tsk.Recurrer != nil { if tsk.Recurrer != nil {
recurStr = tsk.Recurrer.String() recurStr = tsk.Recurrer.String()
} }
if _, err := t.db.Exec(` if _, err := t.tx.Exec(`
INSERT INTO tasks INSERT INTO tasks
(id, title, project, date, time, duration, recurrer) (id, title, project, date, time, duration, recurrer)
VALUES VALUES
@ -42,7 +42,7 @@ recurrer=?
func (t *SqliteTask) FindOne(id string) (item.Task, error) { func (t *SqliteTask) FindOne(id string) (item.Task, error) {
var tsk item.Task var tsk item.Task
var dateStr, timeStr, recurStr, durStr string var dateStr, timeStr, recurStr, durStr string
err := t.db.QueryRow(` err := t.tx.QueryRow(`
SELECT id, title, project, date, time, duration, recurrer SELECT id, title, project, date, time, duration, recurrer
FROM tasks FROM tasks
WHERE id = ?`, id).Scan(&tsk.ID, &tsk.Title, &tsk.Project, &dateStr, &timeStr, &durStr, &recurStr) WHERE id = ?`, id).Scan(&tsk.ID, &tsk.Title, &tsk.Project, &dateStr, &timeStr, &durStr, &recurStr)
@ -98,7 +98,7 @@ func (t *SqliteTask) FindMany(params storage.TaskListParams) ([]item.Task, error
} }
} }
rows, err := t.db.Query(query, args...) rows, err := t.tx.Query(query, args...)
if err != nil { if err != nil {
return nil, fmt.Errorf("%w: %v", ErrSqliteFailure, err) return nil, fmt.Errorf("%w: %v", ErrSqliteFailure, err)
} }
@ -126,7 +126,7 @@ func (t *SqliteTask) FindMany(params storage.TaskListParams) ([]item.Task, error
} }
func (t *SqliteTask) Delete(id string) error { func (t *SqliteTask) Delete(id string) error {
result, err := t.db.Exec(` result, err := t.tx.Exec(`
DELETE FROM tasks DELETE FROM tasks
WHERE id = ?`, id) WHERE id = ?`, id)
if err != nil { if err != nil {
@ -146,7 +146,7 @@ WHERE id = ?`, id)
} }
func (t *SqliteTask) Projects() (map[string]int, error) { func (t *SqliteTask) Projects() (map[string]int, error) {
rows, err := t.db.Query(`SELECT project, count(*) FROM tasks GROUP BY project`) rows, err := t.tx.Query(`SELECT project, count(*) FROM tasks GROUP BY project`)
if err != nil { if err != nil {
return nil, fmt.Errorf("%w: %v", ErrSqliteFailure, err) return nil, fmt.Errorf("%w: %v", ErrSqliteFailure, err)
} }

View File

@ -12,6 +12,10 @@ var (
ErrNotFound = errors.New("not found") ErrNotFound = errors.New("not found")
) )
type Txer interface {
Begin() (*Tx, error)
}
type LocalID interface { type LocalID interface {
FindOne(lid int) (string, error) FindOne(lid int) (string, error)
FindAll() (map[string]int, error) FindAll() (map[string]int, error)

View File

@ -0,0 +1,40 @@
package storage
import "database/sql"
// Tx wraps sql.Tx so transactions can be skipped for in-memory repositories
type Tx struct {
tx *sql.Tx
}
func NewTx(tx *sql.Tx) *Tx {
return &Tx{tx: tx}
}
func (tx *Tx) Rollback() error {
if tx.tx == nil {
return nil
}
return tx.tx.Rollback()
}
func (tx *Tx) Commit() error {
if tx.tx == nil {
return nil
}
return tx.tx.Commit()
}
func (tx *Tx) QueryRow(query string, args ...any) *sql.Row {
return tx.tx.QueryRow(query, args...)
}
func (tx *Tx) Query(query string, args ...any) (*sql.Rows, error) {
return tx.tx.Query(query, args...)
}
func (tx *Tx) Exec(query string, args ...any) (sql.Result, error) {
return tx.tx.Exec(query, args...)
}