diff --git a/dist/plan b/dist/plan index 8c88a59..f662649 100755 Binary files a/dist/plan and b/dist/plan differ diff --git a/plan/command/add.go b/plan/command/add.go index ac5416a..e5a8688 100644 --- a/plan/command/add.go +++ b/plan/command/add.go @@ -8,6 +8,7 @@ import ( "github.com/google/uuid" "go-mod.ewintr.nl/planner/item" + "go-mod.ewintr.nl/planner/sync/client" ) type AddArgs struct { @@ -92,16 +93,22 @@ type Add struct { Args AddArgs } -func (a Add) Do(deps Dependencies) (CommandResult, error) { - if err := deps.TaskRepo.Store(a.Args.Task); err != nil { +func (a Add) Do(repos Repositories, _ client.Client) (CommandResult, error) { + tx, err := repos.Begin() + if err != nil { + return nil, fmt.Errorf("could not start transaction: %v", err) + } + defer tx.Rollback() + + if err := repos.Task(tx).Store(a.Args.Task); err != nil { return nil, fmt.Errorf("could not store event: %v", err) } - localID, err := deps.LocalIDRepo.Next() + localID, err := repos.LocalID(tx).Next() if err != nil { return nil, fmt.Errorf("could not create next local id: %v", err) } - if err := deps.LocalIDRepo.Store(a.Args.Task.ID, localID); err != nil { + if err := repos.LocalID(tx).Store(a.Args.Task.ID, localID); err != nil { return nil, fmt.Errorf("could not store local id: %v", err) } @@ -109,10 +116,14 @@ func (a Add) Do(deps Dependencies) (CommandResult, error) { if err != nil { return nil, fmt.Errorf("could not convert event to sync item: %v", err) } - if err := deps.SyncRepo.Store(it); err != nil { + if err := repos.Sync(tx).Store(it); err != nil { return nil, fmt.Errorf("could not store sync item: %v", err) } + if err := tx.Commit(); err != nil { + return nil, fmt.Errorf("could not add task: %v", err) + } + return AddRender{}, nil } diff --git a/plan/command/add_test.go b/plan/command/add_test.go index 9fd8e7d..e8d29c2 100644 --- a/plan/command/add_test.go +++ b/plan/command/add_test.go @@ -60,9 +60,7 @@ func TestAdd(t *testing.T) { } { t.Run(tc.name, func(t *testing.T) { // setup - taskRepo := memory.NewTask() - localIDRepo := memory.NewLocalID() - syncRepo := memory.NewSync() + mems := memory.New() // parse cmd, actParseErr := command.NewAddArgs().Parse(tc.main, tc.fields) @@ -74,16 +72,12 @@ func TestAdd(t *testing.T) { } // do - if _, err := cmd.Do(command.Dependencies{ - TaskRepo: taskRepo, - LocalIDRepo: localIDRepo, - SyncRepo: syncRepo, - }); err != nil { + if _, err := cmd.Do(mems, nil); err != nil { t.Errorf("exp nil, got %v", err) } // check - actTasks, err := taskRepo.FindMany(storage.TaskListParams{}) + actTasks, err := mems.Task(nil).FindMany(storage.TaskListParams{}) if err != nil { t.Errorf("exp nil, got %v", err) } @@ -91,7 +85,7 @@ func TestAdd(t *testing.T) { t.Errorf("exp 1, got %d", len(actTasks)) } - actLocalIDs, err := localIDRepo.FindAll() + actLocalIDs, err := mems.LocalID(nil).FindAll() if err != nil { t.Errorf("exp nil, got %v", err) } @@ -110,7 +104,7 @@ func TestAdd(t *testing.T) { t.Errorf("(exp -, got +)\n%s", diff) } - updated, err := syncRepo.FindAll() + updated, err := mems.Sync(nil).FindAll() if err != nil { t.Errorf("exp nil, got %v", err) } diff --git a/plan/command/command.go b/plan/command/command.go index 0400225..4b64508 100644 --- a/plan/command/command.go +++ b/plan/command/command.go @@ -20,11 +20,11 @@ var ( ErrInvalidArg = errors.New("invalid argument") ) -type Dependencies struct { - LocalIDRepo storage.LocalID - TaskRepo storage.Task - SyncRepo storage.Sync - SyncClient client.Client +type Repositories interface { + Begin() (*storage.Tx, error) + LocalID(tx *storage.Tx) storage.LocalID + Sync(tx *storage.Tx) storage.Sync + Task(tx *storage.Tx) storage.Task } type CommandArgs interface { @@ -32,7 +32,7 @@ type CommandArgs interface { } type Command interface { - Do(deps Dependencies) (CommandResult, error) + Do(repos Repositories, client client.Client) (CommandResult, error) } type CommandResult interface { @@ -40,13 +40,15 @@ type CommandResult interface { } type CLI struct { - deps Dependencies + repos Repositories + client client.Client cmdArgs []CommandArgs } -func NewCLI(deps Dependencies) *CLI { +func NewCLI(repos Repositories, client client.Client) *CLI { return &CLI{ - deps: deps, + repos: repos, + client: client, cmdArgs: []CommandArgs{ NewShowArgs(), NewProjectsArgs(), NewAddArgs(), NewDeleteArgs(), NewListArgs(), @@ -66,7 +68,7 @@ func (cli *CLI) Run(args []string) error { return err } - result, err := cmd.Do(cli.deps) + result, err := cmd.Do(cli.repos, cli.client) if err != nil { return err } diff --git a/plan/command/delete.go b/plan/command/delete.go index 742b49d..31bb225 100644 --- a/plan/command/delete.go +++ b/plan/command/delete.go @@ -4,6 +4,8 @@ import ( "fmt" "slices" "strconv" + + "go-mod.ewintr.nl/planner/sync/client" ) type DeleteArgs struct { @@ -45,9 +47,15 @@ type Delete struct { Args DeleteArgs } -func (del Delete) Do(deps Dependencies) (CommandResult, error) { +func (del Delete) Do(repos Repositories, _ client.Client) (CommandResult, error) { + tx, err := repos.Begin() + if err != nil { + return nil, fmt.Errorf("could not start transaction: %v", err) + } + defer tx.Rollback() + var id string - idMap, err := deps.LocalIDRepo.FindAll() + idMap, err := repos.LocalID(tx).FindAll() if err != nil { return nil, fmt.Errorf("could not get local ids: %v", err) } @@ -60,7 +68,7 @@ func (del Delete) Do(deps Dependencies) (CommandResult, error) { return nil, fmt.Errorf("could not find local id") } - tsk, err := deps.TaskRepo.FindOne(id) + tsk, err := repos.Task(tx).FindOne(id) if err != nil { return nil, fmt.Errorf("could not get task: %v", err) } @@ -70,15 +78,15 @@ func (del Delete) Do(deps Dependencies) (CommandResult, error) { return nil, fmt.Errorf("could not convert task to sync item: %v", err) } it.Deleted = true - if err := deps.SyncRepo.Store(it); err != nil { + if err := repos.Sync(tx).Store(it); err != nil { return nil, fmt.Errorf("could not store sync item: %v", err) } - if err := deps.LocalIDRepo.Delete(id); err != nil { + if err := repos.LocalID(tx).Delete(id); err != nil { return nil, fmt.Errorf("could not delete local id: %v", err) } - if err := deps.TaskRepo.Delete(id); err != nil { + if err := repos.Task(tx).Delete(id); err != nil { return nil, fmt.Errorf("could not delete task: %v", err) } diff --git a/plan/command/delete_test.go b/plan/command/delete_test.go index 3f2899a..d6cfc90 100644 --- a/plan/command/delete_test.go +++ b/plan/command/delete_test.go @@ -49,13 +49,12 @@ func TestDelete(t *testing.T) { } { t.Run(tc.name, func(t *testing.T) { // setup - taskRepo := memory.NewTask() - syncRepo := memory.NewSync() - if err := taskRepo.Store(e); err != nil { + mems := memory.New() + + if err := mems.Task(nil).Store(e); err != nil { t.Errorf("exp nil, got %v", err) } - localIDRepo := memory.NewLocalID() - if err := localIDRepo.Store(e.ID, 1); err != nil { + if err := mems.LocalID(nil).Store(e.ID, 1); err != nil { t.Errorf("exp nil, got %v", err) } @@ -69,11 +68,7 @@ func TestDelete(t *testing.T) { } // do - _, actDoErr := cmd.Do(command.Dependencies{ - TaskRepo: taskRepo, - LocalIDRepo: localIDRepo, - SyncRepo: syncRepo, - }) + _, actDoErr := cmd.Do(mems, nil) if tc.expDoErr != (actDoErr != nil) { t.Errorf("exp false, got %v", actDoErr) } @@ -82,18 +77,18 @@ func TestDelete(t *testing.T) { } // check - _, repoErr := taskRepo.FindOne(e.ID) + _, repoErr := mems.Task(nil).FindOne(e.ID) if !errors.Is(repoErr, storage.ErrNotFound) { t.Errorf("exp %v, got %v", storage.ErrNotFound, repoErr) } - idMap, idErr := localIDRepo.FindAll() + idMap, idErr := mems.LocalID(nil).FindAll() if idErr != nil { t.Errorf("exp nil, got %v", idErr) } if len(idMap) != 0 { t.Errorf("exp 0, got %v", len(idMap)) } - updated, err := syncRepo.FindAll() + updated, err := mems.Sync(nil).FindAll() if err != nil { t.Errorf("exp nil, got %v", err) } diff --git a/plan/command/list.go b/plan/command/list.go index b462287..5636433 100644 --- a/plan/command/list.go +++ b/plan/command/list.go @@ -9,6 +9,7 @@ import ( "go-mod.ewintr.nl/planner/item" "go-mod.ewintr.nl/planner/plan/format" "go-mod.ewintr.nl/planner/plan/storage" + "go-mod.ewintr.nl/planner/sync/client" ) type ListArgs struct { @@ -96,12 +97,18 @@ type List struct { Args ListArgs } -func (list List) Do(deps Dependencies) (CommandResult, error) { - localIDs, err := deps.LocalIDRepo.FindAll() +func (list List) Do(repos Repositories, _ client.Client) (CommandResult, error) { + tx, err := repos.Begin() + if err != nil { + return nil, fmt.Errorf("could not start transaction: %v", err) + } + defer tx.Rollback() + + localIDs, err := repos.LocalID(tx).FindAll() if err != nil { return nil, fmt.Errorf("could not get local ids: %v", err) } - all, err := deps.TaskRepo.FindMany(storage.TaskListParams{ + all, err := repos.Task(tx).FindMany(storage.TaskListParams{ HasRecurrer: list.Args.HasRecurrer, From: list.Args.From, To: list.Args.To, diff --git a/plan/command/list_test.go b/plan/command/list_test.go index a19386e..133998e 100644 --- a/plan/command/list_test.go +++ b/plan/command/list_test.go @@ -79,8 +79,8 @@ func TestListParse(t *testing.T) { func TestList(t *testing.T) { t.Parallel() - taskRepo := memory.NewTask() - localRepo := memory.NewLocalID() + mems := memory.New() + e := item.Task{ ID: "id", Date: item.NewDate(2024, 10, 7), @@ -88,10 +88,10 @@ func TestList(t *testing.T) { Title: "name", }, } - if err := taskRepo.Store(e); err != nil { + if err := mems.Task(nil).Store(e); err != nil { t.Errorf("exp nil, got %v", err) } - if err := localRepo.Store(e.ID, 1); err != nil { + if err := mems.LocalID(nil).Store(e.ID, 1); err != nil { t.Errorf("exp nil, got %v", err) } @@ -115,10 +115,7 @@ func TestList(t *testing.T) { }, } { t.Run(tc.name, func(t *testing.T) { - res, err := tc.cmd.Do(command.Dependencies{ - TaskRepo: taskRepo, - LocalIDRepo: localRepo, - }) + res, err := tc.cmd.Do(mems, nil) if err != nil { t.Errorf("exp nil, got %v", err) } diff --git a/plan/command/projects.go b/plan/command/projects.go index a472601..af06af2 100644 --- a/plan/command/projects.go +++ b/plan/command/projects.go @@ -5,6 +5,7 @@ import ( "sort" "go-mod.ewintr.nl/planner/plan/format" + "go-mod.ewintr.nl/planner/sync/client" ) type ProjectsArgs struct{} @@ -23,8 +24,14 @@ func (pa ProjectsArgs) Parse(main []string, fields map[string]string) (Command, type Projects struct{} -func (ps Projects) Do(deps Dependencies) (CommandResult, error) { - projects, err := deps.TaskRepo.Projects() +func (ps Projects) Do(repos Repositories, _ client.Client) (CommandResult, error) { + tx, err := repos.Begin() + if err != nil { + return nil, fmt.Errorf("could not start transaction: %v", err) + } + defer tx.Rollback() + + projects, err := repos.Task(tx).Projects() if err != nil { return nil, fmt.Errorf("could not find projects: %v", err) } diff --git a/plan/command/show.go b/plan/command/show.go index 23c86d0..94f3c1d 100644 --- a/plan/command/show.go +++ b/plan/command/show.go @@ -8,6 +8,7 @@ import ( "go-mod.ewintr.nl/planner/item" "go-mod.ewintr.nl/planner/plan/format" "go-mod.ewintr.nl/planner/plan/storage" + "go-mod.ewintr.nl/planner/sync/client" ) type ShowArgs struct { @@ -38,8 +39,14 @@ type Show struct { args ShowArgs } -func (s Show) Do(deps Dependencies) (CommandResult, error) { - id, err := deps.LocalIDRepo.FindOne(s.args.localID) +func (s Show) Do(repos Repositories, _ client.Client) (CommandResult, error) { + tx, err := repos.Begin() + if err != nil { + return nil, fmt.Errorf("could not start transaction: %v", err) + } + defer tx.Rollback() + + id, err := repos.LocalID(tx).FindOne(s.args.localID) switch { case errors.Is(err, storage.ErrNotFound): return nil, fmt.Errorf("could not find local id") @@ -47,7 +54,7 @@ func (s Show) Do(deps Dependencies) (CommandResult, error) { return nil, err } - tsk, err := deps.TaskRepo.FindOne(id) + tsk, err := repos.Task(tx).FindOne(id) if err != nil { return nil, fmt.Errorf("could not find task") } diff --git a/plan/command/show_test.go b/plan/command/show_test.go index 1414ab8..26dd84a 100644 --- a/plan/command/show_test.go +++ b/plan/command/show_test.go @@ -12,8 +12,8 @@ import ( func TestShow(t *testing.T) { t.Parallel() - taskRepo := memory.NewTask() - localRepo := memory.NewLocalID() + mems := memory.New() + tsk := item.Task{ ID: "id", Date: item.NewDate(2024, 10, 7), @@ -21,10 +21,10 @@ func TestShow(t *testing.T) { Title: "name", }, } - if err := taskRepo.Store(tsk); err != nil { + if err := mems.Task(nil).Store(tsk); err != nil { t.Errorf("exp nil, got %v", err) } - if err := localRepo.Store(tsk.ID, 1); err != nil { + if err := mems.LocalID(nil).Store(tsk.ID, 1); err != nil { t.Errorf("exp nil, got %v", err) } @@ -69,10 +69,7 @@ func TestShow(t *testing.T) { } // do - _, actDoErr := cmd.Do(command.Dependencies{ - TaskRepo: taskRepo, - LocalIDRepo: localRepo, - }) + _, actDoErr := cmd.Do(mems, nil) if tc.expDoErr != (actDoErr != nil) { t.Errorf("exp %v, got %v", tc.expDoErr, actDoErr != nil) } diff --git a/plan/command/sync.go b/plan/command/sync.go index 0f7d13f..f248a83 100644 --- a/plan/command/sync.go +++ b/plan/command/sync.go @@ -8,6 +8,7 @@ import ( "go-mod.ewintr.nl/planner/item" "go-mod.ewintr.nl/planner/plan/storage" + "go-mod.ewintr.nl/planner/sync/client" ) type SyncArgs struct{} @@ -26,25 +27,31 @@ func (sa SyncArgs) Parse(main []string, flags map[string]string) (Command, error type Sync struct{} -func (s Sync) Do(deps Dependencies) (CommandResult, error) { +func (s Sync) Do(repos Repositories, client client.Client) (CommandResult, error) { + tx, err := repos.Begin() + if err != nil { + return nil, fmt.Errorf("could not start transaction: %v", err) + } + defer tx.Rollback() + // local new and updated - sendItems, err := deps.SyncRepo.FindAll() + sendItems, err := repos.Sync(tx).FindAll() if err != nil { return nil, fmt.Errorf("could not get updated items: %v", err) } - if err := deps.SyncClient.Update(sendItems); err != nil { + if err := client.Update(sendItems); err != nil { return nil, fmt.Errorf("could not send updated items: %v", err) } - if err := deps.SyncRepo.DeleteAll(); err != nil { + if err := repos.Sync(tx).DeleteAll(); err != nil { return nil, fmt.Errorf("could not clear updated items: %v", err) } // get new/updated items - oldTS, err := deps.SyncRepo.LastUpdate() + oldTS, err := repos.Sync(tx).LastUpdate() if err != nil { return nil, fmt.Errorf("could not find timestamp of last update: %v", err) } - recItems, err := deps.SyncClient.Updated([]item.Kind{item.KindTask}, oldTS) + recItems, err := client.Updated([]item.Kind{item.KindTask}, oldTS) if err != nil { return nil, fmt.Errorf("could not receive updates: %v", err) } @@ -56,10 +63,10 @@ func (s Sync) Do(deps Dependencies) (CommandResult, error) { newTS = ri.Updated } if ri.Deleted { - if err := deps.LocalIDRepo.Delete(ri.ID); err != nil && !errors.Is(err, storage.ErrNotFound) { + if err := repos.LocalID(tx).Delete(ri.ID); err != nil && !errors.Is(err, storage.ErrNotFound) { return nil, fmt.Errorf("could not delete local id: %v", err) } - if err := deps.TaskRepo.Delete(ri.ID); err != nil && !errors.Is(err, storage.ErrNotFound) { + if err := repos.Task(tx).Delete(ri.ID); err != nil && !errors.Is(err, storage.ErrNotFound) { return nil, fmt.Errorf("could not delete task: %v", err) } continue @@ -67,7 +74,7 @@ func (s Sync) Do(deps Dependencies) (CommandResult, error) { updated = append(updated, ri) } - lidMap, err := deps.LocalIDRepo.FindAll() + lidMap, err := repos.LocalID(tx).FindAll() if err != nil { return nil, fmt.Errorf("could not get local ids: %v", err) } @@ -83,23 +90,23 @@ func (s Sync) Do(deps Dependencies) (CommandResult, error) { RecurNext: u.RecurNext, TaskBody: tskBody, } - if err := deps.TaskRepo.Store(tsk); err != nil { + if err := repos.Task(tx).Store(tsk); err != nil { return nil, fmt.Errorf("could not store task: %v", err) } lid, ok := lidMap[u.ID] if !ok { - lid, err = deps.LocalIDRepo.Next() + lid, err = repos.LocalID(tx).Next() if err != nil { return nil, fmt.Errorf("could not get next local id: %v", err) } - if err := deps.LocalIDRepo.Store(u.ID, lid); err != nil { + if err := repos.LocalID(tx).Store(u.ID, lid); err != nil { return nil, fmt.Errorf("could not store local id: %v", err) } } } - if err := deps.SyncRepo.SetLastUpdate(newTS); err != nil { + if err := repos.Sync(tx).SetLastUpdate(newTS); err != nil { return nil, fmt.Errorf("could not store update timestamp: %v", err) } diff --git a/plan/command/sync_test.go b/plan/command/sync_test.go index 9e3b55a..4f1ab91 100644 --- a/plan/command/sync_test.go +++ b/plan/command/sync_test.go @@ -47,9 +47,7 @@ func TestSyncSend(t *testing.T) { t.Parallel() syncClient := client.NewMemory() - syncRepo := memory.NewSync() - localIDRepo := memory.NewLocalID() - taskRepo := memory.NewTask() + mems := memory.New() it := item.Item{ ID: "a", @@ -60,7 +58,7 @@ func TestSyncSend(t *testing.T) { "duration":"1h" }`, } - if err := syncRepo.Store(it); err != nil { + if err := mems.Sync(nil).Store(it); err != nil { t.Errorf("exp nil, got %v", err) } @@ -81,12 +79,7 @@ func TestSyncSend(t *testing.T) { if err != nil { t.Errorf("exp nil, got %v", err) } - if _, err := cmd.Do(command.Dependencies{ - TaskRepo: taskRepo, - LocalIDRepo: localIDRepo, - SyncRepo: syncRepo, - SyncClient: syncClient, - }); err != nil { + if _, err := cmd.Do(mems, syncClient); err != nil { t.Errorf("exp nil, got %v", err) } actItems, actErr := syncClient.Updated(tc.ks, tc.ts) @@ -97,7 +90,7 @@ func TestSyncSend(t *testing.T) { t.Errorf("(exp +, got -)\n%s", diff) } - actLeft, actErr := syncRepo.FindAll() + actLeft, actErr := mems.Sync(nil).FindAll() if actErr != nil { t.Errorf("exp nil, got %v", actErr) } @@ -185,16 +178,14 @@ func TestSyncReceive(t *testing.T) { } { t.Run(tc.name, func(t *testing.T) { syncClient := client.NewMemory() - syncRepo := memory.NewSync() - localIDRepo := memory.NewLocalID() - taskRepo := memory.NewTask() + mems := memory.New() // setup for i, p := range tc.present { - if err := taskRepo.Store(p); err != nil { + if err := mems.Task(nil).Store(p); err != nil { t.Errorf("exp nil, got %v", err) } - if err := localIDRepo.Store(p.ID, i+1); err != nil { + if err := mems.LocalID(nil).Store(p.ID, i+1); err != nil { t.Errorf("exp nil, got %v", err) } } @@ -207,24 +198,19 @@ func TestSyncReceive(t *testing.T) { if err != nil { t.Errorf("exp nil, got %v", err) } - if _, err := cmd.Do(command.Dependencies{ - TaskRepo: taskRepo, - LocalIDRepo: localIDRepo, - SyncRepo: syncRepo, - SyncClient: syncClient, - }); err != nil { + if _, err := cmd.Do(mems, syncClient); err != nil { t.Errorf("exp nil, got %v", err) } // check result - actTasks, err := taskRepo.FindMany(storage.TaskListParams{}) + actTasks, err := mems.Task(nil).FindMany(storage.TaskListParams{}) if err != nil { t.Errorf("exp nil, got %v", err) } if diff := item.TaskDiffs(tc.expTask, actTasks); diff != "" { t.Errorf("(exp +, got -)\n%s", diff) } - actLocalIDs, err := localIDRepo.FindAll() + actLocalIDs, err := mems.LocalID(nil).FindAll() if err != nil { t.Errorf("exp nil, got %v", err) } diff --git a/plan/command/update.go b/plan/command/update.go index 2fe4779..530c076 100644 --- a/plan/command/update.go +++ b/plan/command/update.go @@ -10,6 +10,7 @@ import ( "go-mod.ewintr.nl/planner/item" "go-mod.ewintr.nl/planner/plan/storage" + "go-mod.ewintr.nl/planner/sync/client" ) type UpdateArgs struct { @@ -116,8 +117,14 @@ type Update struct { args UpdateArgs } -func (u Update) Do(deps Dependencies) (CommandResult, error) { - id, err := deps.LocalIDRepo.FindOne(u.args.LocalID) +func (u Update) Do(repos Repositories, _ client.Client) (CommandResult, error) { + tx, err := repos.Begin() + if err != nil { + return nil, fmt.Errorf("could not start transaction: %v", err) + } + defer tx.Rollback() + + id, err := repos.LocalID(tx).FindOne(u.args.LocalID) switch { case errors.Is(err, storage.ErrNotFound): return nil, fmt.Errorf("could not find local id") @@ -125,7 +132,7 @@ func (u Update) Do(deps Dependencies) (CommandResult, error) { return nil, err } - tsk, err := deps.TaskRepo.FindOne(id) + tsk, err := repos.Task(tx).FindOne(id) if err != nil { return nil, fmt.Errorf("could not find task") } @@ -154,7 +161,7 @@ func (u Update) Do(deps Dependencies) (CommandResult, error) { return nil, fmt.Errorf("task is unvalid") } - if err := deps.TaskRepo.Store(tsk); err != nil { + if err := repos.Task(tx).Store(tsk); err != nil { return nil, fmt.Errorf("could not store task: %v", err) } @@ -162,7 +169,7 @@ func (u Update) Do(deps Dependencies) (CommandResult, error) { if err != nil { return nil, fmt.Errorf("could not convert task to sync item: %v", err) } - if err := deps.SyncRepo.Store(it); err != nil { + if err := repos.Sync(tx).Store(it); err != nil { return nil, fmt.Errorf("could not store sync item: %v", err) } diff --git a/plan/command/update_test.go b/plan/command/update_test.go index d0b1b00..92f0728 100644 --- a/plan/command/update_test.go +++ b/plan/command/update_test.go @@ -189,10 +189,8 @@ func TestUpdateExecute(t *testing.T) { } { t.Run(tc.name, func(t *testing.T) { // setup - taskRepo := memory.NewTask() - localIDRepo := memory.NewLocalID() - syncRepo := memory.NewSync() - if err := taskRepo.Store(item.Task{ + mems := memory.New() + if err := mems.Task(nil).Store(item.Task{ ID: tskID, Date: aDate, TaskBody: item.TaskBody{ @@ -204,7 +202,7 @@ func TestUpdateExecute(t *testing.T) { }); err != nil { t.Errorf("exp nil, got %v", err) } - if err := localIDRepo.Store(tskID, lid); err != nil { + if err := mems.LocalID(nil).Store(tskID, lid); err != nil { t.Errorf("exp nil, ,got %v", err) } @@ -218,11 +216,7 @@ func TestUpdateExecute(t *testing.T) { } // do - _, actDoErr := cmd.Do(command.Dependencies{ - TaskRepo: taskRepo, - LocalIDRepo: localIDRepo, - SyncRepo: syncRepo, - }) + _, actDoErr := cmd.Do(mems, nil) if tc.expDoErr != (actDoErr != nil) { t.Errorf("exp %v, got %v", tc.expDoErr, actDoErr) } @@ -231,14 +225,14 @@ func TestUpdateExecute(t *testing.T) { } // check - actTask, err := taskRepo.FindOne(tskID) + actTask, err := mems.Task(nil).FindOne(tskID) if err != nil { t.Errorf("exp nil, got %v", err) } if diff := item.TaskDiff(tc.expTask, actTask); diff != "" { t.Errorf("(exp -, got +)\n%s", diff) } - updated, err := syncRepo.FindAll() + updated, err := mems.Sync(nil).FindAll() if err != nil { t.Errorf("exp nil, got %v", err) } diff --git a/plan/main.go b/plan/main.go index 14850c1..e4b0dbe 100644 --- a/plan/main.go +++ b/plan/main.go @@ -27,7 +27,7 @@ func main() { os.Exit(1) } - localIDRepo, taskRepo, syncRepo, err := sqlite.NewSqlites(conf.DBPath) + repos, err := sqlite.NewSqlites(conf.DBPath) if err != nil { fmt.Printf("could not open db file: %s\n", err) os.Exit(1) @@ -35,12 +35,7 @@ func main() { syncClient := client.New(conf.SyncURL, conf.ApiKey) - cli := command.NewCLI(command.Dependencies{ - LocalIDRepo: localIDRepo, - TaskRepo: taskRepo, - SyncRepo: syncRepo, - SyncClient: syncClient, - }) + cli := command.NewCLI(repos, syncClient) if err := cli.Run(os.Args[1:]); err != nil { fmt.Println(err) os.Exit(1) diff --git a/plan/storage/memory/memory.go b/plan/storage/memory/memory.go new file mode 100644 index 0000000..4384780 --- /dev/null +++ b/plan/storage/memory/memory.go @@ -0,0 +1,35 @@ +package memory + +import ( + "go-mod.ewintr.nl/planner/plan/storage" +) + +type Memories struct { + localID *LocalID + sync *Sync + task *Task +} + +func New() *Memories { + return &Memories{ + localID: NewLocalID(), + sync: NewSync(), + task: NewTask(), + } +} + +func (mems *Memories) Begin() (*storage.Tx, error) { + return &storage.Tx{}, nil +} + +func (mems *Memories) LocalID(_ *storage.Tx) storage.LocalID { + return mems.localID +} + +func (mems *Memories) Sync(_ *storage.Tx) storage.Sync { + return mems.sync +} + +func (mems *Memories) Task(_ *storage.Tx) storage.Task { + return mems.task +} diff --git a/plan/storage/sqlite/localid.go b/plan/storage/sqlite/localid.go index 08f8629..5cfac90 100644 --- a/plan/storage/sqlite/localid.go +++ b/plan/storage/sqlite/localid.go @@ -9,12 +9,12 @@ import ( ) type LocalID struct { - db *sql.DB + tx *storage.Tx } func (l *LocalID) FindOne(lid int) (string, error) { var id string - err := l.db.QueryRow(` + err := l.tx.QueryRow(` SELECT id FROM localids WHERE local_id = ? @@ -30,7 +30,7 @@ WHERE local_id = ? } func (l *LocalID) FindAll() (map[string]int, error) { - rows, err := l.db.Query(` + rows, err := l.tx.Query(` SELECT id, local_id FROM localids `) @@ -69,7 +69,7 @@ func (l *LocalID) Next() (int, error) { } func (l *LocalID) Store(id string, localID int) error { - if _, err := l.db.Exec(` + if _, err := l.tx.Exec(` INSERT INTO localids (id, local_id) VALUES @@ -81,7 +81,7 @@ VALUES } func (l *LocalID) Delete(id string) error { - result, err := l.db.Exec(` + result, err := l.tx.Exec(` DELETE FROM localids WHERE id = ?`, id) if err != nil { diff --git a/plan/storage/sqlite/sqlite.go b/plan/storage/sqlite/sqlite.go index b61814c..7f884f9 100644 --- a/plan/storage/sqlite/sqlite.go +++ b/plan/storage/sqlite/sqlite.go @@ -5,6 +5,7 @@ import ( "errors" "fmt" + "go-mod.ewintr.nl/planner/plan/storage" _ "modernc.org/sqlite" ) @@ -19,27 +20,43 @@ var ( ErrSqliteFailure = errors.New("sqlite returned an error") ) -func NewSqlites(dbPath string) (*LocalID, *SqliteTask, *SqliteSync, error) { +type Sqlites struct { + db *sql.DB +} + +func (sqs *Sqlites) Begin() (*storage.Tx, error) { + tx, err := sqs.db.Begin() + if err != nil { + return nil, err + } + return storage.NewTx(tx), nil +} + +func (sqs *Sqlites) LocalID(tx *storage.Tx) storage.LocalID { + return &LocalID{tx: tx} +} + +func (sqs *Sqlites) Sync(tx *storage.Tx) storage.Sync { + return &Sync{tx: tx} +} + +func (sqs *Sqlites) Task(tx *storage.Tx) storage.Task { + return &SqliteTask{tx: tx} +} + +func NewSqlites(dbPath string) (*Sqlites, error) { db, err := sql.Open("sqlite", dbPath) if err != nil { - return nil, nil, nil, fmt.Errorf("%w: %v", ErrInvalidConfiguration, err) - } - - sl := &LocalID{ - db: db, - } - se := &SqliteTask{ - db: db, - } - ss := &SqliteSync{ - db: db, + return nil, fmt.Errorf("%w: %v", ErrInvalidConfiguration, err) } if err := migrate(db, migrations); err != nil { - return nil, nil, nil, err + return nil, err } - return sl, se, ss, nil + return &Sqlites{ + db: db, + }, nil } func migrate(db *sql.DB, wanted []string) error { diff --git a/plan/storage/sqlite/sync.go b/plan/storage/sqlite/sync.go index fdf5fba..ba30f5e 100644 --- a/plan/storage/sqlite/sync.go +++ b/plan/storage/sqlite/sync.go @@ -6,18 +6,15 @@ import ( "time" "go-mod.ewintr.nl/planner/item" + "go-mod.ewintr.nl/planner/plan/storage" ) -type SqliteSync struct { - db *sql.DB +type Sync struct { + tx *storage.Tx } -func NewSqliteSync(db *sql.DB) *SqliteSync { - return &SqliteSync{db: db} -} - -func (s *SqliteSync) FindAll() ([]item.Item, error) { - rows, err := s.db.Query("SELECT id, kind, updated, deleted, date, recurrer, recur_next, body FROM items") +func (s *Sync) FindAll() ([]item.Item, error) { + rows, err := s.tx.Query("SELECT id, kind, updated, deleted, date, recurrer, recur_next, body FROM items") if err != nil { return nil, fmt.Errorf("%w: failed to query items: %v", ErrSqliteFailure, err) } @@ -49,7 +46,7 @@ func (s *SqliteSync) FindAll() ([]item.Item, error) { return items, nil } -func (s *SqliteSync) Store(i item.Item) error { +func (s *Sync) Store(i item.Item) error { if i.Updated.IsZero() { i.Updated = time.Now() } @@ -58,7 +55,7 @@ func (s *SqliteSync) Store(i item.Item) error { recurStr = i.Recurrer.String() } - _, err := s.db.Exec( + _, err := s.tx.Exec( `INSERT OR REPLACE INTO items (id, kind, updated, deleted, date, recurrer, recur_next, body) VALUES (?, ?, ?, ?, ?, ?, ?, ?)`, i.ID, @@ -76,24 +73,24 @@ func (s *SqliteSync) Store(i item.Item) error { return nil } -func (s *SqliteSync) DeleteAll() error { - _, err := s.db.Exec("DELETE FROM items") +func (s *Sync) DeleteAll() error { + _, err := s.tx.Exec("DELETE FROM items") if err != nil { return fmt.Errorf("%w: failed to delete all items: %v", ErrSqliteFailure, err) } return nil } -func (s *SqliteSync) SetLastUpdate(ts time.Time) error { - if _, err := s.db.Exec(`UPDATE syncupdate SET timestamp = ?`, ts.Format(time.RFC3339)); err != nil { +func (s *Sync) SetLastUpdate(ts time.Time) error { + if _, err := s.tx.Exec(`UPDATE syncupdate SET timestamp = ?`, ts.Format(time.RFC3339)); err != nil { return fmt.Errorf("%w: could not store timestamp: %v", ErrSqliteFailure, err) } return nil } -func (s *SqliteSync) LastUpdate() (time.Time, error) { +func (s *Sync) LastUpdate() (time.Time, error) { var tsStr string - if err := s.db.QueryRow("SELECT timestamp FROM syncupdate").Scan(&tsStr); err != nil { + if err := s.tx.QueryRow("SELECT timestamp FROM syncupdate").Scan(&tsStr); err != nil { return time.Time{}, fmt.Errorf("%w: failed to get last update: %v", ErrSqliteFailure, err) } ts, err := time.Parse(time.RFC3339, tsStr) diff --git a/plan/storage/sqlite/task.go b/plan/storage/sqlite/task.go index 4fe7dac..88f64d0 100644 --- a/plan/storage/sqlite/task.go +++ b/plan/storage/sqlite/task.go @@ -10,7 +10,7 @@ import ( ) type SqliteTask struct { - db *sql.DB + tx *storage.Tx } func (t *SqliteTask) Store(tsk item.Task) error { @@ -18,7 +18,7 @@ func (t *SqliteTask) Store(tsk item.Task) error { if tsk.Recurrer != nil { recurStr = tsk.Recurrer.String() } - if _, err := t.db.Exec(` + if _, err := t.tx.Exec(` INSERT INTO tasks (id, title, project, date, time, duration, recurrer) VALUES @@ -42,7 +42,7 @@ recurrer=? func (t *SqliteTask) FindOne(id string) (item.Task, error) { var tsk item.Task var dateStr, timeStr, recurStr, durStr string - err := t.db.QueryRow(` + err := t.tx.QueryRow(` SELECT id, title, project, date, time, duration, recurrer FROM tasks WHERE id = ?`, id).Scan(&tsk.ID, &tsk.Title, &tsk.Project, &dateStr, &timeStr, &durStr, &recurStr) @@ -98,7 +98,7 @@ func (t *SqliteTask) FindMany(params storage.TaskListParams) ([]item.Task, error } } - rows, err := t.db.Query(query, args...) + rows, err := t.tx.Query(query, args...) if err != nil { return nil, fmt.Errorf("%w: %v", ErrSqliteFailure, err) } @@ -126,7 +126,7 @@ func (t *SqliteTask) FindMany(params storage.TaskListParams) ([]item.Task, error } func (t *SqliteTask) Delete(id string) error { - result, err := t.db.Exec(` + result, err := t.tx.Exec(` DELETE FROM tasks WHERE id = ?`, id) if err != nil { @@ -146,7 +146,7 @@ WHERE id = ?`, id) } func (t *SqliteTask) Projects() (map[string]int, error) { - rows, err := t.db.Query(`SELECT project, count(*) FROM tasks GROUP BY project`) + rows, err := t.tx.Query(`SELECT project, count(*) FROM tasks GROUP BY project`) if err != nil { return nil, fmt.Errorf("%w: %v", ErrSqliteFailure, err) } diff --git a/plan/storage/storage.go b/plan/storage/storage.go index 3cb7672..6bbc449 100644 --- a/plan/storage/storage.go +++ b/plan/storage/storage.go @@ -12,6 +12,10 @@ var ( ErrNotFound = errors.New("not found") ) +type Txer interface { + Begin() (*Tx, error) +} + type LocalID interface { FindOne(lid int) (string, error) FindAll() (map[string]int, error) diff --git a/plan/storage/transaction.go b/plan/storage/transaction.go new file mode 100644 index 0000000..d9709f2 --- /dev/null +++ b/plan/storage/transaction.go @@ -0,0 +1,40 @@ +package storage + +import "database/sql" + +// Tx wraps sql.Tx so transactions can be skipped for in-memory repositories +type Tx struct { + tx *sql.Tx +} + +func NewTx(tx *sql.Tx) *Tx { + return &Tx{tx: tx} +} + +func (tx *Tx) Rollback() error { + if tx.tx == nil { + return nil + } + + return tx.tx.Rollback() +} + +func (tx *Tx) Commit() error { + if tx.tx == nil { + return nil + } + + return tx.tx.Commit() +} + +func (tx *Tx) QueryRow(query string, args ...any) *sql.Row { + return tx.tx.QueryRow(query, args...) +} + +func (tx *Tx) Query(query string, args ...any) (*sql.Rows, error) { + return tx.tx.Query(query, args...) +} + +func (tx *Tx) Exec(query string, args ...any) (sql.Result, error) { + return tx.tx.Exec(query, args...) +}