transactions

This commit is contained in:
Erik Winter 2025-01-13 09:13:48 +01:00
parent 428b828c45
commit 7f4d870311
23 changed files with 278 additions and 171 deletions

BIN
dist/plan vendored

Binary file not shown.

View File

@ -8,6 +8,7 @@ import (
"github.com/google/uuid"
"go-mod.ewintr.nl/planner/item"
"go-mod.ewintr.nl/planner/sync/client"
)
type AddArgs struct {
@ -92,16 +93,22 @@ type Add struct {
Args AddArgs
}
func (a Add) Do(deps Dependencies) (CommandResult, error) {
if err := deps.TaskRepo.Store(a.Args.Task); err != nil {
func (a Add) Do(repos Repositories, _ client.Client) (CommandResult, error) {
tx, err := repos.Begin()
if err != nil {
return nil, fmt.Errorf("could not start transaction: %v", err)
}
defer tx.Rollback()
if err := repos.Task(tx).Store(a.Args.Task); err != nil {
return nil, fmt.Errorf("could not store event: %v", err)
}
localID, err := deps.LocalIDRepo.Next()
localID, err := repos.LocalID(tx).Next()
if err != nil {
return nil, fmt.Errorf("could not create next local id: %v", err)
}
if err := deps.LocalIDRepo.Store(a.Args.Task.ID, localID); err != nil {
if err := repos.LocalID(tx).Store(a.Args.Task.ID, localID); err != nil {
return nil, fmt.Errorf("could not store local id: %v", err)
}
@ -109,10 +116,14 @@ func (a Add) Do(deps Dependencies) (CommandResult, error) {
if err != nil {
return nil, fmt.Errorf("could not convert event to sync item: %v", err)
}
if err := deps.SyncRepo.Store(it); err != nil {
if err := repos.Sync(tx).Store(it); err != nil {
return nil, fmt.Errorf("could not store sync item: %v", err)
}
if err := tx.Commit(); err != nil {
return nil, fmt.Errorf("could not add task: %v", err)
}
return AddRender{}, nil
}

View File

@ -60,9 +60,7 @@ func TestAdd(t *testing.T) {
} {
t.Run(tc.name, func(t *testing.T) {
// setup
taskRepo := memory.NewTask()
localIDRepo := memory.NewLocalID()
syncRepo := memory.NewSync()
mems := memory.New()
// parse
cmd, actParseErr := command.NewAddArgs().Parse(tc.main, tc.fields)
@ -74,16 +72,12 @@ func TestAdd(t *testing.T) {
}
// do
if _, err := cmd.Do(command.Dependencies{
TaskRepo: taskRepo,
LocalIDRepo: localIDRepo,
SyncRepo: syncRepo,
}); err != nil {
if _, err := cmd.Do(mems, nil); err != nil {
t.Errorf("exp nil, got %v", err)
}
// check
actTasks, err := taskRepo.FindMany(storage.TaskListParams{})
actTasks, err := mems.Task(nil).FindMany(storage.TaskListParams{})
if err != nil {
t.Errorf("exp nil, got %v", err)
}
@ -91,7 +85,7 @@ func TestAdd(t *testing.T) {
t.Errorf("exp 1, got %d", len(actTasks))
}
actLocalIDs, err := localIDRepo.FindAll()
actLocalIDs, err := mems.LocalID(nil).FindAll()
if err != nil {
t.Errorf("exp nil, got %v", err)
}
@ -110,7 +104,7 @@ func TestAdd(t *testing.T) {
t.Errorf("(exp -, got +)\n%s", diff)
}
updated, err := syncRepo.FindAll()
updated, err := mems.Sync(nil).FindAll()
if err != nil {
t.Errorf("exp nil, got %v", err)
}

View File

@ -20,11 +20,11 @@ var (
ErrInvalidArg = errors.New("invalid argument")
)
type Dependencies struct {
LocalIDRepo storage.LocalID
TaskRepo storage.Task
SyncRepo storage.Sync
SyncClient client.Client
type Repositories interface {
Begin() (*storage.Tx, error)
LocalID(tx *storage.Tx) storage.LocalID
Sync(tx *storage.Tx) storage.Sync
Task(tx *storage.Tx) storage.Task
}
type CommandArgs interface {
@ -32,7 +32,7 @@ type CommandArgs interface {
}
type Command interface {
Do(deps Dependencies) (CommandResult, error)
Do(repos Repositories, client client.Client) (CommandResult, error)
}
type CommandResult interface {
@ -40,13 +40,15 @@ type CommandResult interface {
}
type CLI struct {
deps Dependencies
repos Repositories
client client.Client
cmdArgs []CommandArgs
}
func NewCLI(deps Dependencies) *CLI {
func NewCLI(repos Repositories, client client.Client) *CLI {
return &CLI{
deps: deps,
repos: repos,
client: client,
cmdArgs: []CommandArgs{
NewShowArgs(), NewProjectsArgs(),
NewAddArgs(), NewDeleteArgs(), NewListArgs(),
@ -66,7 +68,7 @@ func (cli *CLI) Run(args []string) error {
return err
}
result, err := cmd.Do(cli.deps)
result, err := cmd.Do(cli.repos, cli.client)
if err != nil {
return err
}

View File

@ -4,6 +4,8 @@ import (
"fmt"
"slices"
"strconv"
"go-mod.ewintr.nl/planner/sync/client"
)
type DeleteArgs struct {
@ -45,9 +47,15 @@ type Delete struct {
Args DeleteArgs
}
func (del Delete) Do(deps Dependencies) (CommandResult, error) {
func (del Delete) Do(repos Repositories, _ client.Client) (CommandResult, error) {
tx, err := repos.Begin()
if err != nil {
return nil, fmt.Errorf("could not start transaction: %v", err)
}
defer tx.Rollback()
var id string
idMap, err := deps.LocalIDRepo.FindAll()
idMap, err := repos.LocalID(tx).FindAll()
if err != nil {
return nil, fmt.Errorf("could not get local ids: %v", err)
}
@ -60,7 +68,7 @@ func (del Delete) Do(deps Dependencies) (CommandResult, error) {
return nil, fmt.Errorf("could not find local id")
}
tsk, err := deps.TaskRepo.FindOne(id)
tsk, err := repos.Task(tx).FindOne(id)
if err != nil {
return nil, fmt.Errorf("could not get task: %v", err)
}
@ -70,15 +78,15 @@ func (del Delete) Do(deps Dependencies) (CommandResult, error) {
return nil, fmt.Errorf("could not convert task to sync item: %v", err)
}
it.Deleted = true
if err := deps.SyncRepo.Store(it); err != nil {
if err := repos.Sync(tx).Store(it); err != nil {
return nil, fmt.Errorf("could not store sync item: %v", err)
}
if err := deps.LocalIDRepo.Delete(id); err != nil {
if err := repos.LocalID(tx).Delete(id); err != nil {
return nil, fmt.Errorf("could not delete local id: %v", err)
}
if err := deps.TaskRepo.Delete(id); err != nil {
if err := repos.Task(tx).Delete(id); err != nil {
return nil, fmt.Errorf("could not delete task: %v", err)
}

View File

@ -49,13 +49,12 @@ func TestDelete(t *testing.T) {
} {
t.Run(tc.name, func(t *testing.T) {
// setup
taskRepo := memory.NewTask()
syncRepo := memory.NewSync()
if err := taskRepo.Store(e); err != nil {
mems := memory.New()
if err := mems.Task(nil).Store(e); err != nil {
t.Errorf("exp nil, got %v", err)
}
localIDRepo := memory.NewLocalID()
if err := localIDRepo.Store(e.ID, 1); err != nil {
if err := mems.LocalID(nil).Store(e.ID, 1); err != nil {
t.Errorf("exp nil, got %v", err)
}
@ -69,11 +68,7 @@ func TestDelete(t *testing.T) {
}
// do
_, actDoErr := cmd.Do(command.Dependencies{
TaskRepo: taskRepo,
LocalIDRepo: localIDRepo,
SyncRepo: syncRepo,
})
_, actDoErr := cmd.Do(mems, nil)
if tc.expDoErr != (actDoErr != nil) {
t.Errorf("exp false, got %v", actDoErr)
}
@ -82,18 +77,18 @@ func TestDelete(t *testing.T) {
}
// check
_, repoErr := taskRepo.FindOne(e.ID)
_, repoErr := mems.Task(nil).FindOne(e.ID)
if !errors.Is(repoErr, storage.ErrNotFound) {
t.Errorf("exp %v, got %v", storage.ErrNotFound, repoErr)
}
idMap, idErr := localIDRepo.FindAll()
idMap, idErr := mems.LocalID(nil).FindAll()
if idErr != nil {
t.Errorf("exp nil, got %v", idErr)
}
if len(idMap) != 0 {
t.Errorf("exp 0, got %v", len(idMap))
}
updated, err := syncRepo.FindAll()
updated, err := mems.Sync(nil).FindAll()
if err != nil {
t.Errorf("exp nil, got %v", err)
}

View File

@ -9,6 +9,7 @@ import (
"go-mod.ewintr.nl/planner/item"
"go-mod.ewintr.nl/planner/plan/format"
"go-mod.ewintr.nl/planner/plan/storage"
"go-mod.ewintr.nl/planner/sync/client"
)
type ListArgs struct {
@ -96,12 +97,18 @@ type List struct {
Args ListArgs
}
func (list List) Do(deps Dependencies) (CommandResult, error) {
localIDs, err := deps.LocalIDRepo.FindAll()
func (list List) Do(repos Repositories, _ client.Client) (CommandResult, error) {
tx, err := repos.Begin()
if err != nil {
return nil, fmt.Errorf("could not start transaction: %v", err)
}
defer tx.Rollback()
localIDs, err := repos.LocalID(tx).FindAll()
if err != nil {
return nil, fmt.Errorf("could not get local ids: %v", err)
}
all, err := deps.TaskRepo.FindMany(storage.TaskListParams{
all, err := repos.Task(tx).FindMany(storage.TaskListParams{
HasRecurrer: list.Args.HasRecurrer,
From: list.Args.From,
To: list.Args.To,

View File

@ -79,8 +79,8 @@ func TestListParse(t *testing.T) {
func TestList(t *testing.T) {
t.Parallel()
taskRepo := memory.NewTask()
localRepo := memory.NewLocalID()
mems := memory.New()
e := item.Task{
ID: "id",
Date: item.NewDate(2024, 10, 7),
@ -88,10 +88,10 @@ func TestList(t *testing.T) {
Title: "name",
},
}
if err := taskRepo.Store(e); err != nil {
if err := mems.Task(nil).Store(e); err != nil {
t.Errorf("exp nil, got %v", err)
}
if err := localRepo.Store(e.ID, 1); err != nil {
if err := mems.LocalID(nil).Store(e.ID, 1); err != nil {
t.Errorf("exp nil, got %v", err)
}
@ -115,10 +115,7 @@ func TestList(t *testing.T) {
},
} {
t.Run(tc.name, func(t *testing.T) {
res, err := tc.cmd.Do(command.Dependencies{
TaskRepo: taskRepo,
LocalIDRepo: localRepo,
})
res, err := tc.cmd.Do(mems, nil)
if err != nil {
t.Errorf("exp nil, got %v", err)
}

View File

@ -5,6 +5,7 @@ import (
"sort"
"go-mod.ewintr.nl/planner/plan/format"
"go-mod.ewintr.nl/planner/sync/client"
)
type ProjectsArgs struct{}
@ -23,8 +24,14 @@ func (pa ProjectsArgs) Parse(main []string, fields map[string]string) (Command,
type Projects struct{}
func (ps Projects) Do(deps Dependencies) (CommandResult, error) {
projects, err := deps.TaskRepo.Projects()
func (ps Projects) Do(repos Repositories, _ client.Client) (CommandResult, error) {
tx, err := repos.Begin()
if err != nil {
return nil, fmt.Errorf("could not start transaction: %v", err)
}
defer tx.Rollback()
projects, err := repos.Task(tx).Projects()
if err != nil {
return nil, fmt.Errorf("could not find projects: %v", err)
}

View File

@ -8,6 +8,7 @@ import (
"go-mod.ewintr.nl/planner/item"
"go-mod.ewintr.nl/planner/plan/format"
"go-mod.ewintr.nl/planner/plan/storage"
"go-mod.ewintr.nl/planner/sync/client"
)
type ShowArgs struct {
@ -38,8 +39,14 @@ type Show struct {
args ShowArgs
}
func (s Show) Do(deps Dependencies) (CommandResult, error) {
id, err := deps.LocalIDRepo.FindOne(s.args.localID)
func (s Show) Do(repos Repositories, _ client.Client) (CommandResult, error) {
tx, err := repos.Begin()
if err != nil {
return nil, fmt.Errorf("could not start transaction: %v", err)
}
defer tx.Rollback()
id, err := repos.LocalID(tx).FindOne(s.args.localID)
switch {
case errors.Is(err, storage.ErrNotFound):
return nil, fmt.Errorf("could not find local id")
@ -47,7 +54,7 @@ func (s Show) Do(deps Dependencies) (CommandResult, error) {
return nil, err
}
tsk, err := deps.TaskRepo.FindOne(id)
tsk, err := repos.Task(tx).FindOne(id)
if err != nil {
return nil, fmt.Errorf("could not find task")
}

View File

@ -12,8 +12,8 @@ import (
func TestShow(t *testing.T) {
t.Parallel()
taskRepo := memory.NewTask()
localRepo := memory.NewLocalID()
mems := memory.New()
tsk := item.Task{
ID: "id",
Date: item.NewDate(2024, 10, 7),
@ -21,10 +21,10 @@ func TestShow(t *testing.T) {
Title: "name",
},
}
if err := taskRepo.Store(tsk); err != nil {
if err := mems.Task(nil).Store(tsk); err != nil {
t.Errorf("exp nil, got %v", err)
}
if err := localRepo.Store(tsk.ID, 1); err != nil {
if err := mems.LocalID(nil).Store(tsk.ID, 1); err != nil {
t.Errorf("exp nil, got %v", err)
}
@ -69,10 +69,7 @@ func TestShow(t *testing.T) {
}
// do
_, actDoErr := cmd.Do(command.Dependencies{
TaskRepo: taskRepo,
LocalIDRepo: localRepo,
})
_, actDoErr := cmd.Do(mems, nil)
if tc.expDoErr != (actDoErr != nil) {
t.Errorf("exp %v, got %v", tc.expDoErr, actDoErr != nil)
}

View File

@ -8,6 +8,7 @@ import (
"go-mod.ewintr.nl/planner/item"
"go-mod.ewintr.nl/planner/plan/storage"
"go-mod.ewintr.nl/planner/sync/client"
)
type SyncArgs struct{}
@ -26,25 +27,31 @@ func (sa SyncArgs) Parse(main []string, flags map[string]string) (Command, error
type Sync struct{}
func (s Sync) Do(deps Dependencies) (CommandResult, error) {
func (s Sync) Do(repos Repositories, client client.Client) (CommandResult, error) {
tx, err := repos.Begin()
if err != nil {
return nil, fmt.Errorf("could not start transaction: %v", err)
}
defer tx.Rollback()
// local new and updated
sendItems, err := deps.SyncRepo.FindAll()
sendItems, err := repos.Sync(tx).FindAll()
if err != nil {
return nil, fmt.Errorf("could not get updated items: %v", err)
}
if err := deps.SyncClient.Update(sendItems); err != nil {
if err := client.Update(sendItems); err != nil {
return nil, fmt.Errorf("could not send updated items: %v", err)
}
if err := deps.SyncRepo.DeleteAll(); err != nil {
if err := repos.Sync(tx).DeleteAll(); err != nil {
return nil, fmt.Errorf("could not clear updated items: %v", err)
}
// get new/updated items
oldTS, err := deps.SyncRepo.LastUpdate()
oldTS, err := repos.Sync(tx).LastUpdate()
if err != nil {
return nil, fmt.Errorf("could not find timestamp of last update: %v", err)
}
recItems, err := deps.SyncClient.Updated([]item.Kind{item.KindTask}, oldTS)
recItems, err := client.Updated([]item.Kind{item.KindTask}, oldTS)
if err != nil {
return nil, fmt.Errorf("could not receive updates: %v", err)
}
@ -56,10 +63,10 @@ func (s Sync) Do(deps Dependencies) (CommandResult, error) {
newTS = ri.Updated
}
if ri.Deleted {
if err := deps.LocalIDRepo.Delete(ri.ID); err != nil && !errors.Is(err, storage.ErrNotFound) {
if err := repos.LocalID(tx).Delete(ri.ID); err != nil && !errors.Is(err, storage.ErrNotFound) {
return nil, fmt.Errorf("could not delete local id: %v", err)
}
if err := deps.TaskRepo.Delete(ri.ID); err != nil && !errors.Is(err, storage.ErrNotFound) {
if err := repos.Task(tx).Delete(ri.ID); err != nil && !errors.Is(err, storage.ErrNotFound) {
return nil, fmt.Errorf("could not delete task: %v", err)
}
continue
@ -67,7 +74,7 @@ func (s Sync) Do(deps Dependencies) (CommandResult, error) {
updated = append(updated, ri)
}
lidMap, err := deps.LocalIDRepo.FindAll()
lidMap, err := repos.LocalID(tx).FindAll()
if err != nil {
return nil, fmt.Errorf("could not get local ids: %v", err)
}
@ -83,23 +90,23 @@ func (s Sync) Do(deps Dependencies) (CommandResult, error) {
RecurNext: u.RecurNext,
TaskBody: tskBody,
}
if err := deps.TaskRepo.Store(tsk); err != nil {
if err := repos.Task(tx).Store(tsk); err != nil {
return nil, fmt.Errorf("could not store task: %v", err)
}
lid, ok := lidMap[u.ID]
if !ok {
lid, err = deps.LocalIDRepo.Next()
lid, err = repos.LocalID(tx).Next()
if err != nil {
return nil, fmt.Errorf("could not get next local id: %v", err)
}
if err := deps.LocalIDRepo.Store(u.ID, lid); err != nil {
if err := repos.LocalID(tx).Store(u.ID, lid); err != nil {
return nil, fmt.Errorf("could not store local id: %v", err)
}
}
}
if err := deps.SyncRepo.SetLastUpdate(newTS); err != nil {
if err := repos.Sync(tx).SetLastUpdate(newTS); err != nil {
return nil, fmt.Errorf("could not store update timestamp: %v", err)
}

View File

@ -47,9 +47,7 @@ func TestSyncSend(t *testing.T) {
t.Parallel()
syncClient := client.NewMemory()
syncRepo := memory.NewSync()
localIDRepo := memory.NewLocalID()
taskRepo := memory.NewTask()
mems := memory.New()
it := item.Item{
ID: "a",
@ -60,7 +58,7 @@ func TestSyncSend(t *testing.T) {
"duration":"1h"
}`,
}
if err := syncRepo.Store(it); err != nil {
if err := mems.Sync(nil).Store(it); err != nil {
t.Errorf("exp nil, got %v", err)
}
@ -81,12 +79,7 @@ func TestSyncSend(t *testing.T) {
if err != nil {
t.Errorf("exp nil, got %v", err)
}
if _, err := cmd.Do(command.Dependencies{
TaskRepo: taskRepo,
LocalIDRepo: localIDRepo,
SyncRepo: syncRepo,
SyncClient: syncClient,
}); err != nil {
if _, err := cmd.Do(mems, syncClient); err != nil {
t.Errorf("exp nil, got %v", err)
}
actItems, actErr := syncClient.Updated(tc.ks, tc.ts)
@ -97,7 +90,7 @@ func TestSyncSend(t *testing.T) {
t.Errorf("(exp +, got -)\n%s", diff)
}
actLeft, actErr := syncRepo.FindAll()
actLeft, actErr := mems.Sync(nil).FindAll()
if actErr != nil {
t.Errorf("exp nil, got %v", actErr)
}
@ -185,16 +178,14 @@ func TestSyncReceive(t *testing.T) {
} {
t.Run(tc.name, func(t *testing.T) {
syncClient := client.NewMemory()
syncRepo := memory.NewSync()
localIDRepo := memory.NewLocalID()
taskRepo := memory.NewTask()
mems := memory.New()
// setup
for i, p := range tc.present {
if err := taskRepo.Store(p); err != nil {
if err := mems.Task(nil).Store(p); err != nil {
t.Errorf("exp nil, got %v", err)
}
if err := localIDRepo.Store(p.ID, i+1); err != nil {
if err := mems.LocalID(nil).Store(p.ID, i+1); err != nil {
t.Errorf("exp nil, got %v", err)
}
}
@ -207,24 +198,19 @@ func TestSyncReceive(t *testing.T) {
if err != nil {
t.Errorf("exp nil, got %v", err)
}
if _, err := cmd.Do(command.Dependencies{
TaskRepo: taskRepo,
LocalIDRepo: localIDRepo,
SyncRepo: syncRepo,
SyncClient: syncClient,
}); err != nil {
if _, err := cmd.Do(mems, syncClient); err != nil {
t.Errorf("exp nil, got %v", err)
}
// check result
actTasks, err := taskRepo.FindMany(storage.TaskListParams{})
actTasks, err := mems.Task(nil).FindMany(storage.TaskListParams{})
if err != nil {
t.Errorf("exp nil, got %v", err)
}
if diff := item.TaskDiffs(tc.expTask, actTasks); diff != "" {
t.Errorf("(exp +, got -)\n%s", diff)
}
actLocalIDs, err := localIDRepo.FindAll()
actLocalIDs, err := mems.LocalID(nil).FindAll()
if err != nil {
t.Errorf("exp nil, got %v", err)
}

View File

@ -10,6 +10,7 @@ import (
"go-mod.ewintr.nl/planner/item"
"go-mod.ewintr.nl/planner/plan/storage"
"go-mod.ewintr.nl/planner/sync/client"
)
type UpdateArgs struct {
@ -116,8 +117,14 @@ type Update struct {
args UpdateArgs
}
func (u Update) Do(deps Dependencies) (CommandResult, error) {
id, err := deps.LocalIDRepo.FindOne(u.args.LocalID)
func (u Update) Do(repos Repositories, _ client.Client) (CommandResult, error) {
tx, err := repos.Begin()
if err != nil {
return nil, fmt.Errorf("could not start transaction: %v", err)
}
defer tx.Rollback()
id, err := repos.LocalID(tx).FindOne(u.args.LocalID)
switch {
case errors.Is(err, storage.ErrNotFound):
return nil, fmt.Errorf("could not find local id")
@ -125,7 +132,7 @@ func (u Update) Do(deps Dependencies) (CommandResult, error) {
return nil, err
}
tsk, err := deps.TaskRepo.FindOne(id)
tsk, err := repos.Task(tx).FindOne(id)
if err != nil {
return nil, fmt.Errorf("could not find task")
}
@ -154,7 +161,7 @@ func (u Update) Do(deps Dependencies) (CommandResult, error) {
return nil, fmt.Errorf("task is unvalid")
}
if err := deps.TaskRepo.Store(tsk); err != nil {
if err := repos.Task(tx).Store(tsk); err != nil {
return nil, fmt.Errorf("could not store task: %v", err)
}
@ -162,7 +169,7 @@ func (u Update) Do(deps Dependencies) (CommandResult, error) {
if err != nil {
return nil, fmt.Errorf("could not convert task to sync item: %v", err)
}
if err := deps.SyncRepo.Store(it); err != nil {
if err := repos.Sync(tx).Store(it); err != nil {
return nil, fmt.Errorf("could not store sync item: %v", err)
}

View File

@ -189,10 +189,8 @@ func TestUpdateExecute(t *testing.T) {
} {
t.Run(tc.name, func(t *testing.T) {
// setup
taskRepo := memory.NewTask()
localIDRepo := memory.NewLocalID()
syncRepo := memory.NewSync()
if err := taskRepo.Store(item.Task{
mems := memory.New()
if err := mems.Task(nil).Store(item.Task{
ID: tskID,
Date: aDate,
TaskBody: item.TaskBody{
@ -204,7 +202,7 @@ func TestUpdateExecute(t *testing.T) {
}); err != nil {
t.Errorf("exp nil, got %v", err)
}
if err := localIDRepo.Store(tskID, lid); err != nil {
if err := mems.LocalID(nil).Store(tskID, lid); err != nil {
t.Errorf("exp nil, ,got %v", err)
}
@ -218,11 +216,7 @@ func TestUpdateExecute(t *testing.T) {
}
// do
_, actDoErr := cmd.Do(command.Dependencies{
TaskRepo: taskRepo,
LocalIDRepo: localIDRepo,
SyncRepo: syncRepo,
})
_, actDoErr := cmd.Do(mems, nil)
if tc.expDoErr != (actDoErr != nil) {
t.Errorf("exp %v, got %v", tc.expDoErr, actDoErr)
}
@ -231,14 +225,14 @@ func TestUpdateExecute(t *testing.T) {
}
// check
actTask, err := taskRepo.FindOne(tskID)
actTask, err := mems.Task(nil).FindOne(tskID)
if err != nil {
t.Errorf("exp nil, got %v", err)
}
if diff := item.TaskDiff(tc.expTask, actTask); diff != "" {
t.Errorf("(exp -, got +)\n%s", diff)
}
updated, err := syncRepo.FindAll()
updated, err := mems.Sync(nil).FindAll()
if err != nil {
t.Errorf("exp nil, got %v", err)
}

View File

@ -27,7 +27,7 @@ func main() {
os.Exit(1)
}
localIDRepo, taskRepo, syncRepo, err := sqlite.NewSqlites(conf.DBPath)
repos, err := sqlite.NewSqlites(conf.DBPath)
if err != nil {
fmt.Printf("could not open db file: %s\n", err)
os.Exit(1)
@ -35,12 +35,7 @@ func main() {
syncClient := client.New(conf.SyncURL, conf.ApiKey)
cli := command.NewCLI(command.Dependencies{
LocalIDRepo: localIDRepo,
TaskRepo: taskRepo,
SyncRepo: syncRepo,
SyncClient: syncClient,
})
cli := command.NewCLI(repos, syncClient)
if err := cli.Run(os.Args[1:]); err != nil {
fmt.Println(err)
os.Exit(1)

View File

@ -0,0 +1,35 @@
package memory
import (
"go-mod.ewintr.nl/planner/plan/storage"
)
type Memories struct {
localID *LocalID
sync *Sync
task *Task
}
func New() *Memories {
return &Memories{
localID: NewLocalID(),
sync: NewSync(),
task: NewTask(),
}
}
func (mems *Memories) Begin() (*storage.Tx, error) {
return &storage.Tx{}, nil
}
func (mems *Memories) LocalID(_ *storage.Tx) storage.LocalID {
return mems.localID
}
func (mems *Memories) Sync(_ *storage.Tx) storage.Sync {
return mems.sync
}
func (mems *Memories) Task(_ *storage.Tx) storage.Task {
return mems.task
}

View File

@ -9,12 +9,12 @@ import (
)
type LocalID struct {
db *sql.DB
tx *storage.Tx
}
func (l *LocalID) FindOne(lid int) (string, error) {
var id string
err := l.db.QueryRow(`
err := l.tx.QueryRow(`
SELECT id
FROM localids
WHERE local_id = ?
@ -30,7 +30,7 @@ WHERE local_id = ?
}
func (l *LocalID) FindAll() (map[string]int, error) {
rows, err := l.db.Query(`
rows, err := l.tx.Query(`
SELECT id, local_id
FROM localids
`)
@ -69,7 +69,7 @@ func (l *LocalID) Next() (int, error) {
}
func (l *LocalID) Store(id string, localID int) error {
if _, err := l.db.Exec(`
if _, err := l.tx.Exec(`
INSERT INTO localids
(id, local_id)
VALUES
@ -81,7 +81,7 @@ VALUES
}
func (l *LocalID) Delete(id string) error {
result, err := l.db.Exec(`
result, err := l.tx.Exec(`
DELETE FROM localids
WHERE id = ?`, id)
if err != nil {

View File

@ -5,6 +5,7 @@ import (
"errors"
"fmt"
"go-mod.ewintr.nl/planner/plan/storage"
_ "modernc.org/sqlite"
)
@ -19,27 +20,43 @@ var (
ErrSqliteFailure = errors.New("sqlite returned an error")
)
func NewSqlites(dbPath string) (*LocalID, *SqliteTask, *SqliteSync, error) {
type Sqlites struct {
db *sql.DB
}
func (sqs *Sqlites) Begin() (*storage.Tx, error) {
tx, err := sqs.db.Begin()
if err != nil {
return nil, err
}
return storage.NewTx(tx), nil
}
func (sqs *Sqlites) LocalID(tx *storage.Tx) storage.LocalID {
return &LocalID{tx: tx}
}
func (sqs *Sqlites) Sync(tx *storage.Tx) storage.Sync {
return &Sync{tx: tx}
}
func (sqs *Sqlites) Task(tx *storage.Tx) storage.Task {
return &SqliteTask{tx: tx}
}
func NewSqlites(dbPath string) (*Sqlites, error) {
db, err := sql.Open("sqlite", dbPath)
if err != nil {
return nil, nil, nil, fmt.Errorf("%w: %v", ErrInvalidConfiguration, err)
}
sl := &LocalID{
db: db,
}
se := &SqliteTask{
db: db,
}
ss := &SqliteSync{
db: db,
return nil, fmt.Errorf("%w: %v", ErrInvalidConfiguration, err)
}
if err := migrate(db, migrations); err != nil {
return nil, nil, nil, err
return nil, err
}
return sl, se, ss, nil
return &Sqlites{
db: db,
}, nil
}
func migrate(db *sql.DB, wanted []string) error {

View File

@ -6,18 +6,15 @@ import (
"time"
"go-mod.ewintr.nl/planner/item"
"go-mod.ewintr.nl/planner/plan/storage"
)
type SqliteSync struct {
db *sql.DB
type Sync struct {
tx *storage.Tx
}
func NewSqliteSync(db *sql.DB) *SqliteSync {
return &SqliteSync{db: db}
}
func (s *SqliteSync) FindAll() ([]item.Item, error) {
rows, err := s.db.Query("SELECT id, kind, updated, deleted, date, recurrer, recur_next, body FROM items")
func (s *Sync) FindAll() ([]item.Item, error) {
rows, err := s.tx.Query("SELECT id, kind, updated, deleted, date, recurrer, recur_next, body FROM items")
if err != nil {
return nil, fmt.Errorf("%w: failed to query items: %v", ErrSqliteFailure, err)
}
@ -49,7 +46,7 @@ func (s *SqliteSync) FindAll() ([]item.Item, error) {
return items, nil
}
func (s *SqliteSync) Store(i item.Item) error {
func (s *Sync) Store(i item.Item) error {
if i.Updated.IsZero() {
i.Updated = time.Now()
}
@ -58,7 +55,7 @@ func (s *SqliteSync) Store(i item.Item) error {
recurStr = i.Recurrer.String()
}
_, err := s.db.Exec(
_, err := s.tx.Exec(
`INSERT OR REPLACE INTO items (id, kind, updated, deleted, date, recurrer, recur_next, body)
VALUES (?, ?, ?, ?, ?, ?, ?, ?)`,
i.ID,
@ -76,24 +73,24 @@ func (s *SqliteSync) Store(i item.Item) error {
return nil
}
func (s *SqliteSync) DeleteAll() error {
_, err := s.db.Exec("DELETE FROM items")
func (s *Sync) DeleteAll() error {
_, err := s.tx.Exec("DELETE FROM items")
if err != nil {
return fmt.Errorf("%w: failed to delete all items: %v", ErrSqliteFailure, err)
}
return nil
}
func (s *SqliteSync) SetLastUpdate(ts time.Time) error {
if _, err := s.db.Exec(`UPDATE syncupdate SET timestamp = ?`, ts.Format(time.RFC3339)); err != nil {
func (s *Sync) SetLastUpdate(ts time.Time) error {
if _, err := s.tx.Exec(`UPDATE syncupdate SET timestamp = ?`, ts.Format(time.RFC3339)); err != nil {
return fmt.Errorf("%w: could not store timestamp: %v", ErrSqliteFailure, err)
}
return nil
}
func (s *SqliteSync) LastUpdate() (time.Time, error) {
func (s *Sync) LastUpdate() (time.Time, error) {
var tsStr string
if err := s.db.QueryRow("SELECT timestamp FROM syncupdate").Scan(&tsStr); err != nil {
if err := s.tx.QueryRow("SELECT timestamp FROM syncupdate").Scan(&tsStr); err != nil {
return time.Time{}, fmt.Errorf("%w: failed to get last update: %v", ErrSqliteFailure, err)
}
ts, err := time.Parse(time.RFC3339, tsStr)

View File

@ -10,7 +10,7 @@ import (
)
type SqliteTask struct {
db *sql.DB
tx *storage.Tx
}
func (t *SqliteTask) Store(tsk item.Task) error {
@ -18,7 +18,7 @@ func (t *SqliteTask) Store(tsk item.Task) error {
if tsk.Recurrer != nil {
recurStr = tsk.Recurrer.String()
}
if _, err := t.db.Exec(`
if _, err := t.tx.Exec(`
INSERT INTO tasks
(id, title, project, date, time, duration, recurrer)
VALUES
@ -42,7 +42,7 @@ recurrer=?
func (t *SqliteTask) FindOne(id string) (item.Task, error) {
var tsk item.Task
var dateStr, timeStr, recurStr, durStr string
err := t.db.QueryRow(`
err := t.tx.QueryRow(`
SELECT id, title, project, date, time, duration, recurrer
FROM tasks
WHERE id = ?`, id).Scan(&tsk.ID, &tsk.Title, &tsk.Project, &dateStr, &timeStr, &durStr, &recurStr)
@ -98,7 +98,7 @@ func (t *SqliteTask) FindMany(params storage.TaskListParams) ([]item.Task, error
}
}
rows, err := t.db.Query(query, args...)
rows, err := t.tx.Query(query, args...)
if err != nil {
return nil, fmt.Errorf("%w: %v", ErrSqliteFailure, err)
}
@ -126,7 +126,7 @@ func (t *SqliteTask) FindMany(params storage.TaskListParams) ([]item.Task, error
}
func (t *SqliteTask) Delete(id string) error {
result, err := t.db.Exec(`
result, err := t.tx.Exec(`
DELETE FROM tasks
WHERE id = ?`, id)
if err != nil {
@ -146,7 +146,7 @@ WHERE id = ?`, id)
}
func (t *SqliteTask) Projects() (map[string]int, error) {
rows, err := t.db.Query(`SELECT project, count(*) FROM tasks GROUP BY project`)
rows, err := t.tx.Query(`SELECT project, count(*) FROM tasks GROUP BY project`)
if err != nil {
return nil, fmt.Errorf("%w: %v", ErrSqliteFailure, err)
}

View File

@ -12,6 +12,10 @@ var (
ErrNotFound = errors.New("not found")
)
type Txer interface {
Begin() (*Tx, error)
}
type LocalID interface {
FindOne(lid int) (string, error)
FindAll() (map[string]int, error)

View File

@ -0,0 +1,40 @@
package storage
import "database/sql"
// Tx wraps sql.Tx so transactions can be skipped for in-memory repositories
type Tx struct {
tx *sql.Tx
}
func NewTx(tx *sql.Tx) *Tx {
return &Tx{tx: tx}
}
func (tx *Tx) Rollback() error {
if tx.tx == nil {
return nil
}
return tx.tx.Rollback()
}
func (tx *Tx) Commit() error {
if tx.tx == nil {
return nil
}
return tx.tx.Commit()
}
func (tx *Tx) QueryRow(query string, args ...any) *sql.Row {
return tx.tx.QueryRow(query, args...)
}
func (tx *Tx) Query(query string, args ...any) (*sql.Rows, error) {
return tx.tx.Query(query, args...)
}
func (tx *Tx) Exec(query string, args ...any) (sql.Result, error) {
return tx.tx.Exec(query, args...)
}