diff --git a/storage/boltdb/boltdb.go b/storage/boltdb/boltdb.go index 82a6330f..45e50f91 100644 --- a/storage/boltdb/boltdb.go +++ b/storage/boltdb/boltdb.go @@ -93,21 +93,16 @@ func (s *BoltStore) SaveUser(user *storage.User) error { func (s *BoltStore) DeleteUser(user *storage.User) error { return s.db.Batch(func(tx *bolt.Tx) error { - b := tx.Bucket(bucketServers) - c := b.Cursor() - - for k, _ := c.Seek(user.IDBytes); bytes.HasPrefix(k, user.IDBytes); k, _ = c.Next() { - b.Delete(k) + err := tx.Bucket(bucketUsers).Delete(user.IDBytes) + if err != nil { + return err } - b = tx.Bucket(bucketChannels) - c = b.Cursor() - - for k, _ := c.Seek(user.IDBytes); bytes.HasPrefix(k, user.IDBytes); k, _ = c.Next() { - b.Delete(k) - } - - return tx.Bucket(bucketUsers).Delete(user.IDBytes) + return deletePrefix(user.IDBytes, + tx.Bucket(bucketServers), + tx.Bucket(bucketChannels), + tx.Bucket(bucketOpenDMs), + ) }) } @@ -161,23 +156,15 @@ func (s *BoltStore) SaveServer(user *storage.User, server *storage.Server) error func (s *BoltStore) RemoveServer(user *storage.User, address string) error { return s.db.Batch(func(tx *bolt.Tx) error { serverID := serverID(user, address) - tx.Bucket(bucketServers).Delete(serverID) - - b := tx.Bucket(bucketChannels) - c := b.Cursor() - - for k, _ := c.Seek(serverID); bytes.HasPrefix(k, serverID); k, _ = c.Next() { - b.Delete(k) + err := tx.Bucket(bucketServers).Delete(serverID) + if err != nil { + return err } - b = tx.Bucket(bucketOpenDMs) - c = b.Cursor() - - for k, _ := c.Seek(serverID); bytes.HasPrefix(k, serverID); k, _ = c.Next() { - b.Delete(k) - } - - return nil + return deletePrefix(serverID, + tx.Bucket(bucketChannels), + tx.Bucket(bucketOpenDMs), + ) }) } @@ -403,6 +390,21 @@ func (s *BoltStore) DeleteSession(key string) error { }) } +func deletePrefix(prefix []byte, buckets ...*bolt.Bucket) error { + for _, b := range buckets { + c := b.Cursor() + + for k, _ := c.Seek(prefix); bytes.HasPrefix(k, prefix); k, _ = c.Next() { + err := b.Delete(k) + if err != nil { + return err + } + } + } + + return nil +} + func serverID(user *storage.User, address string) []byte { id := make([]byte, 8+len(address)) copy(id, user.IDBytes) diff --git a/storage/user_test.go b/storage/user_test.go index 864628d0..8772c554 100644 --- a/storage/user_test.go +++ b/storage/user_test.go @@ -100,10 +100,16 @@ func TestUser(t *testing.T) { assert.Equal(t, settings, user.GetClientSettings()) assert.NotEqual(t, settings, storage.DefaultClientSettings()) + user.AddOpenDM(srv.Host, "cake") + user.Remove() _, err = os.Stat(storage.Path.User(user.Username)) assert.True(t, os.IsNotExist(err)) + openDMs, err = user.GetOpenDMs() + assert.Nil(t, err) + assert.Len(t, openDMs, 0) + users, err = storage.LoadUsers(db) assert.Nil(t, err)