diff --git a/plan/storage/memory/localid.go b/plan/storage/memory/localid.go index 40024b1..838e80a 100644 --- a/plan/storage/memory/localid.go +++ b/plan/storage/memory/localid.go @@ -1,6 +1,7 @@ package memory import ( + "errors" "sync" "go-mod.ewintr.nl/planner/plan/storage" @@ -24,6 +25,33 @@ func (ml *LocalID) FindAll() (map[string]int, error) { return ml.ids, nil } +func (ml *LocalID) Find(id string) (int, error) { + ml.mutex.RLock() + defer ml.mutex.RUnlock() + + lid, ok := ml.ids[id] + if !ok { + return 0, storage.ErrNotFound + } + + return lid, nil +} + +func (ml *LocalID) FindOrNext(id string) (int, error) { + ml.mutex.Lock() + defer ml.mutex.Unlock() + + lid, err := ml.Find(id) + switch { + case errors.Is(err, storage.ErrNotFound): + return ml.Next() + case err != nil: + return 0, err + default: + return lid, nil + } +} + func (ml *LocalID) Next() (int, error) { ml.mutex.RLock() defer ml.mutex.RUnlock() diff --git a/plan/storage/memory/localid_test.go b/plan/storage/memory/localid_test.go index 02d060e..f750e30 100644 --- a/plan/storage/memory/localid_test.go +++ b/plan/storage/memory/localid_test.go @@ -37,6 +37,15 @@ func TestLocalID(t *testing.T) { t.Errorf("exp nil, got %v", actErr) } + t.Log("retrieve known") + actLid, actErr := repo.FindOrNext("test") + if actErr != nil { + t.Errorf("exp nil, got %v", actErr) + } + if actLid != 1 { + t.Errorf("exp 1, git %v", actLid) + } + actIDs, actErr = repo.FindAll() if actErr != nil { t.Errorf("exp nil, got %v", actErr) diff --git a/plan/storage/storage.go b/plan/storage/storage.go index c7bfdd9..cdb9907 100644 --- a/plan/storage/storage.go +++ b/plan/storage/storage.go @@ -13,6 +13,7 @@ var ( type LocalID interface { FindAll() (map[string]int, error) + FindOrNext(id string) (int, error) Next() (int, error) Store(id string, localID int) error Delete(id string) error