diff --git a/Godeps/Godeps.json b/Godeps/Godeps.json index 2cde85a2..f31b0f22 100644 --- a/Godeps/Godeps.json +++ b/Godeps/Godeps.json @@ -29,8 +29,8 @@ }, { "ImportPath": "github.com/boltdb/bolt", - "Comment": "v1.0-43-gcf33c9e", - "Rev": "cf33c9e0ca0a23509b8bb8edfc63e4776bb1a330" + "Comment": "v1.1.0-61-g6465994", + "Rev": "6465994716bf6400605746e79224cf1e7ed68725" }, { "ImportPath": "github.com/coreos/go-etcd/etcd", @@ -52,7 +52,7 @@ }, { "ImportPath": "github.com/gorilla/websocket", - "Rev": "ecff5aabe41f13b4cdf897e3c0c9bbccbe552a29" + "Rev": "3986be78bf859e01f01af631ad76da5b269d270c" }, { "ImportPath": "github.com/inconshreveable/mousetrap", @@ -82,7 +82,7 @@ }, { "ImportPath": "github.com/mitchellh/go-homedir", - "Rev": "1f6da4a72e57d4e7edd4a7295a585e0a3999a2d4" + "Rev": "d682a8f0cf139663a984ff12528da460ca963de9" }, { "ImportPath": "github.com/mitchellh/mapstructure", @@ -94,11 +94,11 @@ }, { "ImportPath": "github.com/spf13/cast", - "Rev": "4d07383ffe94b5e5a6fa3af9211374a4507a0184" + "Rev": "ee7b3e0353166ab1f3a605294ac8cd2b77953778" }, { "ImportPath": "github.com/spf13/cobra", - "Rev": "3ee9552eebbb5db27cb81abcae66c6f1430cad29" + "Rev": "9c9300901990faada0c5fb3b5730f452585c7c2b" }, { "ImportPath": "github.com/spf13/jwalterweatherman", @@ -106,11 +106,11 @@ }, { "ImportPath": "github.com/spf13/pflag", - "Rev": "32bfad653e29e893e4ed3812fdc0294a05126c08" + "Rev": "7f60f83a2c81bc3c3c0d5297f61ddfa68da9d3b7" }, { "ImportPath": "github.com/spf13/viper", - "Rev": "d62d4bb4c68a773c3b5f4e72844913a2d5de0de0" + "Rev": "a212099cbe6fbe8d07476bfda8d2d39b6ff8f325" }, { "ImportPath": "github.com/square/go-jose", @@ -140,8 +140,8 @@ }, { "ImportPath": "github.com/xenolf/lego/acme", - "Comment": "v0.1.1-34-g12b5de7", - "Rev": "12b5de7e8cb451949aabad64cce93e4b846e2aa7" + "Comment": "v0.2.0-6-gdb3a956", + "Rev": "db3a956d52bf23cc5201fe98bc9c9787d3b32c2d" }, { "ImportPath": "github.com/xordataexchange/crypt/backend", diff --git a/Godeps/_workspace/src/github.com/boltdb/bolt/.gitignore b/Godeps/_workspace/src/github.com/boltdb/bolt/.gitignore index b2bb382b..c7bd2b7a 100644 --- a/Godeps/_workspace/src/github.com/boltdb/bolt/.gitignore +++ b/Godeps/_workspace/src/github.com/boltdb/bolt/.gitignore @@ -1,3 +1,4 @@ *.prof *.test +*.swp /bin/ diff --git a/Godeps/_workspace/src/github.com/boltdb/bolt/Makefile b/Godeps/_workspace/src/github.com/boltdb/bolt/Makefile index cfbed514..e035e63a 100644 --- a/Godeps/_workspace/src/github.com/boltdb/bolt/Makefile +++ b/Godeps/_workspace/src/github.com/boltdb/bolt/Makefile @@ -1,54 +1,18 @@ -TEST=. -BENCH=. -COVERPROFILE=/tmp/c.out BRANCH=`git rev-parse --abbrev-ref HEAD` COMMIT=`git rev-parse --short HEAD` GOLDFLAGS="-X main.branch $(BRANCH) -X main.commit $(COMMIT)" default: build -bench: - go test -v -test.run=NOTHINCONTAINSTHIS -test.bench=$(BENCH) - -# http://cloc.sourceforge.net/ -cloc: - @cloc --not-match-f='Makefile|_test.go' . - -cover: fmt - go test -coverprofile=$(COVERPROFILE) -test.run=$(TEST) $(COVERFLAG) . - go tool cover -html=$(COVERPROFILE) - rm $(COVERPROFILE) - -cpuprofile: fmt - @go test -c - @./bolt.test -test.v -test.run=$(TEST) -test.cpuprofile cpu.prof +race: + @go test -v -race -test.run="TestSimulate_(100op|1000op)" # go get github.com/kisielk/errcheck errcheck: - @echo "=== errcheck ===" - @errcheck github.com/boltdb/bolt + @errcheck -ignorepkg=bytes -ignore=os:Remove github.com/boltdb/bolt -fmt: - @go fmt ./... +test: + @go test -v -cover . + @go test -v ./cmd/bolt -get: - @go get -d ./... - -build: get - @mkdir -p bin - @go build -ldflags=$(GOLDFLAGS) -a -o bin/bolt ./cmd/bolt - -test: fmt - @go get github.com/stretchr/testify/assert - @echo "=== TESTS ===" - @go test -v -cover -test.run=$(TEST) - @echo "" - @echo "" - @echo "=== CLI ===" - @go test -v -test.run=$(TEST) ./cmd/bolt - @echo "" - @echo "" - @echo "=== RACE DETECTOR ===" - @go test -v -race -test.run="TestSimulate_(100op|1000op)" - -.PHONY: bench cloc cover cpuprofile fmt memprofile test +.PHONY: fmt test diff --git a/Godeps/_workspace/src/github.com/boltdb/bolt/README.md b/Godeps/_workspace/src/github.com/boltdb/bolt/README.md index 02a85b0a..82e85742 100644 --- a/Godeps/_workspace/src/github.com/boltdb/bolt/README.md +++ b/Godeps/_workspace/src/github.com/boltdb/bolt/README.md @@ -1,8 +1,8 @@ -Bolt [![Build Status](https://drone.io/github.com/boltdb/bolt/status.png)](https://drone.io/github.com/boltdb/bolt/latest) [![Coverage Status](https://coveralls.io/repos/boltdb/bolt/badge.png?branch=master)](https://coveralls.io/r/boltdb/bolt?branch=master) [![GoDoc](https://godoc.org/github.com/boltdb/bolt?status.png)](https://godoc.org/github.com/boltdb/bolt) ![Version](http://img.shields.io/badge/version-1.0-green.png) +Bolt [![Build Status](https://drone.io/github.com/boltdb/bolt/status.png)](https://drone.io/github.com/boltdb/bolt/latest) [![Coverage Status](https://coveralls.io/repos/boltdb/bolt/badge.svg?branch=master)](https://coveralls.io/r/boltdb/bolt?branch=master) [![GoDoc](https://godoc.org/github.com/boltdb/bolt?status.svg)](https://godoc.org/github.com/boltdb/bolt) ![Version](https://img.shields.io/badge/version-1.0-green.svg) ==== -Bolt is a pure Go key/value store inspired by [Howard Chu's][hyc_symas] and -the [LMDB project][lmdb]. The goal of the project is to provide a simple, +Bolt is a pure Go key/value store inspired by [Howard Chu's][hyc_symas] +[LMDB project][lmdb]. The goal of the project is to provide a simple, fast, and reliable database for projects that don't require a full database server such as Postgres or MySQL. @@ -13,7 +13,6 @@ and setting values. That's it. [hyc_symas]: https://twitter.com/hyc_symas [lmdb]: http://symas.com/mdb/ - ## Project Status Bolt is stable and the API is fixed. Full unit test coverage and randomized @@ -22,6 +21,36 @@ Bolt is currently in high-load production environments serving databases as large as 1TB. Many companies such as Shopify and Heroku use Bolt-backed services every day. +## Table of Contents + +- [Getting Started](#getting-started) + - [Installing](#installing) + - [Opening a database](#opening-a-database) + - [Transactions](#transactions) + - [Read-write transactions](#read-write-transactions) + - [Read-only transactions](#read-only-transactions) + - [Batch read-write transactions](#batch-read-write-transactions) + - [Managing transactions manually](#managing-transactions-manually) + - [Using buckets](#using-buckets) + - [Using key/value pairs](#using-keyvalue-pairs) + - [Autoincrementing integer for the bucket](#autoincrementing-integer-for-the-bucket) + - [Iterating over keys](#iterating-over-keys) + - [Prefix scans](#prefix-scans) + - [Range scans](#range-scans) + - [ForEach()](#foreach) + - [Nested buckets](#nested-buckets) + - [Database backups](#database-backups) + - [Statistics](#statistics) + - [Read-Only Mode](#read-only-mode) + - [Mobile Use (iOS/Android)](#mobile-use-iosandroid) +- [Resources](#resources) +- [Comparison with other databases](#comparison-with-other-databases) + - [Postgres, MySQL, & other relational databases](#postgres-mysql--other-relational-databases) + - [LevelDB, RocksDB](#leveldb-rocksdb) + - [LMDB](#lmdb) +- [Caveats & Limitations](#caveats--limitations) +- [Reading the Source](#reading-the-source) +- [Other Projects Using Bolt](#other-projects-using-bolt) ## Getting Started @@ -87,6 +116,11 @@ are not thread safe. To work with data in multiple goroutines you must start a transaction for each one or use locking to ensure only one goroutine accesses a transaction at a time. Creating transaction from the `DB` is thread safe. +Read-only transactions and read-write transactions should not depend on one +another and generally shouldn't be opened simultaneously in the same goroutine. +This can cause a deadlock as the read-write transaction needs to periodically +re-map the data file but it cannot do so while a read-only transaction is open. + #### Read-write transactions @@ -125,6 +159,48 @@ no mutating operations are allowed within a read-only transaction. You can only retrieve buckets, retrieve values, and copy the database within a read-only transaction. + +#### Batch read-write transactions + +Each `DB.Update()` waits for disk to commit the writes. This overhead +can be minimized by combining multiple updates with the `DB.Batch()` +function: + +```go +err := db.Batch(func(tx *bolt.Tx) error { + ... + return nil +}) +``` + +Concurrent Batch calls are opportunistically combined into larger +transactions. Batch is only useful when there are multiple goroutines +calling it. + +The trade-off is that `Batch` can call the given +function multiple times, if parts of the transaction fail. The +function must be idempotent and side effects must take effect only +after a successful return from `DB.Batch()`. + +For example: don't display messages from inside the function, instead +set variables in the enclosing scope: + +```go +var id uint64 +err := db.Batch(func(tx *bolt.Tx) error { + // Find last key in bucket, decode as bigendian uint64, increment + // by one, encode back to []byte, and add new key. + ... + id = newValue + return nil +}) +if err != nil { + return ... +} +fmt.Println("Allocated ID %d", id) +``` + + #### Managing transactions manually The `DB.View()` and `DB.Update()` functions are wrappers around the `DB.Begin()` @@ -133,8 +209,8 @@ and then safely close your transaction if an error is returned. This is the recommended way to use Bolt transactions. However, sometimes you may want to manually start and end your transactions. -You can use the `Tx.Begin()` function directly but _please_ be sure to close the -transaction. +You can use the `Tx.Begin()` function directly but **please** be sure to close +the transaction. ```go // Start a writable transaction. @@ -209,13 +285,60 @@ db.View(func(tx *bolt.Tx) error { ``` The `Get()` function does not return an error because its operation is -guarenteed to work (unless there is some kind of system failure). If the key +guaranteed to work (unless there is some kind of system failure). If the key exists then it will return its byte slice value. If it doesn't exist then it will return `nil`. It's important to note that you can have a zero-length value set to a key which is different than the key not existing. Use the `Bucket.Delete()` function to delete a key from the bucket. +Please note that values returned from `Get()` are only valid while the +transaction is open. If you need to use a value outside of the transaction +then you must use `copy()` to copy it to another byte slice. + + +### Autoincrementing integer for the bucket +By using the `NextSequence()` function, you can let Bolt determine a sequence +which can be used as the unique identifier for your key/value pairs. See the +example below. + +```go +// CreateUser saves u to the store. The new user ID is set on u once the data is persisted. +func (s *Store) CreateUser(u *User) error { + return s.db.Update(func(tx *bolt.Tx) error { + // Retrieve the users bucket. + // This should be created when the DB is first opened. + b := tx.Bucket([]byte("users")) + + // Generate ID for the user. + // This returns an error only if the Tx is closed or not writeable. + // That can't happen in an Update() call so I ignore the error check. + id, _ = b.NextSequence() + u.ID = int(id) + + // Marshal user data into bytes. + buf, err := json.Marshal(u) + if err != nil { + return err + } + + // Persist bytes to users bucket. + return b.Put(itob(u.ID), buf) + }) +} + +// itob returns an 8-byte big endian representation of v. +func itob(v int) []byte { + b := make([]byte, 8) + binary.BigEndian.PutUint64(b, uint64(v)) + return b +} + +type User struct { + ID int + ... +} +``` ### Iterating over keys @@ -225,7 +348,9 @@ iteration over these keys extremely fast. To iterate over keys we'll use a ```go db.View(func(tx *bolt.Tx) error { + // Assume bucket exists and has keys b := tx.Bucket([]byte("MyBucket")) + c := b.Cursor() for k, v := c.First(); k != nil; k, v = c.Next() { @@ -249,10 +374,15 @@ Next() Move to the next key. Prev() Move to the previous key. ``` -When you have iterated to the end of the cursor then `Next()` will return `nil`. -You must seek to a position using `First()`, `Last()`, or `Seek()` before -calling `Next()` or `Prev()`. If you do not seek to a position then these -functions will return `nil`. +Each of those functions has a return signature of `(key []byte, value []byte)`. +When you have iterated to the end of the cursor then `Next()` will return a +`nil` key. You must seek to a position using `First()`, `Last()`, or `Seek()` +before calling `Next()` or `Prev()`. If you do not seek to a position then +these functions will return a `nil` key. + +During iteration, if the key is non-`nil` but the value is `nil`, that means +the key refers to a bucket rather than a value. Use `Bucket.Bucket()` to +access the sub-bucket. #### Prefix scans @@ -261,6 +391,7 @@ To iterate over a key prefix, you can combine `Seek()` and `bytes.HasPrefix()`: ```go db.View(func(tx *bolt.Tx) error { + // Assume bucket exists and has keys c := tx.Bucket([]byte("MyBucket")).Cursor() prefix := []byte("1234") @@ -280,7 +411,7 @@ date range like this: ```go db.View(func(tx *bolt.Tx) error { - // Assume our events bucket has RFC3339 encoded time keys. + // Assume our events bucket exists and has RFC3339 encoded time keys. c := tx.Bucket([]byte("Events")).Cursor() // Our time range spans the 90's decade. @@ -304,7 +435,9 @@ all the keys in a bucket: ```go db.View(func(tx *bolt.Tx) error { + // Assume bucket exists and has keys b := tx.Bucket([]byte("MyBucket")) + b.ForEach(func(k, v []byte) error { fmt.Printf("key=%s, value=%s\n", k, v) return nil @@ -328,22 +461,26 @@ func (*Bucket) DeleteBucket(key []byte) error ### Database backups -Bolt is a single file so it's easy to backup. You can use the `Tx.Copy()` +Bolt is a single file so it's easy to backup. You can use the `Tx.WriteTo()` function to write a consistent view of the database to a writer. If you call this from a read-only transaction, it will perform a hot backup and not block -your other database reads and writes. It will also use `O_DIRECT` when available -to prevent page cache trashing. +your other database reads and writes. + +By default, it will use a regular file handle which will utilize the operating +system's page cache. See the [`Tx`](https://godoc.org/github.com/boltdb/bolt#Tx) +documentation for information about optimizing for larger-than-RAM datasets. One common use case is to backup over HTTP so you can use tools like `cURL` to do database backups: ```go func BackupHandleFunc(w http.ResponseWriter, req *http.Request) { - err := db.View(func(tx bolt.Tx) error { + err := db.View(func(tx *bolt.Tx) error { w.Header().Set("Content-Type", "application/octet-stream") w.Header().Set("Content-Disposition", `attachment; filename="my.db"`) w.Header().Set("Content-Length", strconv.Itoa(int(tx.Size()))) - return tx.Copy(w) + _, err := tx.WriteTo(w) + return err }) if err != nil { http.Error(w, err.Error(), http.StatusInternalServerError) @@ -399,12 +536,105 @@ It's also useful to pipe these stats to a service such as statsd for monitoring or to provide an HTTP endpoint that will perform a fixed-length sample. +### Read-Only Mode + +Sometimes it is useful to create a shared, read-only Bolt database. To this, +set the `Options.ReadOnly` flag when opening your database. Read-only mode +uses a shared lock to allow multiple processes to read from the database but +it will block any processes from opening the database in read-write mode. + +```go +db, err := bolt.Open("my.db", 0666, &bolt.Options{ReadOnly: true}) +if err != nil { + log.Fatal(err) +} +``` + +### Mobile Use (iOS/Android) + +Bolt is able to run on mobile devices by leveraging the binding feature of the +[gomobile](https://github.com/golang/mobile) tool. Create a struct that will +contain your database logic and a reference to a `*bolt.DB` with a initializing +contstructor that takes in a filepath where the database file will be stored. +Neither Android nor iOS require extra permissions or cleanup from using this method. + +```go +func NewBoltDB(filepath string) *BoltDB { + db, err := bolt.Open(filepath+"/demo.db", 0600, nil) + if err != nil { + log.Fatal(err) + } + + return &BoltDB{db} +} + +type BoltDB struct { + db *bolt.DB + ... +} + +func (b *BoltDB) Path() string { + return b.db.Path() +} + +func (b *BoltDB) Close() { + b.db.Close() +} +``` + +Database logic should be defined as methods on this wrapper struct. + +To initialize this struct from the native language (both platforms now sync +their local storage to the cloud. These snippets disable that functionality for the +database file): + +#### Android + +```java +String path; +if (android.os.Build.VERSION.SDK_INT >=android.os.Build.VERSION_CODES.LOLLIPOP){ + path = getNoBackupFilesDir().getAbsolutePath(); +} else{ + path = getFilesDir().getAbsolutePath(); +} +Boltmobiledemo.BoltDB boltDB = Boltmobiledemo.NewBoltDB(path) +``` + +#### iOS + +```objc +- (void)demo { + NSString* path = [NSSearchPathForDirectoriesInDomains(NSLibraryDirectory, + NSUserDomainMask, + YES) objectAtIndex:0]; + GoBoltmobiledemoBoltDB * demo = GoBoltmobiledemoNewBoltDB(path); + [self addSkipBackupAttributeToItemAtPath:demo.path]; + //Some DB Logic would go here + [demo close]; +} + +- (BOOL)addSkipBackupAttributeToItemAtPath:(NSString *) filePathString +{ + NSURL* URL= [NSURL fileURLWithPath: filePathString]; + assert([[NSFileManager defaultManager] fileExistsAtPath: [URL path]]); + + NSError *error = nil; + BOOL success = [URL setResourceValue: [NSNumber numberWithBool: YES] + forKey: NSURLIsExcludedFromBackupKey error: &error]; + if(!success){ + NSLog(@"Error excluding %@ from backup %@", [URL lastPathComponent], error); + } + return success; +} + +``` + ## Resources For more information on getting started with Bolt, check out the following articles: * [Intro to BoltDB: Painless Performant Persistence](http://npf.io/2014/07/intro-to-boltdb-painless-performant-persistence/) by [Nate Finch](https://github.com/natefinch). - +* [Bolt -- an embedded key/value database for Go](https://www.progville.com/go/bolt-embedded-db-golang/) by Progville ## Comparison with other databases @@ -433,7 +663,7 @@ they are libraries bundled into the application, however, their underlying structure is a log-structured merge-tree (LSM tree). An LSM tree optimizes random writes by using a write ahead log and multi-tiered, sorted files called SSTables. Bolt uses a B+tree internally and only a single file. Both approaches -have trade offs. +have trade-offs. If you require a high random write throughput (>10,000 w/sec) or you need to use spinning disks then LevelDB could be a good choice. If your application is @@ -469,9 +699,8 @@ It's important to pick the right tool for the job and Bolt is no exception. Here are a few things to note when evaluating and using Bolt: * Bolt is good for read intensive workloads. Sequential write performance is - also fast but random writes can be slow. You can add a write-ahead log or - [transaction coalescer](https://github.com/boltdb/coalescer) in front of Bolt - to mitigate this issue. + also fast but random writes can be slow. You can use `DB.Batch()` or add a + write-ahead log to help mitigate this issue. * Bolt uses a B+tree internally so there can be a lot of random page access. SSDs provide a significant performance boost over spinning disks. @@ -501,7 +730,75 @@ Here are a few things to note when evaluating and using Bolt: can in memory and will release memory as needed to other processes. This means that Bolt can show very high memory usage when working with large databases. However, this is expected and the OS will release memory as needed. Bolt can - handle databases much larger than the available physical RAM. + handle databases much larger than the available physical RAM, provided its + memory-map fits in the process virtual address space. It may be problematic + on 32-bits systems. + +* The data structures in the Bolt database are memory mapped so the data file + will be endian specific. This means that you cannot copy a Bolt file from a + little endian machine to a big endian machine and have it work. For most + users this is not a concern since most modern CPUs are little endian. + +* Because of the way pages are laid out on disk, Bolt cannot truncate data files + and return free pages back to the disk. Instead, Bolt maintains a free list + of unused pages within its data file. These free pages can be reused by later + transactions. This works well for many use cases as databases generally tend + to grow. However, it's important to note that deleting large chunks of data + will not allow you to reclaim that space on disk. + + For more information on page allocation, [see this comment][page-allocation]. + +[page-allocation]: https://github.com/boltdb/bolt/issues/308#issuecomment-74811638 + + +## Reading the Source + +Bolt is a relatively small code base (<3KLOC) for an embedded, serializable, +transactional key/value database so it can be a good starting point for people +interested in how databases work. + +The best places to start are the main entry points into Bolt: + +- `Open()` - Initializes the reference to the database. It's responsible for + creating the database if it doesn't exist, obtaining an exclusive lock on the + file, reading the meta pages, & memory-mapping the file. + +- `DB.Begin()` - Starts a read-only or read-write transaction depending on the + value of the `writable` argument. This requires briefly obtaining the "meta" + lock to keep track of open transactions. Only one read-write transaction can + exist at a time so the "rwlock" is acquired during the life of a read-write + transaction. + +- `Bucket.Put()` - Writes a key/value pair into a bucket. After validating the + arguments, a cursor is used to traverse the B+tree to the page and position + where they key & value will be written. Once the position is found, the bucket + materializes the underlying page and the page's parent pages into memory as + "nodes". These nodes are where mutations occur during read-write transactions. + These changes get flushed to disk during commit. + +- `Bucket.Get()` - Retrieves a key/value pair from a bucket. This uses a cursor + to move to the page & position of a key/value pair. During a read-only + transaction, the key and value data is returned as a direct reference to the + underlying mmap file so there's no allocation overhead. For read-write + transactions, this data may reference the mmap file or one of the in-memory + node values. + +- `Cursor` - This object is simply for traversing the B+tree of on-disk pages + or in-memory nodes. It can seek to a specific key, move to the first or last + value, or it can move forward or backward. The cursor handles the movement up + and down the B+tree transparently to the end user. + +- `Tx.Commit()` - Converts the in-memory dirty nodes and the list of free pages + into pages to be written to disk. Writing to disk then occurs in two phases. + First, the dirty pages are written to disk and an `fsync()` occurs. Second, a + new meta page with an incremented transaction ID is written and another + `fsync()` occurs. This two phase write ensures that partially written data + pages are ignored in the event of a crash since the meta page pointing to them + is never written. Partially written meta pages are invalidated because they + are written with a checksum. + +If you have additional notes that could be helpful for others, please submit +them via pull request. ## Other Projects Using Bolt @@ -509,23 +806,35 @@ Here are a few things to note when evaluating and using Bolt: Below is a list of public, open source projects that use Bolt: * [Operation Go: A Routine Mission](http://gocode.io) - An online programming game for Golang using Bolt for user accounts and a leaderboard. -* [Bazil](https://github.com/bazillion/bazil) - A file system that lets your data reside where it is most convenient for it to reside. +* [Bazil](https://bazil.org/) - A file system that lets your data reside where it is most convenient for it to reside. * [DVID](https://github.com/janelia-flyem/dvid) - Added Bolt as optional storage engine and testing it against Basho-tuned leveldb. * [Skybox Analytics](https://github.com/skybox/skybox) - A standalone funnel analysis tool for web analytics. * [Scuttlebutt](https://github.com/benbjohnson/scuttlebutt) - Uses Bolt to store and process all Twitter mentions of GitHub projects. * [Wiki](https://github.com/peterhellberg/wiki) - A tiny wiki using Goji, BoltDB and Blackfriday. -* [ChainStore](https://github.com/nulayer/chainstore) - Simple key-value interface to a variety of storage engines organized as a chain of operations. +* [ChainStore](https://github.com/pressly/chainstore) - Simple key-value interface to a variety of storage engines organized as a chain of operations. * [MetricBase](https://github.com/msiebuhr/MetricBase) - Single-binary version of Graphite. * [Gitchain](https://github.com/gitchain/gitchain) - Decentralized, peer-to-peer Git repositories aka "Git meets Bitcoin". * [event-shuttle](https://github.com/sclasen/event-shuttle) - A Unix system service to collect and reliably deliver messages to Kafka. * [ipxed](https://github.com/kelseyhightower/ipxed) - Web interface and api for ipxed. * [BoltStore](https://github.com/yosssi/boltstore) - Session store using Bolt. -* [photosite/session](http://godoc.org/bitbucket.org/kardianos/photosite/session) - Sessions for a photo viewing site. +* [photosite/session](https://godoc.org/bitbucket.org/kardianos/photosite/session) - Sessions for a photo viewing site. * [LedisDB](https://github.com/siddontang/ledisdb) - A high performance NoSQL, using Bolt as optional storage. * [ipLocator](https://github.com/AndreasBriese/ipLocator) - A fast ip-geo-location-server using bolt with bloom filters. * [cayley](https://github.com/google/cayley) - Cayley is an open-source graph database using Bolt as optional backend. * [bleve](http://www.blevesearch.com/) - A pure Go search engine similar to ElasticSearch that uses Bolt as the default storage backend. * [tentacool](https://github.com/optiflows/tentacool) - REST api server to manage system stuff (IP, DNS, Gateway...) on a linux server. * [SkyDB](https://github.com/skydb/sky) - Behavioral analytics database. +* [Seaweed File System](https://github.com/chrislusf/seaweedfs) - Highly scalable distributed key~file system with O(1) disk read. +* [InfluxDB](https://influxdata.com) - Scalable datastore for metrics, events, and real-time analytics. +* [Freehold](http://tshannon.bitbucket.org/freehold/) - An open, secure, and lightweight platform for your files and data. +* [Prometheus Annotation Server](https://github.com/oliver006/prom_annotation_server) - Annotation server for PromDash & Prometheus service monitoring system. +* [Consul](https://github.com/hashicorp/consul) - Consul is service discovery and configuration made easy. Distributed, highly available, and datacenter-aware. +* [Kala](https://github.com/ajvb/kala) - Kala is a modern job scheduler optimized to run on a single node. It is persistent, JSON over HTTP API, ISO 8601 duration notation, and dependent jobs. +* [drive](https://github.com/odeke-em/drive) - drive is an unofficial Google Drive command line client for \*NIX operating systems. +* [stow](https://github.com/djherbis/stow) - a persistence manager for objects + backed by boltdb. +* [buckets](https://github.com/joyrexus/buckets) - a bolt wrapper streamlining + simple tx and key scans. +* [Request Baskets](https://github.com/darklynx/request-baskets) - A web service to collect arbitrary HTTP requests and inspect them via REST API or simple web UI, similar to [RequestBin](http://requestb.in/) service If you are using Bolt in a project please send a pull request to add it to the list. diff --git a/Godeps/_workspace/src/github.com/boltdb/bolt/bolt_arm64.go b/Godeps/_workspace/src/github.com/boltdb/bolt/bolt_arm64.go new file mode 100644 index 00000000..6d230935 --- /dev/null +++ b/Godeps/_workspace/src/github.com/boltdb/bolt/bolt_arm64.go @@ -0,0 +1,9 @@ +// +build arm64 + +package bolt + +// maxMapSize represents the largest mmap size supported by Bolt. +const maxMapSize = 0xFFFFFFFFFFFF // 256TB + +// maxAllocSize is the size used when creating array pointers. +const maxAllocSize = 0x7FFFFFFF diff --git a/Godeps/_workspace/src/github.com/boltdb/bolt/bolt_linux.go b/Godeps/_workspace/src/github.com/boltdb/bolt/bolt_linux.go index e9d1c907..2b676661 100644 --- a/Godeps/_workspace/src/github.com/boltdb/bolt/bolt_linux.go +++ b/Godeps/_workspace/src/github.com/boltdb/bolt/bolt_linux.go @@ -4,8 +4,6 @@ import ( "syscall" ) -var odirect = syscall.O_DIRECT - // fdatasync flushes written data to a file descriptor. func fdatasync(db *DB) error { return syscall.Fdatasync(int(db.file.Fd())) diff --git a/Godeps/_workspace/src/github.com/boltdb/bolt/bolt_openbsd.go b/Godeps/_workspace/src/github.com/boltdb/bolt/bolt_openbsd.go index 7c1bef1a..7058c3d7 100644 --- a/Godeps/_workspace/src/github.com/boltdb/bolt/bolt_openbsd.go +++ b/Godeps/_workspace/src/github.com/boltdb/bolt/bolt_openbsd.go @@ -11,8 +11,6 @@ const ( msInvalidate // invalidate cached data ) -var odirect int - func msync(db *DB) error { _, _, errno := syscall.Syscall(syscall.SYS_MSYNC, uintptr(unsafe.Pointer(db.data)), uintptr(db.datasz), msInvalidate) if errno != 0 { diff --git a/Godeps/_workspace/src/github.com/boltdb/bolt/bolt_ppc64le.go b/Godeps/_workspace/src/github.com/boltdb/bolt/bolt_ppc64le.go new file mode 100644 index 00000000..8351e129 --- /dev/null +++ b/Godeps/_workspace/src/github.com/boltdb/bolt/bolt_ppc64le.go @@ -0,0 +1,9 @@ +// +build ppc64le + +package bolt + +// maxMapSize represents the largest mmap size supported by Bolt. +const maxMapSize = 0xFFFFFFFFFFFF // 256TB + +// maxAllocSize is the size used when creating array pointers. +const maxAllocSize = 0x7FFFFFFF diff --git a/Godeps/_workspace/src/github.com/boltdb/bolt/bolt_s390x.go b/Godeps/_workspace/src/github.com/boltdb/bolt/bolt_s390x.go new file mode 100644 index 00000000..f4dd26bb --- /dev/null +++ b/Godeps/_workspace/src/github.com/boltdb/bolt/bolt_s390x.go @@ -0,0 +1,9 @@ +// +build s390x + +package bolt + +// maxMapSize represents the largest mmap size supported by Bolt. +const maxMapSize = 0xFFFFFFFFFFFF // 256TB + +// maxAllocSize is the size used when creating array pointers. +const maxAllocSize = 0x7FFFFFFF diff --git a/Godeps/_workspace/src/github.com/boltdb/bolt/bolt_test.go b/Godeps/_workspace/src/github.com/boltdb/bolt/bolt_test.go deleted file mode 100644 index b7bea1fc..00000000 --- a/Godeps/_workspace/src/github.com/boltdb/bolt/bolt_test.go +++ /dev/null @@ -1,36 +0,0 @@ -package bolt_test - -import ( - "fmt" - "path/filepath" - "reflect" - "runtime" - "testing" -) - -// assert fails the test if the condition is false. -func assert(tb testing.TB, condition bool, msg string, v ...interface{}) { - if !condition { - _, file, line, _ := runtime.Caller(1) - fmt.Printf("\033[31m%s:%d: "+msg+"\033[39m\n\n", append([]interface{}{filepath.Base(file), line}, v...)...) - tb.FailNow() - } -} - -// ok fails the test if an err is not nil. -func ok(tb testing.TB, err error) { - if err != nil { - _, file, line, _ := runtime.Caller(1) - fmt.Printf("\033[31m%s:%d: unexpected error: %s\033[39m\n\n", filepath.Base(file), line, err.Error()) - tb.FailNow() - } -} - -// equals fails the test if exp is not equal to act. -func equals(tb testing.TB, exp, act interface{}) { - if !reflect.DeepEqual(exp, act) { - _, file, line, _ := runtime.Caller(1) - fmt.Printf("\033[31m%s:%d:\n\n\texp: %#v\n\n\tgot: %#v\033[39m\n\n", filepath.Base(file), line, exp, act) - tb.FailNow() - } -} diff --git a/Godeps/_workspace/src/github.com/boltdb/bolt/bolt_unix.go b/Godeps/_workspace/src/github.com/boltdb/bolt/bolt_unix.go index e222cfdc..4b0723aa 100644 --- a/Godeps/_workspace/src/github.com/boltdb/bolt/bolt_unix.go +++ b/Godeps/_workspace/src/github.com/boltdb/bolt/bolt_unix.go @@ -1,4 +1,4 @@ -// +build !windows,!plan9 +// +build !windows,!plan9,!solaris package bolt @@ -11,7 +11,7 @@ import ( ) // flock acquires an advisory lock on a file descriptor. -func flock(f *os.File, timeout time.Duration) error { +func flock(f *os.File, exclusive bool, timeout time.Duration) error { var t time.Time for { // If we're beyond our timeout then return an error. @@ -21,9 +21,13 @@ func flock(f *os.File, timeout time.Duration) error { } else if timeout > 0 && time.Since(t) > timeout { return ErrTimeout } + flag := syscall.LOCK_SH + if exclusive { + flag = syscall.LOCK_EX + } // Otherwise attempt to obtain an exclusive lock. - err := syscall.Flock(int(f.Fd()), syscall.LOCK_EX|syscall.LOCK_NB) + err := syscall.Flock(int(f.Fd()), flag|syscall.LOCK_NB) if err == nil { return nil } else if err != syscall.EWOULDBLOCK { @@ -42,21 +46,17 @@ func funlock(f *os.File) error { // mmap memory maps a DB's data file. func mmap(db *DB, sz int) error { - // Truncate and fsync to ensure file size metadata is flushed. - // https://github.com/boltdb/bolt/issues/284 - if err := db.file.Truncate(int64(sz)); err != nil { - return fmt.Errorf("file resize error: %s", err) - } - if err := db.file.Sync(); err != nil { - return fmt.Errorf("file sync error: %s", err) - } - // Map the data file to memory. - b, err := syscall.Mmap(int(db.file.Fd()), 0, sz, syscall.PROT_READ, syscall.MAP_SHARED) + b, err := syscall.Mmap(int(db.file.Fd()), 0, sz, syscall.PROT_READ, syscall.MAP_SHARED|db.MmapFlags) if err != nil { return err } + // Advise the kernel that the mmap is accessed randomly. + if err := madvise(b, syscall.MADV_RANDOM); err != nil { + return fmt.Errorf("madvise: %s", err) + } + // Save the original byte slice and convert to a byte array pointer. db.dataref = b db.data = (*[maxMapSize]byte)(unsafe.Pointer(&b[0])) @@ -78,3 +78,12 @@ func munmap(db *DB) error { db.datasz = 0 return err } + +// NOTE: This function is copied from stdlib because it is not available on darwin. +func madvise(b []byte, advice int) (err error) { + _, _, e1 := syscall.Syscall(syscall.SYS_MADVISE, uintptr(unsafe.Pointer(&b[0])), uintptr(len(b)), uintptr(advice)) + if e1 != 0 { + err = e1 + } + return +} diff --git a/Godeps/_workspace/src/github.com/boltdb/bolt/bolt_unix_solaris.go b/Godeps/_workspace/src/github.com/boltdb/bolt/bolt_unix_solaris.go new file mode 100644 index 00000000..1c4e48d6 --- /dev/null +++ b/Godeps/_workspace/src/github.com/boltdb/bolt/bolt_unix_solaris.go @@ -0,0 +1,90 @@ +package bolt + +import ( + "fmt" + "os" + "syscall" + "time" + "unsafe" + + "golang.org/x/sys/unix" +) + +// flock acquires an advisory lock on a file descriptor. +func flock(f *os.File, exclusive bool, timeout time.Duration) error { + var t time.Time + for { + // If we're beyond our timeout then return an error. + // This can only occur after we've attempted a flock once. + if t.IsZero() { + t = time.Now() + } else if timeout > 0 && time.Since(t) > timeout { + return ErrTimeout + } + var lock syscall.Flock_t + lock.Start = 0 + lock.Len = 0 + lock.Pid = 0 + lock.Whence = 0 + lock.Pid = 0 + if exclusive { + lock.Type = syscall.F_WRLCK + } else { + lock.Type = syscall.F_RDLCK + } + err := syscall.FcntlFlock(f.Fd(), syscall.F_SETLK, &lock) + if err == nil { + return nil + } else if err != syscall.EAGAIN { + return err + } + + // Wait for a bit and try again. + time.Sleep(50 * time.Millisecond) + } +} + +// funlock releases an advisory lock on a file descriptor. +func funlock(f *os.File) error { + var lock syscall.Flock_t + lock.Start = 0 + lock.Len = 0 + lock.Type = syscall.F_UNLCK + lock.Whence = 0 + return syscall.FcntlFlock(uintptr(f.Fd()), syscall.F_SETLK, &lock) +} + +// mmap memory maps a DB's data file. +func mmap(db *DB, sz int) error { + // Map the data file to memory. + b, err := unix.Mmap(int(db.file.Fd()), 0, sz, syscall.PROT_READ, syscall.MAP_SHARED|db.MmapFlags) + if err != nil { + return err + } + + // Advise the kernel that the mmap is accessed randomly. + if err := unix.Madvise(b, syscall.MADV_RANDOM); err != nil { + return fmt.Errorf("madvise: %s", err) + } + + // Save the original byte slice and convert to a byte array pointer. + db.dataref = b + db.data = (*[maxMapSize]byte)(unsafe.Pointer(&b[0])) + db.datasz = sz + return nil +} + +// munmap unmaps a DB's data file from memory. +func munmap(db *DB) error { + // Ignore the unmap if we have no mapped data. + if db.dataref == nil { + return nil + } + + // Unmap using the original byte slice. + err := unix.Munmap(db.dataref) + db.dataref = nil + db.data = nil + db.datasz = 0 + return err +} diff --git a/Godeps/_workspace/src/github.com/boltdb/bolt/bolt_windows.go b/Godeps/_workspace/src/github.com/boltdb/bolt/bolt_windows.go index c8539d40..91c4968f 100644 --- a/Godeps/_workspace/src/github.com/boltdb/bolt/bolt_windows.go +++ b/Godeps/_workspace/src/github.com/boltdb/bolt/bolt_windows.go @@ -8,7 +8,37 @@ import ( "unsafe" ) -var odirect int +// LockFileEx code derived from golang build filemutex_windows.go @ v1.5.1 +var ( + modkernel32 = syscall.NewLazyDLL("kernel32.dll") + procLockFileEx = modkernel32.NewProc("LockFileEx") + procUnlockFileEx = modkernel32.NewProc("UnlockFileEx") +) + +const ( + // see https://msdn.microsoft.com/en-us/library/windows/desktop/aa365203(v=vs.85).aspx + flagLockExclusive = 2 + flagLockFailImmediately = 1 + + // see https://msdn.microsoft.com/en-us/library/windows/desktop/ms681382(v=vs.85).aspx + errLockViolation syscall.Errno = 0x21 +) + +func lockFileEx(h syscall.Handle, flags, reserved, locklow, lockhigh uint32, ol *syscall.Overlapped) (err error) { + r, _, err := procLockFileEx.Call(uintptr(h), uintptr(flags), uintptr(reserved), uintptr(locklow), uintptr(lockhigh), uintptr(unsafe.Pointer(ol))) + if r == 0 { + return err + } + return nil +} + +func unlockFileEx(h syscall.Handle, reserved, locklow, lockhigh uint32, ol *syscall.Overlapped) (err error) { + r, _, err := procUnlockFileEx.Call(uintptr(h), uintptr(reserved), uintptr(locklow), uintptr(lockhigh), uintptr(unsafe.Pointer(ol)), 0) + if r == 0 { + return err + } + return nil +} // fdatasync flushes written data to a file descriptor. func fdatasync(db *DB) error { @@ -16,21 +46,47 @@ func fdatasync(db *DB) error { } // flock acquires an advisory lock on a file descriptor. -func flock(f *os.File, _ time.Duration) error { - return nil +func flock(f *os.File, exclusive bool, timeout time.Duration) error { + var t time.Time + for { + // If we're beyond our timeout then return an error. + // This can only occur after we've attempted a flock once. + if t.IsZero() { + t = time.Now() + } else if timeout > 0 && time.Since(t) > timeout { + return ErrTimeout + } + + var flag uint32 = flagLockFailImmediately + if exclusive { + flag |= flagLockExclusive + } + + err := lockFileEx(syscall.Handle(f.Fd()), flag, 0, 1, 0, &syscall.Overlapped{}) + if err == nil { + return nil + } else if err != errLockViolation { + return err + } + + // Wait for a bit and try again. + time.Sleep(50 * time.Millisecond) + } } // funlock releases an advisory lock on a file descriptor. func funlock(f *os.File) error { - return nil + return unlockFileEx(syscall.Handle(f.Fd()), 0, 1, 0, &syscall.Overlapped{}) } // mmap memory maps a DB's data file. // Based on: https://github.com/edsrzf/mmap-go func mmap(db *DB, sz int) error { - // Truncate the database to the size of the mmap. - if err := db.file.Truncate(int64(sz)); err != nil { - return fmt.Errorf("truncate: %s", err) + if !db.readOnly { + // Truncate the database to the size of the mmap. + if err := db.file.Truncate(int64(sz)); err != nil { + return fmt.Errorf("truncate: %s", err) + } } // Open a file mapping handle. diff --git a/Godeps/_workspace/src/github.com/boltdb/bolt/boltsync_unix.go b/Godeps/_workspace/src/github.com/boltdb/bolt/boltsync_unix.go index 8db89776..f5044252 100644 --- a/Godeps/_workspace/src/github.com/boltdb/bolt/boltsync_unix.go +++ b/Godeps/_workspace/src/github.com/boltdb/bolt/boltsync_unix.go @@ -2,8 +2,6 @@ package bolt -var odirect int - // fdatasync flushes written data to a file descriptor. func fdatasync(db *DB) error { return db.file.Sync() diff --git a/Godeps/_workspace/src/github.com/boltdb/bolt/bucket.go b/Godeps/_workspace/src/github.com/boltdb/bolt/bucket.go index 470689ba..d2f8c524 100644 --- a/Godeps/_workspace/src/github.com/boltdb/bolt/bucket.go +++ b/Godeps/_workspace/src/github.com/boltdb/bolt/bucket.go @@ -11,7 +11,7 @@ const ( MaxKeySize = 32768 // MaxValueSize is the maximum length of a value, in bytes. - MaxValueSize = 4294967295 + MaxValueSize = (1 << 31) - 2 ) const ( @@ -99,6 +99,7 @@ func (b *Bucket) Cursor() *Cursor { // Bucket retrieves a nested bucket by name. // Returns nil if the bucket does not exist. +// The bucket instance is only valid for the lifetime of the transaction. func (b *Bucket) Bucket(name []byte) *Bucket { if b.buckets != nil { if child := b.buckets[string(name)]; child != nil { @@ -148,6 +149,7 @@ func (b *Bucket) openBucket(value []byte) *Bucket { // CreateBucket creates a new bucket at the given key and returns the new bucket. // Returns an error if the key already exists, if the bucket name is blank, or if the bucket name is too long. +// The bucket instance is only valid for the lifetime of the transaction. func (b *Bucket) CreateBucket(key []byte) (*Bucket, error) { if b.tx.db == nil { return nil, ErrTxClosed @@ -192,6 +194,7 @@ func (b *Bucket) CreateBucket(key []byte) (*Bucket, error) { // CreateBucketIfNotExists creates a new bucket if it doesn't already exist and returns a reference to it. // Returns an error if the bucket name is blank, or if the bucket name is too long. +// The bucket instance is only valid for the lifetime of the transaction. func (b *Bucket) CreateBucketIfNotExists(key []byte) (*Bucket, error) { child, err := b.CreateBucket(key) if err == ErrBucketExists { @@ -252,6 +255,7 @@ func (b *Bucket) DeleteBucket(key []byte) error { // Get retrieves the value for a key in the bucket. // Returns a nil value if the key does not exist or if the key is a nested bucket. +// The returned value is only valid for the life of the transaction. func (b *Bucket) Get(key []byte) []byte { k, v, flags := b.Cursor().seek(key) @@ -269,6 +273,7 @@ func (b *Bucket) Get(key []byte) []byte { // Put sets the value for a key in the bucket. // If the key exist then its previous value will be overwritten. +// Supplied value must remain valid for the life of the transaction. // Returns an error if the bucket was created from a read-only transaction, if the key is blank, if the key is too large, or if the value is too large. func (b *Bucket) Put(key []byte, value []byte) error { if b.tx.db == nil { @@ -345,7 +350,8 @@ func (b *Bucket) NextSequence() (uint64, error) { // ForEach executes a function for each key/value pair in a bucket. // If the provided function returns an error then the iteration is stopped and -// the error is returned to the caller. +// the error is returned to the caller. The provided function must not modify +// the bucket; this will result in undefined behavior. func (b *Bucket) ForEach(fn func(k, v []byte) error) error { if b.tx.db == nil { return ErrTxClosed diff --git a/Godeps/_workspace/src/github.com/boltdb/bolt/bucket_test.go b/Godeps/_workspace/src/github.com/boltdb/bolt/bucket_test.go index 9cbc531d..9606ec55 100644 --- a/Godeps/_workspace/src/github.com/boltdb/bolt/bucket_test.go +++ b/Godeps/_workspace/src/github.com/boltdb/bolt/bucket_test.go @@ -5,6 +5,7 @@ import ( "encoding/binary" "errors" "fmt" + "log" "math/rand" "os" "strconv" @@ -17,94 +18,150 @@ import ( // Ensure that a bucket that gets a non-existent key returns nil. func TestBucket_Get_NonExistent(t *testing.T) { - db := NewTestDB() - defer db.Close() - db.Update(func(tx *bolt.Tx) error { - tx.CreateBucket([]byte("widgets")) - value := tx.Bucket([]byte("widgets")).Get([]byte("foo")) - assert(t, value == nil, "") + db := MustOpenDB() + defer db.MustClose() + + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + if v := b.Get([]byte("foo")); v != nil { + t.Fatal("expected nil value") + } return nil - }) + }); err != nil { + t.Fatal(err) + } } // Ensure that a bucket can read a value that is not flushed yet. func TestBucket_Get_FromNode(t *testing.T) { - db := NewTestDB() - defer db.Close() - db.Update(func(tx *bolt.Tx) error { - tx.CreateBucket([]byte("widgets")) - b := tx.Bucket([]byte("widgets")) - b.Put([]byte("foo"), []byte("bar")) - value := b.Get([]byte("foo")) - equals(t, []byte("bar"), value) + db := MustOpenDB() + defer db.MustClose() + + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("foo"), []byte("bar")); err != nil { + t.Fatal(err) + } + if v := b.Get([]byte("foo")); !bytes.Equal(v, []byte("bar")) { + t.Fatalf("unexpected value: %v", v) + } return nil - }) + }); err != nil { + t.Fatal(err) + } } // Ensure that a bucket retrieved via Get() returns a nil. func TestBucket_Get_IncompatibleValue(t *testing.T) { - db := NewTestDB() - defer db.Close() - db.Update(func(tx *bolt.Tx) error { - tx.CreateBucket([]byte("widgets")) - _, err := tx.Bucket([]byte("widgets")).CreateBucket([]byte("foo")) - ok(t, err) - assert(t, tx.Bucket([]byte("widgets")).Get([]byte("foo")) == nil, "") + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { + _, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + + if _, err := tx.Bucket([]byte("widgets")).CreateBucket([]byte("foo")); err != nil { + t.Fatal(err) + } + + if tx.Bucket([]byte("widgets")).Get([]byte("foo")) != nil { + t.Fatal("expected nil value") + } return nil - }) + }); err != nil { + t.Fatal(err) + } } // Ensure that a bucket can write a key/value. func TestBucket_Put(t *testing.T) { - db := NewTestDB() - defer db.Close() - db.Update(func(tx *bolt.Tx) error { - tx.CreateBucket([]byte("widgets")) - err := tx.Bucket([]byte("widgets")).Put([]byte("foo"), []byte("bar")) - ok(t, err) - value := tx.Bucket([]byte("widgets")).Get([]byte("foo")) - equals(t, value, []byte("bar")) + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("foo"), []byte("bar")); err != nil { + t.Fatal(err) + } + + v := tx.Bucket([]byte("widgets")).Get([]byte("foo")) + if !bytes.Equal([]byte("bar"), v) { + t.Fatalf("unexpected value: %v", v) + } return nil - }) + }); err != nil { + t.Fatal(err) + } } // Ensure that a bucket can rewrite a key in the same transaction. func TestBucket_Put_Repeat(t *testing.T) { - db := NewTestDB() - defer db.Close() - db.Update(func(tx *bolt.Tx) error { - tx.CreateBucket([]byte("widgets")) - b := tx.Bucket([]byte("widgets")) - ok(t, b.Put([]byte("foo"), []byte("bar"))) - ok(t, b.Put([]byte("foo"), []byte("baz"))) + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("foo"), []byte("bar")); err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("foo"), []byte("baz")); err != nil { + t.Fatal(err) + } + value := tx.Bucket([]byte("widgets")).Get([]byte("foo")) - equals(t, value, []byte("baz")) + if !bytes.Equal([]byte("baz"), value) { + t.Fatalf("unexpected value: %v", value) + } return nil - }) + }); err != nil { + t.Fatal(err) + } } // Ensure that a bucket can write a bunch of large values. func TestBucket_Put_Large(t *testing.T) { - db := NewTestDB() - defer db.Close() + db := MustOpenDB() + defer db.MustClose() count, factor := 100, 200 - db.Update(func(tx *bolt.Tx) error { - tx.CreateBucket([]byte("widgets")) - b := tx.Bucket([]byte("widgets")) + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } for i := 1; i < count; i++ { - ok(t, b.Put([]byte(strings.Repeat("0", i*factor)), []byte(strings.Repeat("X", (count-i)*factor)))) + if err := b.Put([]byte(strings.Repeat("0", i*factor)), []byte(strings.Repeat("X", (count-i)*factor))); err != nil { + t.Fatal(err) + } } return nil - }) - db.View(func(tx *bolt.Tx) error { + }); err != nil { + t.Fatal(err) + } + + if err := db.View(func(tx *bolt.Tx) error { b := tx.Bucket([]byte("widgets")) for i := 1; i < count; i++ { value := b.Get([]byte(strings.Repeat("0", i*factor))) - equals(t, []byte(strings.Repeat("X", (count-i)*factor)), value) + if !bytes.Equal(value, []byte(strings.Repeat("X", (count-i)*factor))) { + t.Fatalf("unexpected value: %v", value) + } } return nil - }) + }); err != nil { + t.Fatal(err) + } } // Ensure that a database can perform multiple large appends safely. @@ -116,104 +173,170 @@ func TestDB_Put_VeryLarge(t *testing.T) { n, batchN := 400000, 200000 ksize, vsize := 8, 500 - db := NewTestDB() - defer db.Close() + db := MustOpenDB() + defer db.MustClose() for i := 0; i < n; i += batchN { - err := db.Update(func(tx *bolt.Tx) error { - b, _ := tx.CreateBucketIfNotExists([]byte("widgets")) + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucketIfNotExists([]byte("widgets")) + if err != nil { + t.Fatal(err) + } for j := 0; j < batchN; j++ { k, v := make([]byte, ksize), make([]byte, vsize) binary.BigEndian.PutUint32(k, uint32(i+j)) - ok(t, b.Put(k, v)) + if err := b.Put(k, v); err != nil { + t.Fatal(err) + } } return nil - }) - ok(t, err) + }); err != nil { + t.Fatal(err) + } } } // Ensure that a setting a value on a key with a bucket value returns an error. func TestBucket_Put_IncompatibleValue(t *testing.T) { - db := NewTestDB() - defer db.Close() - db.Update(func(tx *bolt.Tx) error { - tx.CreateBucket([]byte("widgets")) - _, err := tx.Bucket([]byte("widgets")).CreateBucket([]byte("foo")) - ok(t, err) - equals(t, bolt.ErrIncompatibleValue, tx.Bucket([]byte("widgets")).Put([]byte("foo"), []byte("bar"))) + db := MustOpenDB() + defer db.MustClose() + + if err := db.Update(func(tx *bolt.Tx) error { + b0, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + + if _, err := tx.Bucket([]byte("widgets")).CreateBucket([]byte("foo")); err != nil { + t.Fatal(err) + } + if err := b0.Put([]byte("foo"), []byte("bar")); err != bolt.ErrIncompatibleValue { + t.Fatalf("unexpected error: %s", err) + } return nil - }) + }); err != nil { + t.Fatal(err) + } } // Ensure that a setting a value while the transaction is closed returns an error. func TestBucket_Put_Closed(t *testing.T) { - db := NewTestDB() - defer db.Close() - tx, _ := db.Begin(true) - tx.CreateBucket([]byte("widgets")) - b := tx.Bucket([]byte("widgets")) - tx.Rollback() - equals(t, bolt.ErrTxClosed, b.Put([]byte("foo"), []byte("bar"))) + db := MustOpenDB() + defer db.MustClose() + tx, err := db.Begin(true) + if err != nil { + t.Fatal(err) + } + + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + + if err := tx.Rollback(); err != nil { + t.Fatal(err) + } + + if err := b.Put([]byte("foo"), []byte("bar")); err != bolt.ErrTxClosed { + t.Fatalf("unexpected error: %s", err) + } } // Ensure that setting a value on a read-only bucket returns an error. func TestBucket_Put_ReadOnly(t *testing.T) { - db := NewTestDB() - defer db.Close() - db.Update(func(tx *bolt.Tx) error { - _, err := tx.CreateBucket([]byte("widgets")) - ok(t, err) + db := MustOpenDB() + defer db.MustClose() + + if err := db.Update(func(tx *bolt.Tx) error { + if _, err := tx.CreateBucket([]byte("widgets")); err != nil { + t.Fatal(err) + } return nil - }) - db.View(func(tx *bolt.Tx) error { + }); err != nil { + t.Fatal(err) + } + + if err := db.View(func(tx *bolt.Tx) error { b := tx.Bucket([]byte("widgets")) - err := b.Put([]byte("foo"), []byte("bar")) - equals(t, err, bolt.ErrTxNotWritable) + if err := b.Put([]byte("foo"), []byte("bar")); err != bolt.ErrTxNotWritable { + t.Fatalf("unexpected error: %s", err) + } return nil - }) + }); err != nil { + t.Fatal(err) + } } // Ensure that a bucket can delete an existing key. func TestBucket_Delete(t *testing.T) { - db := NewTestDB() - defer db.Close() - db.Update(func(tx *bolt.Tx) error { - tx.CreateBucket([]byte("widgets")) - tx.Bucket([]byte("widgets")).Put([]byte("foo"), []byte("bar")) - err := tx.Bucket([]byte("widgets")).Delete([]byte("foo")) - ok(t, err) - value := tx.Bucket([]byte("widgets")).Get([]byte("foo")) - assert(t, value == nil, "") + db := MustOpenDB() + defer db.MustClose() + + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("foo"), []byte("bar")); err != nil { + t.Fatal(err) + } + if err := b.Delete([]byte("foo")); err != nil { + t.Fatal(err) + } + if v := b.Get([]byte("foo")); v != nil { + t.Fatalf("unexpected value: %v", v) + } return nil - }) + }); err != nil { + t.Fatal(err) + } } // Ensure that deleting a large set of keys will work correctly. func TestBucket_Delete_Large(t *testing.T) { - db := NewTestDB() - defer db.Close() - db.Update(func(tx *bolt.Tx) error { - var b, _ = tx.CreateBucket([]byte("widgets")) + db := MustOpenDB() + defer db.MustClose() + + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + for i := 0; i < 100; i++ { - ok(t, b.Put([]byte(strconv.Itoa(i)), []byte(strings.Repeat("*", 1024)))) + if err := b.Put([]byte(strconv.Itoa(i)), []byte(strings.Repeat("*", 1024))); err != nil { + t.Fatal(err) + } + } + + return nil + }); err != nil { + t.Fatal(err) + } + + if err := db.Update(func(tx *bolt.Tx) error { + b := tx.Bucket([]byte("widgets")) + for i := 0; i < 100; i++ { + if err := b.Delete([]byte(strconv.Itoa(i))); err != nil { + t.Fatal(err) + } } return nil - }) - db.Update(func(tx *bolt.Tx) error { - var b = tx.Bucket([]byte("widgets")) + }); err != nil { + t.Fatal(err) + } + + if err := db.View(func(tx *bolt.Tx) error { + b := tx.Bucket([]byte("widgets")) for i := 0; i < 100; i++ { - ok(t, b.Delete([]byte(strconv.Itoa(i)))) + if v := b.Get([]byte(strconv.Itoa(i))); v != nil { + t.Fatalf("unexpected value: %v, i=%d", v, i) + } } return nil - }) - db.View(func(tx *bolt.Tx) error { - var b = tx.Bucket([]byte("widgets")) - for i := 0; i < 100; i++ { - assert(t, b.Get([]byte(strconv.Itoa(i))) == nil, "") - } - return nil - }) + }); err != nil { + t.Fatal(err) + } } // Deleting a very large list of keys will cause the freelist to use overflow. @@ -222,11 +345,12 @@ func TestBucket_Delete_FreelistOverflow(t *testing.T) { t.Skip("skipping test in short mode.") } - db := NewTestDB() - defer db.Close() + db := MustOpenDB() + defer db.MustClose() + k := make([]byte, 16) for i := uint64(0); i < 10000; i++ { - err := db.Update(func(tx *bolt.Tx) error { + if err := db.Update(func(tx *bolt.Tx) error { b, err := tx.CreateBucketIfNotExists([]byte("0")) if err != nil { t.Fatalf("bucket error: %s", err) @@ -241,272 +365,450 @@ func TestBucket_Delete_FreelistOverflow(t *testing.T) { } return nil - }) - - if err != nil { - t.Fatalf("update error: %s", err) + }); err != nil { + t.Fatal(err) } } // Delete all of them in one large transaction - err := db.Update(func(tx *bolt.Tx) error { + if err := db.Update(func(tx *bolt.Tx) error { b := tx.Bucket([]byte("0")) c := b.Cursor() for k, _ := c.First(); k != nil; k, _ = c.Next() { - b.Delete(k) + if err := c.Delete(); err != nil { + t.Fatal(err) + } } return nil - }) - - // Check that a freelist overflow occurred. - ok(t, err) + }); err != nil { + t.Fatal(err) + } } // Ensure that accessing and updating nested buckets is ok across transactions. func TestBucket_Nested(t *testing.T) { - db := NewTestDB() - defer db.Close() - db.Update(func(tx *bolt.Tx) error { + db := MustOpenDB() + defer db.MustClose() + + if err := db.Update(func(tx *bolt.Tx) error { // Create a widgets bucket. b, err := tx.CreateBucket([]byte("widgets")) - ok(t, err) + if err != nil { + t.Fatal(err) + } // Create a widgets/foo bucket. _, err = b.CreateBucket([]byte("foo")) - ok(t, err) + if err != nil { + t.Fatal(err) + } // Create a widgets/bar key. - ok(t, b.Put([]byte("bar"), []byte("0000"))) + if err := b.Put([]byte("bar"), []byte("0000")); err != nil { + t.Fatal(err) + } return nil - }) + }); err != nil { + t.Fatal(err) + } db.MustCheck() // Update widgets/bar. - db.Update(func(tx *bolt.Tx) error { - var b = tx.Bucket([]byte("widgets")) - ok(t, b.Put([]byte("bar"), []byte("xxxx"))) + if err := db.Update(func(tx *bolt.Tx) error { + b := tx.Bucket([]byte("widgets")) + if err := b.Put([]byte("bar"), []byte("xxxx")); err != nil { + t.Fatal(err) + } return nil - }) + }); err != nil { + t.Fatal(err) + } db.MustCheck() // Cause a split. - db.Update(func(tx *bolt.Tx) error { + if err := db.Update(func(tx *bolt.Tx) error { var b = tx.Bucket([]byte("widgets")) for i := 0; i < 10000; i++ { - ok(t, b.Put([]byte(strconv.Itoa(i)), []byte(strconv.Itoa(i)))) + if err := b.Put([]byte(strconv.Itoa(i)), []byte(strconv.Itoa(i))); err != nil { + t.Fatal(err) + } } return nil - }) + }); err != nil { + t.Fatal(err) + } db.MustCheck() // Insert into widgets/foo/baz. - db.Update(func(tx *bolt.Tx) error { + if err := db.Update(func(tx *bolt.Tx) error { var b = tx.Bucket([]byte("widgets")) - ok(t, b.Bucket([]byte("foo")).Put([]byte("baz"), []byte("yyyy"))) + if err := b.Bucket([]byte("foo")).Put([]byte("baz"), []byte("yyyy")); err != nil { + t.Fatal(err) + } return nil - }) + }); err != nil { + t.Fatal(err) + } db.MustCheck() // Verify. - db.View(func(tx *bolt.Tx) error { + if err := db.View(func(tx *bolt.Tx) error { var b = tx.Bucket([]byte("widgets")) - equals(t, []byte("yyyy"), b.Bucket([]byte("foo")).Get([]byte("baz"))) - equals(t, []byte("xxxx"), b.Get([]byte("bar"))) + if v := b.Bucket([]byte("foo")).Get([]byte("baz")); !bytes.Equal(v, []byte("yyyy")) { + t.Fatalf("unexpected value: %v", v) + } + if v := b.Get([]byte("bar")); !bytes.Equal(v, []byte("xxxx")) { + t.Fatalf("unexpected value: %v", v) + } for i := 0; i < 10000; i++ { - equals(t, []byte(strconv.Itoa(i)), b.Get([]byte(strconv.Itoa(i)))) + if v := b.Get([]byte(strconv.Itoa(i))); !bytes.Equal(v, []byte(strconv.Itoa(i))) { + t.Fatalf("unexpected value: %v", v) + } } return nil - }) + }); err != nil { + t.Fatal(err) + } } // Ensure that deleting a bucket using Delete() returns an error. func TestBucket_Delete_Bucket(t *testing.T) { - db := NewTestDB() - defer db.Close() - db.Update(func(tx *bolt.Tx) error { - tx.CreateBucket([]byte("widgets")) - b := tx.Bucket([]byte("widgets")) - _, err := b.CreateBucket([]byte("foo")) - ok(t, err) - equals(t, bolt.ErrIncompatibleValue, b.Delete([]byte("foo"))) + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + if _, err := b.CreateBucket([]byte("foo")); err != nil { + t.Fatal(err) + } + if err := b.Delete([]byte("foo")); err != bolt.ErrIncompatibleValue { + t.Fatalf("unexpected error: %s", err) + } return nil - }) + }); err != nil { + t.Fatal(err) + } } // Ensure that deleting a key on a read-only bucket returns an error. func TestBucket_Delete_ReadOnly(t *testing.T) { - db := NewTestDB() - defer db.Close() - db.Update(func(tx *bolt.Tx) error { - tx.CreateBucket([]byte("widgets")) + db := MustOpenDB() + defer db.MustClose() + + if err := db.Update(func(tx *bolt.Tx) error { + if _, err := tx.CreateBucket([]byte("widgets")); err != nil { + t.Fatal(err) + } return nil - }) - db.View(func(tx *bolt.Tx) error { - b := tx.Bucket([]byte("widgets")) - err := b.Delete([]byte("foo")) - equals(t, err, bolt.ErrTxNotWritable) + }); err != nil { + t.Fatal(err) + } + + if err := db.View(func(tx *bolt.Tx) error { + if err := tx.Bucket([]byte("widgets")).Delete([]byte("foo")); err != bolt.ErrTxNotWritable { + t.Fatalf("unexpected error: %s", err) + } return nil - }) + }); err != nil { + t.Fatal(err) + } } // Ensure that a deleting value while the transaction is closed returns an error. func TestBucket_Delete_Closed(t *testing.T) { - db := NewTestDB() - defer db.Close() - tx, _ := db.Begin(true) - tx.CreateBucket([]byte("widgets")) - b := tx.Bucket([]byte("widgets")) - tx.Rollback() - equals(t, bolt.ErrTxClosed, b.Delete([]byte("foo"))) + db := MustOpenDB() + defer db.MustClose() + + tx, err := db.Begin(true) + if err != nil { + t.Fatal(err) + } + + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + + if err := tx.Rollback(); err != nil { + t.Fatal(err) + } + if err := b.Delete([]byte("foo")); err != bolt.ErrTxClosed { + t.Fatalf("unexpected error: %s", err) + } } // Ensure that deleting a bucket causes nested buckets to be deleted. func TestBucket_DeleteBucket_Nested(t *testing.T) { - db := NewTestDB() - defer db.Close() - db.Update(func(tx *bolt.Tx) error { - tx.CreateBucket([]byte("widgets")) - _, err := tx.Bucket([]byte("widgets")).CreateBucket([]byte("foo")) - ok(t, err) - _, err = tx.Bucket([]byte("widgets")).Bucket([]byte("foo")).CreateBucket([]byte("bar")) - ok(t, err) - ok(t, tx.Bucket([]byte("widgets")).Bucket([]byte("foo")).Bucket([]byte("bar")).Put([]byte("baz"), []byte("bat"))) - ok(t, tx.Bucket([]byte("widgets")).DeleteBucket([]byte("foo"))) + db := MustOpenDB() + defer db.MustClose() + + if err := db.Update(func(tx *bolt.Tx) error { + widgets, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + + foo, err := widgets.CreateBucket([]byte("foo")) + if err != nil { + t.Fatal(err) + } + + bar, err := foo.CreateBucket([]byte("bar")) + if err != nil { + t.Fatal(err) + } + if err := bar.Put([]byte("baz"), []byte("bat")); err != nil { + t.Fatal(err) + } + if err := tx.Bucket([]byte("widgets")).DeleteBucket([]byte("foo")); err != nil { + t.Fatal(err) + } return nil - }) + }); err != nil { + t.Fatal(err) + } } // Ensure that deleting a bucket causes nested buckets to be deleted after they have been committed. func TestBucket_DeleteBucket_Nested2(t *testing.T) { - db := NewTestDB() - defer db.Close() - db.Update(func(tx *bolt.Tx) error { - tx.CreateBucket([]byte("widgets")) - _, err := tx.Bucket([]byte("widgets")).CreateBucket([]byte("foo")) - ok(t, err) - _, err = tx.Bucket([]byte("widgets")).Bucket([]byte("foo")).CreateBucket([]byte("bar")) - ok(t, err) - ok(t, tx.Bucket([]byte("widgets")).Bucket([]byte("foo")).Bucket([]byte("bar")).Put([]byte("baz"), []byte("bat"))) + db := MustOpenDB() + defer db.MustClose() + + if err := db.Update(func(tx *bolt.Tx) error { + widgets, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + + foo, err := widgets.CreateBucket([]byte("foo")) + if err != nil { + t.Fatal(err) + } + + bar, err := foo.CreateBucket([]byte("bar")) + if err != nil { + t.Fatal(err) + } + + if err := bar.Put([]byte("baz"), []byte("bat")); err != nil { + t.Fatal(err) + } return nil - }) - db.Update(func(tx *bolt.Tx) error { - assert(t, tx.Bucket([]byte("widgets")) != nil, "") - assert(t, tx.Bucket([]byte("widgets")).Bucket([]byte("foo")) != nil, "") - assert(t, tx.Bucket([]byte("widgets")).Bucket([]byte("foo")).Bucket([]byte("bar")) != nil, "") - equals(t, []byte("bat"), tx.Bucket([]byte("widgets")).Bucket([]byte("foo")).Bucket([]byte("bar")).Get([]byte("baz"))) - ok(t, tx.DeleteBucket([]byte("widgets"))) + }); err != nil { + t.Fatal(err) + } + + if err := db.Update(func(tx *bolt.Tx) error { + widgets := tx.Bucket([]byte("widgets")) + if widgets == nil { + t.Fatal("expected widgets bucket") + } + + foo := widgets.Bucket([]byte("foo")) + if foo == nil { + t.Fatal("expected foo bucket") + } + + bar := foo.Bucket([]byte("bar")) + if bar == nil { + t.Fatal("expected bar bucket") + } + + if v := bar.Get([]byte("baz")); !bytes.Equal(v, []byte("bat")) { + t.Fatalf("unexpected value: %v", v) + } + if err := tx.DeleteBucket([]byte("widgets")); err != nil { + t.Fatal(err) + } return nil - }) - db.View(func(tx *bolt.Tx) error { - assert(t, tx.Bucket([]byte("widgets")) == nil, "") + }); err != nil { + t.Fatal(err) + } + + if err := db.View(func(tx *bolt.Tx) error { + if tx.Bucket([]byte("widgets")) != nil { + t.Fatal("expected bucket to be deleted") + } return nil - }) + }); err != nil { + t.Fatal(err) + } } // Ensure that deleting a child bucket with multiple pages causes all pages to get collected. +// NOTE: Consistency check in bolt_test.DB.Close() will panic if pages not freed properly. func TestBucket_DeleteBucket_Large(t *testing.T) { - db := NewTestDB() - defer db.Close() - db.Update(func(tx *bolt.Tx) error { - _, err := tx.CreateBucket([]byte("widgets")) - ok(t, err) - _, err = tx.Bucket([]byte("widgets")).CreateBucket([]byte("foo")) - ok(t, err) - b := tx.Bucket([]byte("widgets")).Bucket([]byte("foo")) + db := MustOpenDB() + defer db.MustClose() + + if err := db.Update(func(tx *bolt.Tx) error { + widgets, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + + foo, err := widgets.CreateBucket([]byte("foo")) + if err != nil { + t.Fatal(err) + } + for i := 0; i < 1000; i++ { - ok(t, b.Put([]byte(fmt.Sprintf("%d", i)), []byte(fmt.Sprintf("%0100d", i)))) + if err := foo.Put([]byte(fmt.Sprintf("%d", i)), []byte(fmt.Sprintf("%0100d", i))); err != nil { + t.Fatal(err) + } } return nil - }) - db.Update(func(tx *bolt.Tx) error { - ok(t, tx.DeleteBucket([]byte("widgets"))) - return nil - }) + }); err != nil { + t.Fatal(err) + } - // NOTE: Consistency check in TestDB.Close() will panic if pages not freed properly. + if err := db.Update(func(tx *bolt.Tx) error { + if err := tx.DeleteBucket([]byte("widgets")); err != nil { + t.Fatal(err) + } + return nil + }); err != nil { + t.Fatal(err) + } } // Ensure that a simple value retrieved via Bucket() returns a nil. func TestBucket_Bucket_IncompatibleValue(t *testing.T) { - db := NewTestDB() - defer db.Close() - db.Update(func(tx *bolt.Tx) error { - tx.CreateBucket([]byte("widgets")) - ok(t, tx.Bucket([]byte("widgets")).Put([]byte("foo"), []byte("bar"))) - assert(t, tx.Bucket([]byte("widgets")).Bucket([]byte("foo")) == nil, "") + db := MustOpenDB() + defer db.MustClose() + + if err := db.Update(func(tx *bolt.Tx) error { + widgets, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + + if err := widgets.Put([]byte("foo"), []byte("bar")); err != nil { + t.Fatal(err) + } + if b := tx.Bucket([]byte("widgets")).Bucket([]byte("foo")); b != nil { + t.Fatal("expected nil bucket") + } return nil - }) + }); err != nil { + t.Fatal(err) + } } // Ensure that creating a bucket on an existing non-bucket key returns an error. func TestBucket_CreateBucket_IncompatibleValue(t *testing.T) { - db := NewTestDB() - defer db.Close() - db.Update(func(tx *bolt.Tx) error { - _, err := tx.CreateBucket([]byte("widgets")) - ok(t, err) - ok(t, tx.Bucket([]byte("widgets")).Put([]byte("foo"), []byte("bar"))) - _, err = tx.Bucket([]byte("widgets")).CreateBucket([]byte("foo")) - equals(t, bolt.ErrIncompatibleValue, err) + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { + widgets, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + + if err := widgets.Put([]byte("foo"), []byte("bar")); err != nil { + t.Fatal(err) + } + if _, err := widgets.CreateBucket([]byte("foo")); err != bolt.ErrIncompatibleValue { + t.Fatalf("unexpected error: %s", err) + } return nil - }) + }); err != nil { + t.Fatal(err) + } } // Ensure that deleting a bucket on an existing non-bucket key returns an error. func TestBucket_DeleteBucket_IncompatibleValue(t *testing.T) { - db := NewTestDB() - defer db.Close() - db.Update(func(tx *bolt.Tx) error { - _, err := tx.CreateBucket([]byte("widgets")) - ok(t, err) - ok(t, tx.Bucket([]byte("widgets")).Put([]byte("foo"), []byte("bar"))) - equals(t, bolt.ErrIncompatibleValue, tx.Bucket([]byte("widgets")).DeleteBucket([]byte("foo"))) + db := MustOpenDB() + defer db.MustClose() + + if err := db.Update(func(tx *bolt.Tx) error { + widgets, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + if err := widgets.Put([]byte("foo"), []byte("bar")); err != nil { + t.Fatal(err) + } + if err := tx.Bucket([]byte("widgets")).DeleteBucket([]byte("foo")); err != bolt.ErrIncompatibleValue { + t.Fatalf("unexpected error: %s", err) + } return nil - }) + }); err != nil { + t.Fatal(err) + } } // Ensure that a bucket can return an autoincrementing sequence. func TestBucket_NextSequence(t *testing.T) { - db := NewTestDB() - defer db.Close() - db.Update(func(tx *bolt.Tx) error { - tx.CreateBucket([]byte("widgets")) - tx.CreateBucket([]byte("woojits")) + db := MustOpenDB() + defer db.MustClose() + + if err := db.Update(func(tx *bolt.Tx) error { + widgets, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + woojits, err := tx.CreateBucket([]byte("woojits")) + if err != nil { + t.Fatal(err) + } // Make sure sequence increments. - seq, err := tx.Bucket([]byte("widgets")).NextSequence() - ok(t, err) - equals(t, seq, uint64(1)) - seq, err = tx.Bucket([]byte("widgets")).NextSequence() - ok(t, err) - equals(t, seq, uint64(2)) + if seq, err := widgets.NextSequence(); err != nil { + t.Fatal(err) + } else if seq != 1 { + t.Fatalf("unexpecte sequence: %d", seq) + } + + if seq, err := widgets.NextSequence(); err != nil { + t.Fatal(err) + } else if seq != 2 { + t.Fatalf("unexpected sequence: %d", seq) + } // Buckets should be separate. - seq, err = tx.Bucket([]byte("woojits")).NextSequence() - ok(t, err) - equals(t, seq, uint64(1)) + if seq, err := woojits.NextSequence(); err != nil { + t.Fatal(err) + } else if seq != 1 { + t.Fatalf("unexpected sequence: %d", 1) + } + return nil - }) + }); err != nil { + t.Fatal(err) + } } // Ensure that a bucket will persist an autoincrementing sequence even if its // the only thing updated on the bucket. // https://github.com/boltdb/bolt/issues/296 func TestBucket_NextSequence_Persist(t *testing.T) { - db := NewTestDB() - defer db.Close() - db.Update(func(tx *bolt.Tx) error { - _, _ = tx.CreateBucket([]byte("widgets")) - return nil - }) + db := MustOpenDB() + defer db.MustClose() - db.Update(func(tx *bolt.Tx) error { - _, _ = tx.Bucket([]byte("widgets")).NextSequence() + if err := db.Update(func(tx *bolt.Tx) error { + if _, err := tx.CreateBucket([]byte("widgets")); err != nil { + t.Fatal(err) + } return nil - }) + }); err != nil { + t.Fatal(err) + } - db.Update(func(tx *bolt.Tx) error { + if err := db.Update(func(tx *bolt.Tx) error { + if _, err := tx.Bucket([]byte("widgets")).NextSequence(); err != nil { + t.Fatal(err) + } + return nil + }); err != nil { + t.Fatal(err) + } + + if err := db.Update(func(tx *bolt.Tx) error { seq, err := tx.Bucket([]byte("widgets")).NextSequence() if err != nil { t.Fatalf("unexpected error: %s", err) @@ -514,183 +816,326 @@ func TestBucket_NextSequence_Persist(t *testing.T) { t.Fatalf("unexpected sequence: %d", seq) } return nil - }) + }); err != nil { + t.Fatal(err) + } } // Ensure that retrieving the next sequence on a read-only bucket returns an error. func TestBucket_NextSequence_ReadOnly(t *testing.T) { - db := NewTestDB() - defer db.Close() - db.Update(func(tx *bolt.Tx) error { - tx.CreateBucket([]byte("widgets")) + db := MustOpenDB() + defer db.MustClose() + + if err := db.Update(func(tx *bolt.Tx) error { + if _, err := tx.CreateBucket([]byte("widgets")); err != nil { + t.Fatal(err) + } return nil - }) - db.View(func(tx *bolt.Tx) error { - b := tx.Bucket([]byte("widgets")) - i, err := b.NextSequence() - equals(t, i, uint64(0)) - equals(t, err, bolt.ErrTxNotWritable) + }); err != nil { + t.Fatal(err) + } + + if err := db.View(func(tx *bolt.Tx) error { + _, err := tx.Bucket([]byte("widgets")).NextSequence() + if err != bolt.ErrTxNotWritable { + t.Fatalf("unexpected error: %s", err) + } return nil - }) + }); err != nil { + t.Fatal(err) + } } // Ensure that retrieving the next sequence for a bucket on a closed database return an error. func TestBucket_NextSequence_Closed(t *testing.T) { - db := NewTestDB() - defer db.Close() - tx, _ := db.Begin(true) - tx.CreateBucket([]byte("widgets")) - b := tx.Bucket([]byte("widgets")) - tx.Rollback() - _, err := b.NextSequence() - equals(t, bolt.ErrTxClosed, err) + db := MustOpenDB() + defer db.MustClose() + tx, err := db.Begin(true) + if err != nil { + t.Fatal(err) + } + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + if err := tx.Rollback(); err != nil { + t.Fatal(err) + } + if _, err := b.NextSequence(); err != bolt.ErrTxClosed { + t.Fatal(err) + } } // Ensure a user can loop over all key/value pairs in a bucket. func TestBucket_ForEach(t *testing.T) { - db := NewTestDB() - defer db.Close() - db.Update(func(tx *bolt.Tx) error { - tx.CreateBucket([]byte("widgets")) - tx.Bucket([]byte("widgets")).Put([]byte("foo"), []byte("0000")) - tx.Bucket([]byte("widgets")).Put([]byte("baz"), []byte("0001")) - tx.Bucket([]byte("widgets")).Put([]byte("bar"), []byte("0002")) + db := MustOpenDB() + defer db.MustClose() + + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("foo"), []byte("0000")); err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("baz"), []byte("0001")); err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("bar"), []byte("0002")); err != nil { + t.Fatal(err) + } var index int - err := tx.Bucket([]byte("widgets")).ForEach(func(k, v []byte) error { + if err := b.ForEach(func(k, v []byte) error { switch index { case 0: - equals(t, k, []byte("bar")) - equals(t, v, []byte("0002")) + if !bytes.Equal(k, []byte("bar")) { + t.Fatalf("unexpected key: %v", k) + } else if !bytes.Equal(v, []byte("0002")) { + t.Fatalf("unexpected value: %v", v) + } case 1: - equals(t, k, []byte("baz")) - equals(t, v, []byte("0001")) + if !bytes.Equal(k, []byte("baz")) { + t.Fatalf("unexpected key: %v", k) + } else if !bytes.Equal(v, []byte("0001")) { + t.Fatalf("unexpected value: %v", v) + } case 2: - equals(t, k, []byte("foo")) - equals(t, v, []byte("0000")) + if !bytes.Equal(k, []byte("foo")) { + t.Fatalf("unexpected key: %v", k) + } else if !bytes.Equal(v, []byte("0000")) { + t.Fatalf("unexpected value: %v", v) + } } index++ return nil - }) - ok(t, err) - equals(t, index, 3) + }); err != nil { + t.Fatal(err) + } + + if index != 3 { + t.Fatalf("unexpected index: %d", index) + } + return nil - }) + }); err != nil { + t.Fatal(err) + } } // Ensure a database can stop iteration early. func TestBucket_ForEach_ShortCircuit(t *testing.T) { - db := NewTestDB() - defer db.Close() - db.Update(func(tx *bolt.Tx) error { - tx.CreateBucket([]byte("widgets")) - tx.Bucket([]byte("widgets")).Put([]byte("bar"), []byte("0000")) - tx.Bucket([]byte("widgets")).Put([]byte("baz"), []byte("0000")) - tx.Bucket([]byte("widgets")).Put([]byte("foo"), []byte("0000")) + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("bar"), []byte("0000")); err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("baz"), []byte("0000")); err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("foo"), []byte("0000")); err != nil { + t.Fatal(err) + } var index int - err := tx.Bucket([]byte("widgets")).ForEach(func(k, v []byte) error { + if err := tx.Bucket([]byte("widgets")).ForEach(func(k, v []byte) error { index++ if bytes.Equal(k, []byte("baz")) { return errors.New("marker") } return nil - }) - equals(t, errors.New("marker"), err) - equals(t, 2, index) + }); err == nil || err.Error() != "marker" { + t.Fatalf("unexpected error: %s", err) + } + if index != 2 { + t.Fatalf("unexpected index: %d", index) + } + return nil - }) + }); err != nil { + t.Fatal(err) + } } // Ensure that looping over a bucket on a closed database returns an error. func TestBucket_ForEach_Closed(t *testing.T) { - db := NewTestDB() - defer db.Close() - tx, _ := db.Begin(true) - tx.CreateBucket([]byte("widgets")) - b := tx.Bucket([]byte("widgets")) - tx.Rollback() - err := b.ForEach(func(k, v []byte) error { return nil }) - equals(t, bolt.ErrTxClosed, err) + db := MustOpenDB() + defer db.MustClose() + + tx, err := db.Begin(true) + if err != nil { + t.Fatal(err) + } + + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + + if err := tx.Rollback(); err != nil { + t.Fatal(err) + } + + if err := b.ForEach(func(k, v []byte) error { return nil }); err != bolt.ErrTxClosed { + t.Fatalf("unexpected error: %s", err) + } } // Ensure that an error is returned when inserting with an empty key. func TestBucket_Put_EmptyKey(t *testing.T) { - db := NewTestDB() - defer db.Close() - db.Update(func(tx *bolt.Tx) error { - tx.CreateBucket([]byte("widgets")) - err := tx.Bucket([]byte("widgets")).Put([]byte(""), []byte("bar")) - equals(t, err, bolt.ErrKeyRequired) - err = tx.Bucket([]byte("widgets")).Put(nil, []byte("bar")) - equals(t, err, bolt.ErrKeyRequired) + db := MustOpenDB() + defer db.MustClose() + + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + if err := b.Put([]byte(""), []byte("bar")); err != bolt.ErrKeyRequired { + t.Fatalf("unexpected error: %s", err) + } + if err := b.Put(nil, []byte("bar")); err != bolt.ErrKeyRequired { + t.Fatalf("unexpected error: %s", err) + } return nil - }) + }); err != nil { + t.Fatal(err) + } } // Ensure that an error is returned when inserting with a key that's too large. func TestBucket_Put_KeyTooLarge(t *testing.T) { - db := NewTestDB() - defer db.Close() - db.Update(func(tx *bolt.Tx) error { - tx.CreateBucket([]byte("widgets")) - err := tx.Bucket([]byte("widgets")).Put(make([]byte, 32769), []byte("bar")) - equals(t, err, bolt.ErrKeyTooLarge) + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + if err := b.Put(make([]byte, 32769), []byte("bar")); err != bolt.ErrKeyTooLarge { + t.Fatalf("unexpected error: %s", err) + } return nil - }) + }); err != nil { + t.Fatal(err) + } +} + +// Ensure that an error is returned when inserting a value that's too large. +func TestBucket_Put_ValueTooLarge(t *testing.T) { + // Skip this test on DroneCI because the machine is resource constrained. + if os.Getenv("DRONE") == "true" { + t.Skip("not enough RAM for test") + } + + db := MustOpenDB() + defer db.MustClose() + + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("foo"), make([]byte, bolt.MaxValueSize+1)); err != bolt.ErrValueTooLarge { + t.Fatalf("unexpected error: %s", err) + } + return nil + }); err != nil { + t.Fatal(err) + } } // Ensure a bucket can calculate stats. func TestBucket_Stats(t *testing.T) { - db := NewTestDB() - defer db.Close() + db := MustOpenDB() + defer db.MustClose() // Add bucket with fewer keys but one big value. - big_key := []byte("really-big-value") + bigKey := []byte("really-big-value") for i := 0; i < 500; i++ { - db.Update(func(tx *bolt.Tx) error { - b, _ := tx.CreateBucketIfNotExists([]byte("woojits")) - return b.Put([]byte(fmt.Sprintf("%03d", i)), []byte(strconv.Itoa(i))) - }) + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucketIfNotExists([]byte("woojits")) + if err != nil { + t.Fatal(err) + } + + if err := b.Put([]byte(fmt.Sprintf("%03d", i)), []byte(strconv.Itoa(i))); err != nil { + t.Fatal(err) + } + return nil + }); err != nil { + t.Fatal(err) + } + } + if err := db.Update(func(tx *bolt.Tx) error { + if err := tx.Bucket([]byte("woojits")).Put(bigKey, []byte(strings.Repeat("*", 10000))); err != nil { + t.Fatal(err) + } + return nil + }); err != nil { + t.Fatal(err) } - db.Update(func(tx *bolt.Tx) error { - b, _ := tx.CreateBucketIfNotExists([]byte("woojits")) - return b.Put(big_key, []byte(strings.Repeat("*", 10000))) - }) db.MustCheck() - db.View(func(tx *bolt.Tx) error { - b := tx.Bucket([]byte("woojits")) - stats := b.Stats() - equals(t, 1, stats.BranchPageN) - equals(t, 0, stats.BranchOverflowN) - equals(t, 7, stats.LeafPageN) - equals(t, 2, stats.LeafOverflowN) - equals(t, 501, stats.KeyN) - equals(t, 2, stats.Depth) + + if err := db.View(func(tx *bolt.Tx) error { + stats := tx.Bucket([]byte("woojits")).Stats() + if stats.BranchPageN != 1 { + t.Fatalf("unexpected BranchPageN: %d", stats.BranchPageN) + } else if stats.BranchOverflowN != 0 { + t.Fatalf("unexpected BranchOverflowN: %d", stats.BranchOverflowN) + } else if stats.LeafPageN != 7 { + t.Fatalf("unexpected LeafPageN: %d", stats.LeafPageN) + } else if stats.LeafOverflowN != 2 { + t.Fatalf("unexpected LeafOverflowN: %d", stats.LeafOverflowN) + } else if stats.KeyN != 501 { + t.Fatalf("unexpected KeyN: %d", stats.KeyN) + } else if stats.Depth != 2 { + t.Fatalf("unexpected Depth: %d", stats.Depth) + } branchInuse := 16 // branch page header branchInuse += 7 * 16 // branch elements branchInuse += 7 * 3 // branch keys (6 3-byte keys) - equals(t, branchInuse, stats.BranchInuse) + if stats.BranchInuse != branchInuse { + t.Fatalf("unexpected BranchInuse: %d", stats.BranchInuse) + } leafInuse := 7 * 16 // leaf page header leafInuse += 501 * 16 // leaf elements - leafInuse += 500*3 + len(big_key) // leaf keys + leafInuse += 500*3 + len(bigKey) // leaf keys leafInuse += 1*10 + 2*90 + 3*400 + 10000 // leaf values - equals(t, leafInuse, stats.LeafInuse) - - if os.Getpagesize() == 4096 { - // Incompatible page size - equals(t, 4096, stats.BranchAlloc) - equals(t, 36864, stats.LeafAlloc) + if stats.LeafInuse != leafInuse { + t.Fatalf("unexpected LeafInuse: %d", stats.LeafInuse) + } + + // Only check allocations for 4KB pages. + if os.Getpagesize() == 4096 { + if stats.BranchAlloc != 4096 { + t.Fatalf("unexpected BranchAlloc: %d", stats.BranchAlloc) + } else if stats.LeafAlloc != 36864 { + t.Fatalf("unexpected LeafAlloc: %d", stats.LeafAlloc) + } + } + + if stats.BucketN != 1 { + t.Fatalf("unexpected BucketN: %d", stats.BucketN) + } else if stats.InlineBucketN != 0 { + t.Fatalf("unexpected InlineBucketN: %d", stats.InlineBucketN) + } else if stats.InlineBucketInuse != 0 { + t.Fatalf("unexpected InlineBucketInuse: %d", stats.InlineBucketInuse) } - equals(t, 1, stats.BucketN) - equals(t, 0, stats.InlineBucketN) - equals(t, 0, stats.InlineBucketInuse) return nil - }) + }); err != nil { + t.Fatal(err) + } } // Ensure a bucket with random insertion utilizes fill percentage correctly. @@ -701,150 +1146,251 @@ func TestBucket_Stats_RandomFill(t *testing.T) { t.Skip("invalid page size for test") } - db := NewTestDB() - defer db.Close() + db := MustOpenDB() + defer db.MustClose() // Add a set of values in random order. It will be the same random // order so we can maintain consistency between test runs. var count int - r := rand.New(rand.NewSource(42)) - for _, i := range r.Perm(1000) { - db.Update(func(tx *bolt.Tx) error { - b, _ := tx.CreateBucketIfNotExists([]byte("woojits")) + rand := rand.New(rand.NewSource(42)) + for _, i := range rand.Perm(1000) { + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucketIfNotExists([]byte("woojits")) + if err != nil { + t.Fatal(err) + } b.FillPercent = 0.9 - for _, j := range r.Perm(100) { + for _, j := range rand.Perm(100) { index := (j * 10000) + i - b.Put([]byte(fmt.Sprintf("%d000000000000000", index)), []byte("0000000000")) + if err := b.Put([]byte(fmt.Sprintf("%d000000000000000", index)), []byte("0000000000")); err != nil { + t.Fatal(err) + } count++ } return nil - }) + }); err != nil { + t.Fatal(err) + } } + db.MustCheck() - db.View(func(tx *bolt.Tx) error { - s := tx.Bucket([]byte("woojits")).Stats() - equals(t, 100000, s.KeyN) + if err := db.View(func(tx *bolt.Tx) error { + stats := tx.Bucket([]byte("woojits")).Stats() + if stats.KeyN != 100000 { + t.Fatalf("unexpected KeyN: %d", stats.KeyN) + } - equals(t, 98, s.BranchPageN) - equals(t, 0, s.BranchOverflowN) - equals(t, 130984, s.BranchInuse) - equals(t, 401408, s.BranchAlloc) + if stats.BranchPageN != 98 { + t.Fatalf("unexpected BranchPageN: %d", stats.BranchPageN) + } else if stats.BranchOverflowN != 0 { + t.Fatalf("unexpected BranchOverflowN: %d", stats.BranchOverflowN) + } else if stats.BranchInuse != 130984 { + t.Fatalf("unexpected BranchInuse: %d", stats.BranchInuse) + } else if stats.BranchAlloc != 401408 { + t.Fatalf("unexpected BranchAlloc: %d", stats.BranchAlloc) + } - equals(t, 3412, s.LeafPageN) - equals(t, 0, s.LeafOverflowN) - equals(t, 4742482, s.LeafInuse) - equals(t, 13975552, s.LeafAlloc) + if stats.LeafPageN != 3412 { + t.Fatalf("unexpected LeafPageN: %d", stats.LeafPageN) + } else if stats.LeafOverflowN != 0 { + t.Fatalf("unexpected LeafOverflowN: %d", stats.LeafOverflowN) + } else if stats.LeafInuse != 4742482 { + t.Fatalf("unexpected LeafInuse: %d", stats.LeafInuse) + } else if stats.LeafAlloc != 13975552 { + t.Fatalf("unexpected LeafAlloc: %d", stats.LeafAlloc) + } return nil - }) + }); err != nil { + t.Fatal(err) + } } // Ensure a bucket can calculate stats. func TestBucket_Stats_Small(t *testing.T) { - db := NewTestDB() - defer db.Close() - db.Update(func(tx *bolt.Tx) error { + db := MustOpenDB() + defer db.MustClose() + + if err := db.Update(func(tx *bolt.Tx) error { // Add a bucket that fits on a single root leaf. b, err := tx.CreateBucket([]byte("whozawhats")) - ok(t, err) - b.Put([]byte("foo"), []byte("bar")) + if err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("foo"), []byte("bar")); err != nil { + t.Fatal(err) + } return nil - }) + }); err != nil { + t.Fatal(err) + } + db.MustCheck() - db.View(func(tx *bolt.Tx) error { + + if err := db.View(func(tx *bolt.Tx) error { b := tx.Bucket([]byte("whozawhats")) stats := b.Stats() - equals(t, 0, stats.BranchPageN) - equals(t, 0, stats.BranchOverflowN) - equals(t, 0, stats.LeafPageN) - equals(t, 0, stats.LeafOverflowN) - equals(t, 1, stats.KeyN) - equals(t, 1, stats.Depth) - equals(t, 0, stats.BranchInuse) - equals(t, 0, stats.LeafInuse) - if os.Getpagesize() == 4096 { - // Incompatible page size - equals(t, 0, stats.BranchAlloc) - equals(t, 0, stats.LeafAlloc) + if stats.BranchPageN != 0 { + t.Fatalf("unexpected BranchPageN: %d", stats.BranchPageN) + } else if stats.BranchOverflowN != 0 { + t.Fatalf("unexpected BranchOverflowN: %d", stats.BranchOverflowN) + } else if stats.LeafPageN != 0 { + t.Fatalf("unexpected LeafPageN: %d", stats.LeafPageN) + } else if stats.LeafOverflowN != 0 { + t.Fatalf("unexpected LeafOverflowN: %d", stats.LeafOverflowN) + } else if stats.KeyN != 1 { + t.Fatalf("unexpected KeyN: %d", stats.KeyN) + } else if stats.Depth != 1 { + t.Fatalf("unexpected Depth: %d", stats.Depth) + } else if stats.BranchInuse != 0 { + t.Fatalf("unexpected BranchInuse: %d", stats.BranchInuse) + } else if stats.LeafInuse != 0 { + t.Fatalf("unexpected LeafInuse: %d", stats.LeafInuse) } - equals(t, 1, stats.BucketN) - equals(t, 1, stats.InlineBucketN) - equals(t, 16+16+6, stats.InlineBucketInuse) + + if os.Getpagesize() == 4096 { + if stats.BranchAlloc != 0 { + t.Fatalf("unexpected BranchAlloc: %d", stats.BranchAlloc) + } else if stats.LeafAlloc != 0 { + t.Fatalf("unexpected LeafAlloc: %d", stats.LeafAlloc) + } + } + + if stats.BucketN != 1 { + t.Fatalf("unexpected BucketN: %d", stats.BucketN) + } else if stats.InlineBucketN != 1 { + t.Fatalf("unexpected InlineBucketN: %d", stats.InlineBucketN) + } else if stats.InlineBucketInuse != 16+16+6 { + t.Fatalf("unexpected InlineBucketInuse: %d", stats.InlineBucketInuse) + } + return nil - }) + }); err != nil { + t.Fatal(err) + } } func TestBucket_Stats_EmptyBucket(t *testing.T) { - db := NewTestDB() - defer db.Close() + db := MustOpenDB() + defer db.MustClose() - db.Update(func(tx *bolt.Tx) error { + if err := db.Update(func(tx *bolt.Tx) error { // Add a bucket that fits on a single root leaf. - _, err := tx.CreateBucket([]byte("whozawhats")) - ok(t, err) + if _, err := tx.CreateBucket([]byte("whozawhats")); err != nil { + t.Fatal(err) + } return nil - }) + }); err != nil { + t.Fatal(err) + } + db.MustCheck() - db.View(func(tx *bolt.Tx) error { + + if err := db.View(func(tx *bolt.Tx) error { b := tx.Bucket([]byte("whozawhats")) stats := b.Stats() - equals(t, 0, stats.BranchPageN) - equals(t, 0, stats.BranchOverflowN) - equals(t, 0, stats.LeafPageN) - equals(t, 0, stats.LeafOverflowN) - equals(t, 0, stats.KeyN) - equals(t, 1, stats.Depth) - equals(t, 0, stats.BranchInuse) - equals(t, 0, stats.LeafInuse) - if os.Getpagesize() == 4096 { - // Incompatible page size - equals(t, 0, stats.BranchAlloc) - equals(t, 0, stats.LeafAlloc) + if stats.BranchPageN != 0 { + t.Fatalf("unexpected BranchPageN: %d", stats.BranchPageN) + } else if stats.BranchOverflowN != 0 { + t.Fatalf("unexpected BranchOverflowN: %d", stats.BranchOverflowN) + } else if stats.LeafPageN != 0 { + t.Fatalf("unexpected LeafPageN: %d", stats.LeafPageN) + } else if stats.LeafOverflowN != 0 { + t.Fatalf("unexpected LeafOverflowN: %d", stats.LeafOverflowN) + } else if stats.KeyN != 0 { + t.Fatalf("unexpected KeyN: %d", stats.KeyN) + } else if stats.Depth != 1 { + t.Fatalf("unexpected Depth: %d", stats.Depth) + } else if stats.BranchInuse != 0 { + t.Fatalf("unexpected BranchInuse: %d", stats.BranchInuse) + } else if stats.LeafInuse != 0 { + t.Fatalf("unexpected LeafInuse: %d", stats.LeafInuse) } - equals(t, 1, stats.BucketN) - equals(t, 1, stats.InlineBucketN) - equals(t, 16, stats.InlineBucketInuse) + + if os.Getpagesize() == 4096 { + if stats.BranchAlloc != 0 { + t.Fatalf("unexpected BranchAlloc: %d", stats.BranchAlloc) + } else if stats.LeafAlloc != 0 { + t.Fatalf("unexpected LeafAlloc: %d", stats.LeafAlloc) + } + } + + if stats.BucketN != 1 { + t.Fatalf("unexpected BucketN: %d", stats.BucketN) + } else if stats.InlineBucketN != 1 { + t.Fatalf("unexpected InlineBucketN: %d", stats.InlineBucketN) + } else if stats.InlineBucketInuse != 16 { + t.Fatalf("unexpected InlineBucketInuse: %d", stats.InlineBucketInuse) + } + return nil - }) + }); err != nil { + t.Fatal(err) + } } // Ensure a bucket can calculate stats. func TestBucket_Stats_Nested(t *testing.T) { - db := NewTestDB() - defer db.Close() + db := MustOpenDB() + defer db.MustClose() - db.Update(func(tx *bolt.Tx) error { + if err := db.Update(func(tx *bolt.Tx) error { b, err := tx.CreateBucket([]byte("foo")) - ok(t, err) + if err != nil { + t.Fatal(err) + } for i := 0; i < 100; i++ { - b.Put([]byte(fmt.Sprintf("%02d", i)), []byte(fmt.Sprintf("%02d", i))) + if err := b.Put([]byte(fmt.Sprintf("%02d", i)), []byte(fmt.Sprintf("%02d", i))); err != nil { + t.Fatal(err) + } } + bar, err := b.CreateBucket([]byte("bar")) - ok(t, err) - for i := 0; i < 10; i++ { - bar.Put([]byte(strconv.Itoa(i)), []byte(strconv.Itoa(i))) + if err != nil { + t.Fatal(err) } + for i := 0; i < 10; i++ { + if err := bar.Put([]byte(strconv.Itoa(i)), []byte(strconv.Itoa(i))); err != nil { + t.Fatal(err) + } + } + baz, err := bar.CreateBucket([]byte("baz")) - ok(t, err) - for i := 0; i < 10; i++ { - baz.Put([]byte(strconv.Itoa(i)), []byte(strconv.Itoa(i))) + if err != nil { + t.Fatal(err) } + for i := 0; i < 10; i++ { + if err := baz.Put([]byte(strconv.Itoa(i)), []byte(strconv.Itoa(i))); err != nil { + t.Fatal(err) + } + } + return nil - }) + }); err != nil { + t.Fatal(err) + } db.MustCheck() - db.View(func(tx *bolt.Tx) error { + if err := db.View(func(tx *bolt.Tx) error { b := tx.Bucket([]byte("foo")) stats := b.Stats() - equals(t, 0, stats.BranchPageN) - equals(t, 0, stats.BranchOverflowN) - equals(t, 2, stats.LeafPageN) - equals(t, 0, stats.LeafOverflowN) - equals(t, 122, stats.KeyN) - equals(t, 3, stats.Depth) - equals(t, 0, stats.BranchInuse) + if stats.BranchPageN != 0 { + t.Fatalf("unexpected BranchPageN: %d", stats.BranchPageN) + } else if stats.BranchOverflowN != 0 { + t.Fatalf("unexpected BranchOverflowN: %d", stats.BranchOverflowN) + } else if stats.LeafPageN != 2 { + t.Fatalf("unexpected LeafPageN: %d", stats.LeafPageN) + } else if stats.LeafOverflowN != 0 { + t.Fatalf("unexpected LeafOverflowN: %d", stats.LeafOverflowN) + } else if stats.KeyN != 122 { + t.Fatalf("unexpected KeyN: %d", stats.KeyN) + } else if stats.Depth != 3 { + t.Fatalf("unexpected Depth: %d", stats.Depth) + } else if stats.BranchInuse != 0 { + t.Fatalf("unexpected BranchInuse: %d", stats.BranchInuse) + } foo := 16 // foo (pghdr) foo += 101 * 16 // foo leaf elements @@ -860,17 +1406,30 @@ func TestBucket_Stats_Nested(t *testing.T) { baz += 10 * 16 // baz leaf elements baz += 10 + 10 // baz leaf key/values - equals(t, foo+bar+baz, stats.LeafInuse) - if os.Getpagesize() == 4096 { - // Incompatible page size - equals(t, 0, stats.BranchAlloc) - equals(t, 8192, stats.LeafAlloc) + if stats.LeafInuse != foo+bar+baz { + t.Fatalf("unexpected LeafInuse: %d", stats.LeafInuse) } - equals(t, 3, stats.BucketN) - equals(t, 1, stats.InlineBucketN) - equals(t, baz, stats.InlineBucketInuse) + + if os.Getpagesize() == 4096 { + if stats.BranchAlloc != 0 { + t.Fatalf("unexpected BranchAlloc: %d", stats.BranchAlloc) + } else if stats.LeafAlloc != 8192 { + t.Fatalf("unexpected LeafAlloc: %d", stats.LeafAlloc) + } + } + + if stats.BucketN != 3 { + t.Fatalf("unexpected BucketN: %d", stats.BucketN) + } else if stats.InlineBucketN != 1 { + t.Fatalf("unexpected InlineBucketN: %d", stats.InlineBucketN) + } else if stats.InlineBucketInuse != baz { + t.Fatalf("unexpected InlineBucketInuse: %d", stats.InlineBucketInuse) + } + return nil - }) + }); err != nil { + t.Fatal(err) + } } // Ensure a large bucket can calculate stats. @@ -879,44 +1438,71 @@ func TestBucket_Stats_Large(t *testing.T) { t.Skip("skipping test in short mode.") } - db := NewTestDB() - defer db.Close() + db := MustOpenDB() + defer db.MustClose() var index int for i := 0; i < 100; i++ { - db.Update(func(tx *bolt.Tx) error { - // Add bucket with lots of keys. - b, _ := tx.CreateBucketIfNotExists([]byte("widgets")) + // Add bucket with lots of keys. + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucketIfNotExists([]byte("widgets")) + if err != nil { + t.Fatal(err) + } for i := 0; i < 1000; i++ { - b.Put([]byte(strconv.Itoa(index)), []byte(strconv.Itoa(index))) + if err := b.Put([]byte(strconv.Itoa(index)), []byte(strconv.Itoa(index))); err != nil { + t.Fatal(err) + } index++ } return nil - }) + }); err != nil { + t.Fatal(err) + } } + db.MustCheck() - db.View(func(tx *bolt.Tx) error { - b := tx.Bucket([]byte("widgets")) - stats := b.Stats() - equals(t, 13, stats.BranchPageN) - equals(t, 0, stats.BranchOverflowN) - equals(t, 1196, stats.LeafPageN) - equals(t, 0, stats.LeafOverflowN) - equals(t, 100000, stats.KeyN) - equals(t, 3, stats.Depth) - equals(t, 25257, stats.BranchInuse) - equals(t, 2596916, stats.LeafInuse) - if os.Getpagesize() == 4096 { - // Incompatible page size - equals(t, 53248, stats.BranchAlloc) - equals(t, 4898816, stats.LeafAlloc) + if err := db.View(func(tx *bolt.Tx) error { + stats := tx.Bucket([]byte("widgets")).Stats() + if stats.BranchPageN != 13 { + t.Fatalf("unexpected BranchPageN: %d", stats.BranchPageN) + } else if stats.BranchOverflowN != 0 { + t.Fatalf("unexpected BranchOverflowN: %d", stats.BranchOverflowN) + } else if stats.LeafPageN != 1196 { + t.Fatalf("unexpected LeafPageN: %d", stats.LeafPageN) + } else if stats.LeafOverflowN != 0 { + t.Fatalf("unexpected LeafOverflowN: %d", stats.LeafOverflowN) + } else if stats.KeyN != 100000 { + t.Fatalf("unexpected KeyN: %d", stats.KeyN) + } else if stats.Depth != 3 { + t.Fatalf("unexpected Depth: %d", stats.Depth) + } else if stats.BranchInuse != 25257 { + t.Fatalf("unexpected BranchInuse: %d", stats.BranchInuse) + } else if stats.LeafInuse != 2596916 { + t.Fatalf("unexpected LeafInuse: %d", stats.LeafInuse) } - equals(t, 1, stats.BucketN) - equals(t, 0, stats.InlineBucketN) - equals(t, 0, stats.InlineBucketInuse) + + if os.Getpagesize() == 4096 { + if stats.BranchAlloc != 53248 { + t.Fatalf("unexpected BranchAlloc: %d", stats.BranchAlloc) + } else if stats.LeafAlloc != 4898816 { + t.Fatalf("unexpected LeafAlloc: %d", stats.LeafAlloc) + } + } + + if stats.BucketN != 1 { + t.Fatalf("unexpected BucketN: %d", stats.BucketN) + } else if stats.InlineBucketN != 0 { + t.Fatalf("unexpected InlineBucketN: %d", stats.InlineBucketN) + } else if stats.InlineBucketInuse != 0 { + t.Fatalf("unexpected InlineBucketInuse: %d", stats.InlineBucketInuse) + } + return nil - }) + }); err != nil { + t.Fatal(err) + } } // Ensure that a bucket can write random keys and values across multiple transactions. @@ -926,27 +1512,34 @@ func TestBucket_Put_Single(t *testing.T) { } index := 0 - f := func(items testdata) bool { - db := NewTestDB() - defer db.Close() + if err := quick.Check(func(items testdata) bool { + db := MustOpenDB() + defer db.MustClose() m := make(map[string][]byte) - db.Update(func(tx *bolt.Tx) error { - _, err := tx.CreateBucket([]byte("widgets")) - return err - }) + if err := db.Update(func(tx *bolt.Tx) error { + if _, err := tx.CreateBucket([]byte("widgets")); err != nil { + t.Fatal(err) + } + return nil + }); err != nil { + t.Fatal(err) + } + for _, item := range items { - db.Update(func(tx *bolt.Tx) error { + if err := db.Update(func(tx *bolt.Tx) error { if err := tx.Bucket([]byte("widgets")).Put(item.Key, item.Value); err != nil { panic("put error: " + err.Error()) } m[string(item.Key)] = item.Value return nil - }) + }); err != nil { + t.Fatal(err) + } // Verify all key/values so far. - db.View(func(tx *bolt.Tx) error { + if err := db.View(func(tx *bolt.Tx) error { i := 0 for k, v := range m { value := tx.Bucket([]byte("widgets")).Get([]byte(k)) @@ -958,13 +1551,14 @@ func TestBucket_Put_Single(t *testing.T) { i++ } return nil - }) + }); err != nil { + t.Fatal(err) + } } index++ return true - } - if err := quick.Check(f, qconfig()); err != nil { + }, nil); err != nil { t.Error(err) } } @@ -975,25 +1569,34 @@ func TestBucket_Put_Multiple(t *testing.T) { t.Skip("skipping test in short mode.") } - f := func(items testdata) bool { - db := NewTestDB() - defer db.Close() + if err := quick.Check(func(items testdata) bool { + db := MustOpenDB() + defer db.MustClose() + // Bulk insert all values. - db.Update(func(tx *bolt.Tx) error { - _, err := tx.CreateBucket([]byte("widgets")) - return err - }) - err := db.Update(func(tx *bolt.Tx) error { - b := tx.Bucket([]byte("widgets")) - for _, item := range items { - ok(t, b.Put(item.Key, item.Value)) + if err := db.Update(func(tx *bolt.Tx) error { + if _, err := tx.CreateBucket([]byte("widgets")); err != nil { + t.Fatal(err) } return nil - }) - ok(t, err) + }); err != nil { + t.Fatal(err) + } + + if err := db.Update(func(tx *bolt.Tx) error { + b := tx.Bucket([]byte("widgets")) + for _, item := range items { + if err := b.Put(item.Key, item.Value); err != nil { + t.Fatal(err) + } + } + return nil + }); err != nil { + t.Fatal(err) + } // Verify all items exist. - db.View(func(tx *bolt.Tx) error { + if err := db.View(func(tx *bolt.Tx) error { b := tx.Bucket([]byte("widgets")) for _, item := range items { value := b.Get(item.Key) @@ -1003,10 +1606,12 @@ func TestBucket_Put_Multiple(t *testing.T) { } } return nil - }) + }); err != nil { + t.Fatal(err) + } + return true - } - if err := quick.Check(f, qconfig()); err != nil { + }, qconfig()); err != nil { t.Error(err) } } @@ -1017,68 +1622,98 @@ func TestBucket_Delete_Quick(t *testing.T) { t.Skip("skipping test in short mode.") } - f := func(items testdata) bool { - db := NewTestDB() - defer db.Close() + if err := quick.Check(func(items testdata) bool { + db := MustOpenDB() + defer db.MustClose() + // Bulk insert all values. - db.Update(func(tx *bolt.Tx) error { - _, err := tx.CreateBucket([]byte("widgets")) - return err - }) - err := db.Update(func(tx *bolt.Tx) error { - b := tx.Bucket([]byte("widgets")) - for _, item := range items { - ok(t, b.Put(item.Key, item.Value)) + if err := db.Update(func(tx *bolt.Tx) error { + if _, err := tx.CreateBucket([]byte("widgets")); err != nil { + t.Fatal(err) } return nil - }) - ok(t, err) + }); err != nil { + t.Fatal(err) + } + + if err := db.Update(func(tx *bolt.Tx) error { + b := tx.Bucket([]byte("widgets")) + for _, item := range items { + if err := b.Put(item.Key, item.Value); err != nil { + t.Fatal(err) + } + } + return nil + }); err != nil { + t.Fatal(err) + } // Remove items one at a time and check consistency. for _, item := range items { - err := db.Update(func(tx *bolt.Tx) error { + if err := db.Update(func(tx *bolt.Tx) error { return tx.Bucket([]byte("widgets")).Delete(item.Key) - }) - ok(t, err) + }); err != nil { + t.Fatal(err) + } } // Anything before our deletion index should be nil. - db.View(func(tx *bolt.Tx) error { - tx.Bucket([]byte("widgets")).ForEach(func(k, v []byte) error { + if err := db.View(func(tx *bolt.Tx) error { + if err := tx.Bucket([]byte("widgets")).ForEach(func(k, v []byte) error { t.Fatalf("bucket should be empty; found: %06x", trunc(k, 3)) return nil - }) + }); err != nil { + t.Fatal(err) + } return nil - }) + }); err != nil { + t.Fatal(err) + } + return true - } - if err := quick.Check(f, qconfig()); err != nil { + }, qconfig()); err != nil { t.Error(err) } } func ExampleBucket_Put() { // Open the database. - db, _ := bolt.Open(tempfile(), 0666, nil) + db, err := bolt.Open(tempfile(), 0666, nil) + if err != nil { + log.Fatal(err) + } defer os.Remove(db.Path()) - defer db.Close() // Start a write transaction. - db.Update(func(tx *bolt.Tx) error { + if err := db.Update(func(tx *bolt.Tx) error { // Create a bucket. - tx.CreateBucket([]byte("widgets")) + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + return err + } // Set the value "bar" for the key "foo". - tx.Bucket([]byte("widgets")).Put([]byte("foo"), []byte("bar")) + if err := b.Put([]byte("foo"), []byte("bar")); err != nil { + return err + } return nil - }) + }); err != nil { + log.Fatal(err) + } // Read value back in a different read-only transaction. - db.View(func(tx *bolt.Tx) error { + if err := db.View(func(tx *bolt.Tx) error { value := tx.Bucket([]byte("widgets")).Get([]byte("foo")) fmt.Printf("The value of 'foo' is: %s\n", value) return nil - }) + }); err != nil { + log.Fatal(err) + } + + // Close database to release file lock. + if err := db.Close(); err != nil { + log.Fatal(err) + } // Output: // The value of 'foo' is: bar @@ -1086,38 +1721,56 @@ func ExampleBucket_Put() { func ExampleBucket_Delete() { // Open the database. - db, _ := bolt.Open(tempfile(), 0666, nil) + db, err := bolt.Open(tempfile(), 0666, nil) + if err != nil { + log.Fatal(err) + } defer os.Remove(db.Path()) - defer db.Close() // Start a write transaction. - db.Update(func(tx *bolt.Tx) error { + if err := db.Update(func(tx *bolt.Tx) error { // Create a bucket. - tx.CreateBucket([]byte("widgets")) - b := tx.Bucket([]byte("widgets")) + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + return err + } // Set the value "bar" for the key "foo". - b.Put([]byte("foo"), []byte("bar")) + if err := b.Put([]byte("foo"), []byte("bar")); err != nil { + return err + } // Retrieve the key back from the database and verify it. value := b.Get([]byte("foo")) fmt.Printf("The value of 'foo' was: %s\n", value) + return nil - }) + }); err != nil { + log.Fatal(err) + } // Delete the key in a different write transaction. - db.Update(func(tx *bolt.Tx) error { + if err := db.Update(func(tx *bolt.Tx) error { return tx.Bucket([]byte("widgets")).Delete([]byte("foo")) - }) + }); err != nil { + log.Fatal(err) + } // Retrieve the key again. - db.View(func(tx *bolt.Tx) error { + if err := db.View(func(tx *bolt.Tx) error { value := tx.Bucket([]byte("widgets")).Get([]byte("foo")) if value == nil { fmt.Printf("The value of 'foo' is now: nil\n") } return nil - }) + }); err != nil { + log.Fatal(err) + } + + // Close database to release file lock. + if err := db.Close(); err != nil { + log.Fatal(err) + } // Output: // The value of 'foo' was: bar @@ -1126,25 +1779,46 @@ func ExampleBucket_Delete() { func ExampleBucket_ForEach() { // Open the database. - db, _ := bolt.Open(tempfile(), 0666, nil) + db, err := bolt.Open(tempfile(), 0666, nil) + if err != nil { + log.Fatal(err) + } defer os.Remove(db.Path()) - defer db.Close() // Insert data into a bucket. - db.Update(func(tx *bolt.Tx) error { - tx.CreateBucket([]byte("animals")) - b := tx.Bucket([]byte("animals")) - b.Put([]byte("dog"), []byte("fun")) - b.Put([]byte("cat"), []byte("lame")) - b.Put([]byte("liger"), []byte("awesome")) + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("animals")) + if err != nil { + return err + } + + if err := b.Put([]byte("dog"), []byte("fun")); err != nil { + return err + } + if err := b.Put([]byte("cat"), []byte("lame")); err != nil { + return err + } + if err := b.Put([]byte("liger"), []byte("awesome")); err != nil { + return err + } // Iterate over items in sorted key order. - b.ForEach(func(k, v []byte) error { + if err := b.ForEach(func(k, v []byte) error { fmt.Printf("A %s is %s.\n", k, v) return nil - }) + }); err != nil { + return err + } + return nil - }) + }); err != nil { + log.Fatal(err) + } + + // Close database to release file lock. + if err := db.Close(); err != nil { + log.Fatal(err) + } // Output: // A cat is lame. diff --git a/Godeps/_workspace/src/github.com/boltdb/bolt/cmd/bolt/bench.go b/Godeps/_workspace/src/github.com/boltdb/bolt/cmd/bolt/bench.go deleted file mode 100644 index 80901ab6..00000000 --- a/Godeps/_workspace/src/github.com/boltdb/bolt/cmd/bolt/bench.go +++ /dev/null @@ -1,421 +0,0 @@ -package main - -import ( - "encoding/binary" - "encoding/json" - "errors" - "fmt" - "io/ioutil" - "math/rand" - "os" - "runtime" - "runtime/pprof" - "time" - - "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/boltdb/bolt" -) - -// File handlers for the various profiles. -var cpuprofile, memprofile, blockprofile *os.File - -var benchBucketName = []byte("bench") - -// Bench executes a customizable, synthetic benchmark against Bolt. -func Bench(options *BenchOptions) { - var results BenchResults - - // Validate options. - if options.BatchSize == 0 { - options.BatchSize = options.Iterations - } else if options.Iterations%options.BatchSize != 0 { - fatal("number of iterations must be divisible by the batch size") - } - - // Find temporary location. - path := tempfile() - - if options.Clean { - defer os.Remove(path) - } else { - println("work:", path) - } - - // Create database. - db, err := bolt.Open(path, 0600, nil) - if err != nil { - fatal(err) - return - } - db.NoSync = options.NoSync - defer db.Close() - - // Enable streaming stats. - if options.StatsInterval > 0 { - go printStats(db, options.StatsInterval) - } - - // Start profiling for writes. - if options.ProfileMode == "rw" || options.ProfileMode == "w" { - benchStartProfiling(options) - } - - // Write to the database. - if err := benchWrite(db, options, &results); err != nil { - fatal("bench: write: ", err) - } - - // Stop profiling for writes only. - if options.ProfileMode == "w" { - benchStopProfiling() - } - - // Start profiling for reads. - if options.ProfileMode == "r" { - benchStartProfiling(options) - } - - // Read from the database. - if err := benchRead(db, options, &results); err != nil { - fatal("bench: read: ", err) - } - - // Stop profiling for writes only. - if options.ProfileMode == "rw" || options.ProfileMode == "r" { - benchStopProfiling() - } - - // Print results. - fmt.Fprintf(os.Stderr, "# Write\t%v\t(%v/op)\t(%v op/sec)\n", results.WriteDuration, results.WriteOpDuration(), results.WriteOpsPerSecond()) - fmt.Fprintf(os.Stderr, "# Read\t%v\t(%v/op)\t(%v op/sec)\n", results.ReadDuration, results.ReadOpDuration(), results.ReadOpsPerSecond()) - fmt.Fprintln(os.Stderr, "") -} - -// Writes to the database. -func benchWrite(db *bolt.DB, options *BenchOptions, results *BenchResults) error { - var err error - var t = time.Now() - - switch options.WriteMode { - case "seq": - err = benchWriteSequential(db, options, results) - case "rnd": - err = benchWriteRandom(db, options, results) - case "seq-nest": - err = benchWriteSequentialNested(db, options, results) - case "rnd-nest": - err = benchWriteRandomNested(db, options, results) - default: - return fmt.Errorf("invalid write mode: %s", options.WriteMode) - } - - results.WriteDuration = time.Since(t) - - return err -} - -func benchWriteSequential(db *bolt.DB, options *BenchOptions, results *BenchResults) error { - var i = uint32(0) - return benchWriteWithSource(db, options, results, func() uint32 { i++; return i }) -} - -func benchWriteRandom(db *bolt.DB, options *BenchOptions, results *BenchResults) error { - r := rand.New(rand.NewSource(time.Now().UnixNano())) - return benchWriteWithSource(db, options, results, func() uint32 { return r.Uint32() }) -} - -func benchWriteSequentialNested(db *bolt.DB, options *BenchOptions, results *BenchResults) error { - var i = uint32(0) - return benchWriteNestedWithSource(db, options, results, func() uint32 { i++; return i }) -} - -func benchWriteRandomNested(db *bolt.DB, options *BenchOptions, results *BenchResults) error { - r := rand.New(rand.NewSource(time.Now().UnixNano())) - return benchWriteNestedWithSource(db, options, results, func() uint32 { return r.Uint32() }) -} - -func benchWriteWithSource(db *bolt.DB, options *BenchOptions, results *BenchResults, keySource func() uint32) error { - results.WriteOps = options.Iterations - - for i := 0; i < options.Iterations; i += options.BatchSize { - err := db.Update(func(tx *bolt.Tx) error { - b, _ := tx.CreateBucketIfNotExists(benchBucketName) - b.FillPercent = options.FillPercent - - for j := 0; j < options.BatchSize; j++ { - var key = make([]byte, options.KeySize) - var value = make([]byte, options.ValueSize) - binary.BigEndian.PutUint32(key, keySource()) - if err := b.Put(key, value); err != nil { - return err - } - } - - return nil - }) - if err != nil { - return err - } - } - return nil -} - -func benchWriteNestedWithSource(db *bolt.DB, options *BenchOptions, results *BenchResults, keySource func() uint32) error { - results.WriteOps = options.Iterations - - for i := 0; i < options.Iterations; i += options.BatchSize { - err := db.Update(func(tx *bolt.Tx) error { - top, _ := tx.CreateBucketIfNotExists(benchBucketName) - top.FillPercent = options.FillPercent - - var name = make([]byte, options.KeySize) - binary.BigEndian.PutUint32(name, keySource()) - b, _ := top.CreateBucketIfNotExists(name) - b.FillPercent = options.FillPercent - - for j := 0; j < options.BatchSize; j++ { - var key = make([]byte, options.KeySize) - var value = make([]byte, options.ValueSize) - binary.BigEndian.PutUint32(key, keySource()) - if err := b.Put(key, value); err != nil { - return err - } - } - - return nil - }) - if err != nil { - return err - } - } - return nil -} - -// Reads from the database. -func benchRead(db *bolt.DB, options *BenchOptions, results *BenchResults) error { - var err error - var t = time.Now() - - switch options.ReadMode { - case "seq": - if options.WriteMode == "seq-nest" || options.WriteMode == "rnd-nest" { - err = benchReadSequentialNested(db, options, results) - } else { - err = benchReadSequential(db, options, results) - } - default: - return fmt.Errorf("invalid read mode: %s", options.ReadMode) - } - - results.ReadDuration = time.Since(t) - - return err -} - -func benchReadSequential(db *bolt.DB, options *BenchOptions, results *BenchResults) error { - return db.View(func(tx *bolt.Tx) error { - var t = time.Now() - - for { - c := tx.Bucket(benchBucketName).Cursor() - var count int - for k, v := c.First(); k != nil; k, v = c.Next() { - if v == nil { - return errors.New("invalid value") - } - count++ - } - - if options.WriteMode == "seq" && count != options.Iterations { - return fmt.Errorf("read seq: iter mismatch: expected %d, got %d", options.Iterations, count) - } - - results.ReadOps += count - - // Make sure we do this for at least a second. - if time.Since(t) >= time.Second { - break - } - } - - return nil - }) -} - -func benchReadSequentialNested(db *bolt.DB, options *BenchOptions, results *BenchResults) error { - return db.View(func(tx *bolt.Tx) error { - var t = time.Now() - - for { - var count int - var top = tx.Bucket(benchBucketName) - top.ForEach(func(name, _ []byte) error { - c := top.Bucket(name).Cursor() - for k, v := c.First(); k != nil; k, v = c.Next() { - if v == nil { - return errors.New("invalid value") - } - count++ - } - return nil - }) - - if options.WriteMode == "seq-nest" && count != options.Iterations { - return fmt.Errorf("read seq-nest: iter mismatch: expected %d, got %d", options.Iterations, count) - } - - results.ReadOps += count - - // Make sure we do this for at least a second. - if time.Since(t) >= time.Second { - break - } - } - - return nil - }) -} - -// Starts all profiles set on the options. -func benchStartProfiling(options *BenchOptions) { - var err error - - // Start CPU profiling. - if options.CPUProfile != "" { - cpuprofile, err = os.Create(options.CPUProfile) - if err != nil { - fatalf("bench: could not create cpu profile %q: %v", options.CPUProfile, err) - } - pprof.StartCPUProfile(cpuprofile) - } - - // Start memory profiling. - if options.MemProfile != "" { - memprofile, err = os.Create(options.MemProfile) - if err != nil { - fatalf("bench: could not create memory profile %q: %v", options.MemProfile, err) - } - runtime.MemProfileRate = 4096 - } - - // Start fatal profiling. - if options.BlockProfile != "" { - blockprofile, err = os.Create(options.BlockProfile) - if err != nil { - fatalf("bench: could not create block profile %q: %v", options.BlockProfile, err) - } - runtime.SetBlockProfileRate(1) - } -} - -// Stops all profiles. -func benchStopProfiling() { - if cpuprofile != nil { - pprof.StopCPUProfile() - cpuprofile.Close() - cpuprofile = nil - } - - if memprofile != nil { - pprof.Lookup("heap").WriteTo(memprofile, 0) - memprofile.Close() - memprofile = nil - } - - if blockprofile != nil { - pprof.Lookup("block").WriteTo(blockprofile, 0) - blockprofile.Close() - blockprofile = nil - runtime.SetBlockProfileRate(0) - } -} - -// Continuously prints stats on the database at given intervals. -func printStats(db *bolt.DB, interval time.Duration) { - var prevStats = db.Stats() - var encoder = json.NewEncoder(os.Stdout) - - for { - // Wait for the stats interval. - time.Sleep(interval) - - // Retrieve new stats and find difference from previous iteration. - var stats = db.Stats() - var diff = stats.Sub(&prevStats) - - // Print as JSON to STDOUT. - if err := encoder.Encode(diff); err != nil { - fatal(err) - } - - // Save stats for next iteration. - prevStats = stats - } -} - -// BenchOptions represents the set of options that can be passed to Bench(). -type BenchOptions struct { - ProfileMode string - WriteMode string - ReadMode string - Iterations int - BatchSize int - KeySize int - ValueSize int - CPUProfile string - MemProfile string - BlockProfile string - StatsInterval time.Duration - FillPercent float64 - NoSync bool - Clean bool -} - -// BenchResults represents the performance results of the benchmark. -type BenchResults struct { - WriteOps int - WriteDuration time.Duration - ReadOps int - ReadDuration time.Duration -} - -// Returns the duration for a single write operation. -func (r *BenchResults) WriteOpDuration() time.Duration { - if r.WriteOps == 0 { - return 0 - } - return r.WriteDuration / time.Duration(r.WriteOps) -} - -// Returns average number of write operations that can be performed per second. -func (r *BenchResults) WriteOpsPerSecond() int { - var op = r.WriteOpDuration() - if op == 0 { - return 0 - } - return int(time.Second) / int(op) -} - -// Returns the duration for a single read operation. -func (r *BenchResults) ReadOpDuration() time.Duration { - if r.ReadOps == 0 { - return 0 - } - return r.ReadDuration / time.Duration(r.ReadOps) -} - -// Returns average number of read operations that can be performed per second. -func (r *BenchResults) ReadOpsPerSecond() int { - var op = r.ReadOpDuration() - if op == 0 { - return 0 - } - return int(time.Second) / int(op) -} - -// tempfile returns a temporary file path. -func tempfile() string { - f, _ := ioutil.TempFile("", "bolt-bench-") - f.Close() - os.Remove(f.Name()) - return f.Name() -} diff --git a/Godeps/_workspace/src/github.com/boltdb/bolt/cmd/bolt/buckets.go b/Godeps/_workspace/src/github.com/boltdb/bolt/cmd/bolt/buckets.go deleted file mode 100644 index 71acabd8..00000000 --- a/Godeps/_workspace/src/github.com/boltdb/bolt/cmd/bolt/buckets.go +++ /dev/null @@ -1,33 +0,0 @@ -package main - -import ( - "os" - - "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/boltdb/bolt" -) - -// Buckets prints a list of all buckets. -func Buckets(path string) { - if _, err := os.Stat(path); os.IsNotExist(err) { - fatal(err) - return - } - - db, err := bolt.Open(path, 0600, nil) - if err != nil { - fatal(err) - return - } - defer db.Close() - - err = db.View(func(tx *bolt.Tx) error { - return tx.ForEach(func(name []byte, _ *bolt.Bucket) error { - println(string(name)) - return nil - }) - }) - if err != nil { - fatal(err) - return - } -} diff --git a/Godeps/_workspace/src/github.com/boltdb/bolt/cmd/bolt/buckets_test.go b/Godeps/_workspace/src/github.com/boltdb/bolt/cmd/bolt/buckets_test.go deleted file mode 100644 index ff099254..00000000 --- a/Godeps/_workspace/src/github.com/boltdb/bolt/cmd/bolt/buckets_test.go +++ /dev/null @@ -1,31 +0,0 @@ -package main_test - -import ( - "testing" - - "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/boltdb/bolt" - . "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/boltdb/bolt/cmd/bolt" -) - -// Ensure that a list of buckets can be retrieved. -func TestBuckets(t *testing.T) { - SetTestMode(true) - open(func(db *bolt.DB, path string) { - db.Update(func(tx *bolt.Tx) error { - tx.CreateBucket([]byte("woojits")) - tx.CreateBucket([]byte("widgets")) - tx.CreateBucket([]byte("whatchits")) - return nil - }) - db.Close() - output := run("buckets", path) - equals(t, "whatchits\nwidgets\nwoojits", output) - }) -} - -// Ensure that an error is reported if the database is not found. -func TestBucketsDBNotFound(t *testing.T) { - SetTestMode(true) - output := run("buckets", "no/such/db") - equals(t, "stat no/such/db: no such file or directory", output) -} diff --git a/Godeps/_workspace/src/github.com/boltdb/bolt/cmd/bolt/check.go b/Godeps/_workspace/src/github.com/boltdb/bolt/cmd/bolt/check.go deleted file mode 100644 index 96975023..00000000 --- a/Godeps/_workspace/src/github.com/boltdb/bolt/cmd/bolt/check.go +++ /dev/null @@ -1,47 +0,0 @@ -package main - -import ( - "os" - - "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/boltdb/bolt" -) - -// Check performs a consistency check on the database and prints any errors found. -func Check(path string) { - if _, err := os.Stat(path); os.IsNotExist(err) { - fatal(err) - return - } - - db, err := bolt.Open(path, 0600, nil) - if err != nil { - fatal(err) - return - } - defer db.Close() - - // Perform consistency check. - _ = db.View(func(tx *bolt.Tx) error { - var count int - ch := tx.Check() - loop: - for { - select { - case err, ok := <-ch: - if !ok { - break loop - } - println(err) - count++ - } - } - - // Print summary of errors. - if count > 0 { - fatalf("%d errors found", count) - } else { - println("OK") - } - return nil - }) -} diff --git a/Godeps/_workspace/src/github.com/boltdb/bolt/cmd/bolt/get.go b/Godeps/_workspace/src/github.com/boltdb/bolt/cmd/bolt/get.go deleted file mode 100644 index 464fa167..00000000 --- a/Godeps/_workspace/src/github.com/boltdb/bolt/cmd/bolt/get.go +++ /dev/null @@ -1,45 +0,0 @@ -package main - -import ( - "os" - - "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/boltdb/bolt" -) - -// Get retrieves the value for a given bucket/key. -func Get(path, name, key string) { - if _, err := os.Stat(path); os.IsNotExist(err) { - fatal(err) - return - } - - db, err := bolt.Open(path, 0600, nil) - if err != nil { - fatal(err) - return - } - defer db.Close() - - err = db.View(func(tx *bolt.Tx) error { - // Find bucket. - b := tx.Bucket([]byte(name)) - if b == nil { - fatalf("bucket not found: %s", name) - return nil - } - - // Find value for a given key. - value := b.Get([]byte(key)) - if value == nil { - fatalf("key not found: %s", key) - return nil - } - - println(string(value)) - return nil - }) - if err != nil { - fatal(err) - return - } -} diff --git a/Godeps/_workspace/src/github.com/boltdb/bolt/cmd/bolt/get_test.go b/Godeps/_workspace/src/github.com/boltdb/bolt/cmd/bolt/get_test.go deleted file mode 100644 index 1e09a237..00000000 --- a/Godeps/_workspace/src/github.com/boltdb/bolt/cmd/bolt/get_test.go +++ /dev/null @@ -1,54 +0,0 @@ -package main_test - -import ( - "testing" - - "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/boltdb/bolt" - . "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/boltdb/bolt/cmd/bolt" -) - -// Ensure that a value can be retrieved from the CLI. -func TestGet(t *testing.T) { - SetTestMode(true) - open(func(db *bolt.DB, path string) { - db.Update(func(tx *bolt.Tx) error { - tx.CreateBucket([]byte("widgets")) - tx.Bucket([]byte("widgets")).Put([]byte("foo"), []byte("bar")) - return nil - }) - db.Close() - output := run("get", path, "widgets", "foo") - equals(t, "bar", output) - }) -} - -// Ensure that an error is reported if the database is not found. -func TestGetDBNotFound(t *testing.T) { - SetTestMode(true) - output := run("get", "no/such/db", "widgets", "foo") - equals(t, "stat no/such/db: no such file or directory", output) -} - -// Ensure that an error is reported if the bucket is not found. -func TestGetBucketNotFound(t *testing.T) { - SetTestMode(true) - open(func(db *bolt.DB, path string) { - db.Close() - output := run("get", path, "widgets", "foo") - equals(t, "bucket not found: widgets", output) - }) -} - -// Ensure that an error is reported if the key is not found. -func TestGetKeyNotFound(t *testing.T) { - SetTestMode(true) - open(func(db *bolt.DB, path string) { - db.Update(func(tx *bolt.Tx) error { - _, err := tx.CreateBucket([]byte("widgets")) - return err - }) - db.Close() - output := run("get", path, "widgets", "foo") - equals(t, "key not found: foo", output) - }) -} diff --git a/Godeps/_workspace/src/github.com/boltdb/bolt/cmd/bolt/info.go b/Godeps/_workspace/src/github.com/boltdb/bolt/cmd/bolt/info.go deleted file mode 100644 index eb07f7f2..00000000 --- a/Godeps/_workspace/src/github.com/boltdb/bolt/cmd/bolt/info.go +++ /dev/null @@ -1,26 +0,0 @@ -package main - -import ( - "os" - - "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/boltdb/bolt" -) - -// Info prints basic information about a database. -func Info(path string) { - if _, err := os.Stat(path); os.IsNotExist(err) { - fatal(err) - return - } - - db, err := bolt.Open(path, 0600, nil) - if err != nil { - fatal(err) - return - } - defer db.Close() - - // Print basic database info. - var info = db.Info() - printf("Page Size: %d\n", info.PageSize) -} diff --git a/Godeps/_workspace/src/github.com/boltdb/bolt/cmd/bolt/info_test.go b/Godeps/_workspace/src/github.com/boltdb/bolt/cmd/bolt/info_test.go deleted file mode 100644 index 87d2664f..00000000 --- a/Godeps/_workspace/src/github.com/boltdb/bolt/cmd/bolt/info_test.go +++ /dev/null @@ -1,31 +0,0 @@ -package main_test - -import ( - "testing" - - "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/boltdb/bolt" - . "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/boltdb/bolt/cmd/bolt" -) - -func // Ensure that a database info can be printed. -TestInfo(t *testing.T) { - SetTestMode(true) - open(func(db *bolt.DB, path string) { - db.Update(func(tx *bolt.Tx) error { - tx.CreateBucket([]byte("widgets")) - b := tx.Bucket([]byte("widgets")) - b.Put([]byte("foo"), []byte("0000")) - return nil - }) - db.Close() - output := run("info", path) - equals(t, `Page Size: 4096`, output) - }) -} - -// Ensure that an error is reported if the database is not found. -func TestInfo_NotFound(t *testing.T) { - SetTestMode(true) - output := run("info", "no/such/db") - equals(t, "stat no/such/db: no such file or directory", output) -} diff --git a/Godeps/_workspace/src/github.com/boltdb/bolt/cmd/bolt/keys.go b/Godeps/_workspace/src/github.com/boltdb/bolt/cmd/bolt/keys.go deleted file mode 100644 index f9b2a587..00000000 --- a/Godeps/_workspace/src/github.com/boltdb/bolt/cmd/bolt/keys.go +++ /dev/null @@ -1,41 +0,0 @@ -package main - -import ( - "os" - - "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/boltdb/bolt" -) - -// Keys retrieves a list of keys for a given bucket. -func Keys(path, name string) { - if _, err := os.Stat(path); os.IsNotExist(err) { - fatal(err) - return - } - - db, err := bolt.Open(path, 0600, nil) - if err != nil { - fatal(err) - return - } - defer db.Close() - - err = db.View(func(tx *bolt.Tx) error { - // Find bucket. - b := tx.Bucket([]byte(name)) - if b == nil { - fatalf("bucket not found: %s", name) - return nil - } - - // Iterate over each key. - return b.ForEach(func(key, _ []byte) error { - println(string(key)) - return nil - }) - }) - if err != nil { - fatal(err) - return - } -} diff --git a/Godeps/_workspace/src/github.com/boltdb/bolt/cmd/bolt/keys_test.go b/Godeps/_workspace/src/github.com/boltdb/bolt/cmd/bolt/keys_test.go deleted file mode 100644 index 17a2ad65..00000000 --- a/Godeps/_workspace/src/github.com/boltdb/bolt/cmd/bolt/keys_test.go +++ /dev/null @@ -1,42 +0,0 @@ -package main_test - -import ( - "testing" - - "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/boltdb/bolt" - . "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/boltdb/bolt/cmd/bolt" -) - -// Ensure that a list of keys can be retrieved for a given bucket. -func TestKeys(t *testing.T) { - SetTestMode(true) - open(func(db *bolt.DB, path string) { - db.Update(func(tx *bolt.Tx) error { - tx.CreateBucket([]byte("widgets")) - tx.Bucket([]byte("widgets")).Put([]byte("0002"), []byte("")) - tx.Bucket([]byte("widgets")).Put([]byte("0001"), []byte("")) - tx.Bucket([]byte("widgets")).Put([]byte("0003"), []byte("")) - return nil - }) - db.Close() - output := run("keys", path, "widgets") - equals(t, "0001\n0002\n0003", output) - }) -} - -// Ensure that an error is reported if the database is not found. -func TestKeysDBNotFound(t *testing.T) { - SetTestMode(true) - output := run("keys", "no/such/db", "widgets") - equals(t, "stat no/such/db: no such file or directory", output) -} - -// Ensure that an error is reported if the bucket is not found. -func TestKeysBucketNotFound(t *testing.T) { - SetTestMode(true) - open(func(db *bolt.DB, path string) { - db.Close() - output := run("keys", path, "widgets") - equals(t, "bucket not found: widgets", output) - }) -} diff --git a/Godeps/_workspace/src/github.com/boltdb/bolt/cmd/bolt/main.go b/Godeps/_workspace/src/github.com/boltdb/bolt/cmd/bolt/main.go index 71068451..9c46a3b7 100644 --- a/Godeps/_workspace/src/github.com/boltdb/bolt/cmd/bolt/main.go +++ b/Godeps/_workspace/src/github.com/boltdb/bolt/cmd/bolt/main.go @@ -2,197 +2,1531 @@ package main import ( "bytes" + "encoding/binary" + "errors" + "flag" "fmt" - "log" + "io" + "io/ioutil" + "math/rand" "os" + "runtime" + "runtime/pprof" + "strconv" + "strings" "time" + "unicode" + "unicode/utf8" + "unsafe" - "github.com/codegangsta/cli" "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/boltdb/bolt" ) -var branch, commit string +var ( + // ErrUsage is returned when a usage message was printed and the process + // should simply exit with an error. + ErrUsage = errors.New("usage") + + // ErrUnknownCommand is returned when a CLI command is not specified. + ErrUnknownCommand = errors.New("unknown command") + + // ErrPathRequired is returned when the path to a Bolt database is not specified. + ErrPathRequired = errors.New("path required") + + // ErrFileNotFound is returned when a Bolt database does not exist. + ErrFileNotFound = errors.New("file not found") + + // ErrInvalidValue is returned when a benchmark reads an unexpected value. + ErrInvalidValue = errors.New("invalid value") + + // ErrCorrupt is returned when a checking a data file finds errors. + ErrCorrupt = errors.New("invalid value") + + // ErrNonDivisibleBatchSize is returned when the batch size can't be evenly + // divided by the iteration count. + ErrNonDivisibleBatchSize = errors.New("number of iterations must be divisible by the batch size") + + // ErrPageIDRequired is returned when a required page id is not specified. + ErrPageIDRequired = errors.New("page id required") + + // ErrPageNotFound is returned when specifying a page above the high water mark. + ErrPageNotFound = errors.New("page not found") + + // ErrPageFreed is returned when reading a page that has already been freed. + ErrPageFreed = errors.New("page freed") +) + +// PageHeaderSize represents the size of the bolt.page header. +const PageHeaderSize = 16 func main() { - log.SetFlags(0) - NewApp().Run(os.Args) + m := NewMain() + if err := m.Run(os.Args[1:]...); err == ErrUsage { + os.Exit(2) + } else if err != nil { + fmt.Println(err.Error()) + os.Exit(1) + } } -// NewApp creates an Application instance. -func NewApp() *cli.App { - app := cli.NewApp() - app.Name = "bolt" - app.Usage = "BoltDB toolkit" - app.Version = fmt.Sprintf("0.1.0 (%s %s)", branch, commit) - app.Commands = []cli.Command{ - { - Name: "info", - Usage: "Print basic information about a database", - Action: func(c *cli.Context) { - path := c.Args().Get(0) - Info(path) - }, - }, - { - Name: "get", - Usage: "Retrieve a value for given key in a bucket", - Action: func(c *cli.Context) { - path, name, key := c.Args().Get(0), c.Args().Get(1), c.Args().Get(2) - Get(path, name, key) - }, - }, - { - Name: "keys", - Usage: "Retrieve a list of all keys in a bucket", - Action: func(c *cli.Context) { - path, name := c.Args().Get(0), c.Args().Get(1) - Keys(path, name) - }, - }, - { - Name: "buckets", - Usage: "Retrieves a list of all buckets", - Action: func(c *cli.Context) { - path := c.Args().Get(0) - Buckets(path) - }, - }, - { - Name: "pages", - Usage: "Dumps page information for a database", - Action: func(c *cli.Context) { - path := c.Args().Get(0) - Pages(path) - }, - }, - { - Name: "check", - Usage: "Performs a consistency check on the database", - Action: func(c *cli.Context) { - path := c.Args().Get(0) - Check(path) - }, - }, - { - Name: "stats", - Usage: "Aggregate statistics for all buckets matching specified prefix", - Action: func(c *cli.Context) { - path, name := c.Args().Get(0), c.Args().Get(1) - Stats(path, name) - }, - }, - { - Name: "bench", - Usage: "Performs a synthetic benchmark", - Flags: []cli.Flag{ - &cli.StringFlag{Name: "profile-mode", Value: "rw", Usage: "Profile mode"}, - &cli.StringFlag{Name: "write-mode", Value: "seq", Usage: "Write mode"}, - &cli.StringFlag{Name: "read-mode", Value: "seq", Usage: "Read mode"}, - &cli.IntFlag{Name: "count", Value: 1000, Usage: "Item count"}, - &cli.IntFlag{Name: "batch-size", Usage: "Write batch size"}, - &cli.IntFlag{Name: "key-size", Value: 8, Usage: "Key size"}, - &cli.IntFlag{Name: "value-size", Value: 32, Usage: "Value size"}, - &cli.StringFlag{Name: "cpuprofile", Usage: "CPU profile output path"}, - &cli.StringFlag{Name: "memprofile", Usage: "Memory profile output path"}, - &cli.StringFlag{Name: "blockprofile", Usage: "Block profile output path"}, - &cli.StringFlag{Name: "stats-interval", Value: "0s", Usage: "Continuous stats interval"}, - &cli.Float64Flag{Name: "fill-percent", Value: bolt.DefaultFillPercent, Usage: "Fill percentage"}, - &cli.BoolFlag{Name: "no-sync", Usage: "Skip fsync on every commit"}, - &cli.BoolFlag{Name: "work", Usage: "Print the temp db and do not delete on exit"}, - }, - Action: func(c *cli.Context) { - statsInterval, err := time.ParseDuration(c.String("stats-interval")) - if err != nil { - fatal(err) +// Main represents the main program execution. +type Main struct { + Stdin io.Reader + Stdout io.Writer + Stderr io.Writer +} + +// NewMain returns a new instance of Main connect to the standard input/output. +func NewMain() *Main { + return &Main{ + Stdin: os.Stdin, + Stdout: os.Stdout, + Stderr: os.Stderr, + } +} + +// Run executes the program. +func (m *Main) Run(args ...string) error { + // Require a command at the beginning. + if len(args) == 0 || strings.HasPrefix(args[0], "-") { + fmt.Fprintln(m.Stderr, m.Usage()) + return ErrUsage + } + + // Execute command. + switch args[0] { + case "help": + fmt.Fprintln(m.Stderr, m.Usage()) + return ErrUsage + case "bench": + return newBenchCommand(m).Run(args[1:]...) + case "check": + return newCheckCommand(m).Run(args[1:]...) + case "dump": + return newDumpCommand(m).Run(args[1:]...) + case "info": + return newInfoCommand(m).Run(args[1:]...) + case "page": + return newPageCommand(m).Run(args[1:]...) + case "pages": + return newPagesCommand(m).Run(args[1:]...) + case "stats": + return newStatsCommand(m).Run(args[1:]...) + default: + return ErrUnknownCommand + } +} + +// Usage returns the help message. +func (m *Main) Usage() string { + return strings.TrimLeft(` +Bolt is a tool for inspecting bolt databases. + +Usage: + + bolt command [arguments] + +The commands are: + + bench run synthetic benchmark against bolt + check verifies integrity of bolt database + info print basic info + help print this screen + pages print list of pages with their types + stats iterate over all pages and generate usage stats + +Use "bolt [command] -h" for more information about a command. +`, "\n") +} + +// CheckCommand represents the "check" command execution. +type CheckCommand struct { + Stdin io.Reader + Stdout io.Writer + Stderr io.Writer +} + +// NewCheckCommand returns a CheckCommand. +func newCheckCommand(m *Main) *CheckCommand { + return &CheckCommand{ + Stdin: m.Stdin, + Stdout: m.Stdout, + Stderr: m.Stderr, + } +} + +// Run executes the command. +func (cmd *CheckCommand) Run(args ...string) error { + // Parse flags. + fs := flag.NewFlagSet("", flag.ContinueOnError) + help := fs.Bool("h", false, "") + if err := fs.Parse(args); err != nil { + return err + } else if *help { + fmt.Fprintln(cmd.Stderr, cmd.Usage()) + return ErrUsage + } + + // Require database path. + path := fs.Arg(0) + if path == "" { + return ErrPathRequired + } else if _, err := os.Stat(path); os.IsNotExist(err) { + return ErrFileNotFound + } + + // Open database. + db, err := bolt.Open(path, 0666, nil) + if err != nil { + return err + } + defer db.Close() + + // Perform consistency check. + return db.View(func(tx *bolt.Tx) error { + var count int + ch := tx.Check() + loop: + for { + select { + case err, ok := <-ch: + if !ok { + break loop } + fmt.Fprintln(cmd.Stdout, err) + count++ + } + } - Bench(&BenchOptions{ - ProfileMode: c.String("profile-mode"), - WriteMode: c.String("write-mode"), - ReadMode: c.String("read-mode"), - Iterations: c.Int("count"), - BatchSize: c.Int("batch-size"), - KeySize: c.Int("key-size"), - ValueSize: c.Int("value-size"), - CPUProfile: c.String("cpuprofile"), - MemProfile: c.String("memprofile"), - BlockProfile: c.String("blockprofile"), - StatsInterval: statsInterval, - FillPercent: c.Float64("fill-percent"), - NoSync: c.Bool("no-sync"), - Clean: !c.Bool("work"), - }) - }, - }} - return app + // Print summary of errors. + if count > 0 { + fmt.Fprintf(cmd.Stdout, "%d errors found\n", count) + return ErrCorrupt + } + + // Notify user that database is valid. + fmt.Fprintln(cmd.Stdout, "OK") + return nil + }) } -var logger = log.New(os.Stderr, "", 0) -var logBuffer *bytes.Buffer +// Usage returns the help message. +func (cmd *CheckCommand) Usage() string { + return strings.TrimLeft(` +usage: bolt check PATH -func print(v ...interface{}) { - if testMode { - logger.Print(v...) +Check opens a database at PATH and runs an exhaustive check to verify that +all pages are accessible or are marked as freed. It also verifies that no +pages are double referenced. + +Verification errors will stream out as they are found and the process will +return after all pages have been checked. +`, "\n") +} + +// InfoCommand represents the "info" command execution. +type InfoCommand struct { + Stdin io.Reader + Stdout io.Writer + Stderr io.Writer +} + +// NewInfoCommand returns a InfoCommand. +func newInfoCommand(m *Main) *InfoCommand { + return &InfoCommand{ + Stdin: m.Stdin, + Stdout: m.Stdout, + Stderr: m.Stderr, + } +} + +// Run executes the command. +func (cmd *InfoCommand) Run(args ...string) error { + // Parse flags. + fs := flag.NewFlagSet("", flag.ContinueOnError) + help := fs.Bool("h", false, "") + if err := fs.Parse(args); err != nil { + return err + } else if *help { + fmt.Fprintln(cmd.Stderr, cmd.Usage()) + return ErrUsage + } + + // Require database path. + path := fs.Arg(0) + if path == "" { + return ErrPathRequired + } else if _, err := os.Stat(path); os.IsNotExist(err) { + return ErrFileNotFound + } + + // Open the database. + db, err := bolt.Open(path, 0666, nil) + if err != nil { + return err + } + defer db.Close() + + // Print basic database info. + info := db.Info() + fmt.Fprintf(cmd.Stdout, "Page Size: %d\n", info.PageSize) + + return nil +} + +// Usage returns the help message. +func (cmd *InfoCommand) Usage() string { + return strings.TrimLeft(` +usage: bolt info PATH + +Info prints basic information about the Bolt database at PATH. +`, "\n") +} + +// DumpCommand represents the "dump" command execution. +type DumpCommand struct { + Stdin io.Reader + Stdout io.Writer + Stderr io.Writer +} + +// newDumpCommand returns a DumpCommand. +func newDumpCommand(m *Main) *DumpCommand { + return &DumpCommand{ + Stdin: m.Stdin, + Stdout: m.Stdout, + Stderr: m.Stderr, + } +} + +// Run executes the command. +func (cmd *DumpCommand) Run(args ...string) error { + // Parse flags. + fs := flag.NewFlagSet("", flag.ContinueOnError) + help := fs.Bool("h", false, "") + if err := fs.Parse(args); err != nil { + return err + } else if *help { + fmt.Fprintln(cmd.Stderr, cmd.Usage()) + return ErrUsage + } + + // Require database path and page id. + path := fs.Arg(0) + if path == "" { + return ErrPathRequired + } else if _, err := os.Stat(path); os.IsNotExist(err) { + return ErrFileNotFound + } + + // Read page ids. + pageIDs, err := atois(fs.Args()[1:]) + if err != nil { + return err + } else if len(pageIDs) == 0 { + return ErrPageIDRequired + } + + // Open database to retrieve page size. + pageSize, err := ReadPageSize(path) + if err != nil { + return err + } + + // Open database file handler. + f, err := os.Open(path) + if err != nil { + return err + } + defer func() { _ = f.Close() }() + + // Print each page listed. + for i, pageID := range pageIDs { + // Print a separator. + if i > 0 { + fmt.Fprintln(cmd.Stdout, "===============================================") + } + + // Print page to stdout. + if err := cmd.PrintPage(cmd.Stdout, f, pageID, pageSize); err != nil { + return err + } + } + + return nil +} + +// PrintPage prints a given page as hexidecimal. +func (cmd *DumpCommand) PrintPage(w io.Writer, r io.ReaderAt, pageID int, pageSize int) error { + const bytesPerLineN = 16 + + // Read page into buffer. + buf := make([]byte, pageSize) + addr := pageID * pageSize + if n, err := r.ReadAt(buf, int64(addr)); err != nil { + return err + } else if n != pageSize { + return io.ErrUnexpectedEOF + } + + // Write out to writer in 16-byte lines. + var prev []byte + var skipped bool + for offset := 0; offset < pageSize; offset += bytesPerLineN { + // Retrieve current 16-byte line. + line := buf[offset : offset+bytesPerLineN] + isLastLine := (offset == (pageSize - bytesPerLineN)) + + // If it's the same as the previous line then print a skip. + if bytes.Equal(line, prev) && !isLastLine { + if !skipped { + fmt.Fprintf(w, "%07x *\n", addr+offset) + skipped = true + } + } else { + // Print line as hexadecimal in 2-byte groups. + fmt.Fprintf(w, "%07x %04x %04x %04x %04x %04x %04x %04x %04x\n", addr+offset, + line[0:2], line[2:4], line[4:6], line[6:8], + line[8:10], line[10:12], line[12:14], line[14:16], + ) + + skipped = false + } + + // Save the previous line. + prev = line + } + fmt.Fprint(w, "\n") + + return nil +} + +// Usage returns the help message. +func (cmd *DumpCommand) Usage() string { + return strings.TrimLeft(` +usage: bolt dump -page PAGEID PATH + +Dump prints a hexidecimal dump of a single page. +`, "\n") +} + +// PageCommand represents the "page" command execution. +type PageCommand struct { + Stdin io.Reader + Stdout io.Writer + Stderr io.Writer +} + +// newPageCommand returns a PageCommand. +func newPageCommand(m *Main) *PageCommand { + return &PageCommand{ + Stdin: m.Stdin, + Stdout: m.Stdout, + Stderr: m.Stderr, + } +} + +// Run executes the command. +func (cmd *PageCommand) Run(args ...string) error { + // Parse flags. + fs := flag.NewFlagSet("", flag.ContinueOnError) + help := fs.Bool("h", false, "") + if err := fs.Parse(args); err != nil { + return err + } else if *help { + fmt.Fprintln(cmd.Stderr, cmd.Usage()) + return ErrUsage + } + + // Require database path and page id. + path := fs.Arg(0) + if path == "" { + return ErrPathRequired + } else if _, err := os.Stat(path); os.IsNotExist(err) { + return ErrFileNotFound + } + + // Read page ids. + pageIDs, err := atois(fs.Args()[1:]) + if err != nil { + return err + } else if len(pageIDs) == 0 { + return ErrPageIDRequired + } + + // Open database file handler. + f, err := os.Open(path) + if err != nil { + return err + } + defer func() { _ = f.Close() }() + + // Print each page listed. + for i, pageID := range pageIDs { + // Print a separator. + if i > 0 { + fmt.Fprintln(cmd.Stdout, "===============================================") + } + + // Retrieve page info and page size. + p, buf, err := ReadPage(path, pageID) + if err != nil { + return err + } + + // Print basic page info. + fmt.Fprintf(cmd.Stdout, "Page ID: %d\n", p.id) + fmt.Fprintf(cmd.Stdout, "Page Type: %s\n", p.Type()) + fmt.Fprintf(cmd.Stdout, "Total Size: %d bytes\n", len(buf)) + + // Print type-specific data. + switch p.Type() { + case "meta": + err = cmd.PrintMeta(cmd.Stdout, buf) + case "leaf": + err = cmd.PrintLeaf(cmd.Stdout, buf) + case "branch": + err = cmd.PrintBranch(cmd.Stdout, buf) + case "freelist": + err = cmd.PrintFreelist(cmd.Stdout, buf) + } + if err != nil { + return err + } + } + + return nil +} + +// PrintMeta prints the data from the meta page. +func (cmd *PageCommand) PrintMeta(w io.Writer, buf []byte) error { + m := (*meta)(unsafe.Pointer(&buf[PageHeaderSize])) + fmt.Fprintf(w, "Version: %d\n", m.version) + fmt.Fprintf(w, "Page Size: %d bytes\n", m.pageSize) + fmt.Fprintf(w, "Flags: %08x\n", m.flags) + fmt.Fprintf(w, "Root: \n", m.root.root) + fmt.Fprintf(w, "Freelist: \n", m.freelist) + fmt.Fprintf(w, "HWM: \n", m.pgid) + fmt.Fprintf(w, "Txn ID: %d\n", m.txid) + fmt.Fprintf(w, "Checksum: %016x\n", m.checksum) + fmt.Fprintf(w, "\n") + return nil +} + +// PrintLeaf prints the data for a leaf page. +func (cmd *PageCommand) PrintLeaf(w io.Writer, buf []byte) error { + p := (*page)(unsafe.Pointer(&buf[0])) + + // Print number of items. + fmt.Fprintf(w, "Item Count: %d\n", p.count) + fmt.Fprintf(w, "\n") + + // Print each key/value. + for i := uint16(0); i < p.count; i++ { + e := p.leafPageElement(i) + + // Format key as string. + var k string + if isPrintable(string(e.key())) { + k = fmt.Sprintf("%q", string(e.key())) + } else { + k = fmt.Sprintf("%x", string(e.key())) + } + + // Format value as string. + var v string + if (e.flags & uint32(bucketLeafFlag)) != 0 { + b := (*bucket)(unsafe.Pointer(&e.value()[0])) + v = fmt.Sprintf("", b.root, b.sequence) + } else if isPrintable(string(e.value())) { + k = fmt.Sprintf("%q", string(e.value())) + } else { + k = fmt.Sprintf("%x", string(e.value())) + } + + fmt.Fprintf(w, "%s: %s\n", k, v) + } + fmt.Fprintf(w, "\n") + return nil +} + +// PrintBranch prints the data for a leaf page. +func (cmd *PageCommand) PrintBranch(w io.Writer, buf []byte) error { + p := (*page)(unsafe.Pointer(&buf[0])) + + // Print number of items. + fmt.Fprintf(w, "Item Count: %d\n", p.count) + fmt.Fprintf(w, "\n") + + // Print each key/value. + for i := uint16(0); i < p.count; i++ { + e := p.branchPageElement(i) + + // Format key as string. + var k string + if isPrintable(string(e.key())) { + k = fmt.Sprintf("%q", string(e.key())) + } else { + k = fmt.Sprintf("%x", string(e.key())) + } + + fmt.Fprintf(w, "%s: \n", k, e.pgid) + } + fmt.Fprintf(w, "\n") + return nil +} + +// PrintFreelist prints the data for a freelist page. +func (cmd *PageCommand) PrintFreelist(w io.Writer, buf []byte) error { + p := (*page)(unsafe.Pointer(&buf[0])) + + // Print number of items. + fmt.Fprintf(w, "Item Count: %d\n", p.count) + fmt.Fprintf(w, "\n") + + // Print each page in the freelist. + ids := (*[maxAllocSize]pgid)(unsafe.Pointer(&p.ptr)) + for i := uint16(0); i < p.count; i++ { + fmt.Fprintf(w, "%d\n", ids[i]) + } + fmt.Fprintf(w, "\n") + return nil +} + +// PrintPage prints a given page as hexidecimal. +func (cmd *PageCommand) PrintPage(w io.Writer, r io.ReaderAt, pageID int, pageSize int) error { + const bytesPerLineN = 16 + + // Read page into buffer. + buf := make([]byte, pageSize) + addr := pageID * pageSize + if n, err := r.ReadAt(buf, int64(addr)); err != nil { + return err + } else if n != pageSize { + return io.ErrUnexpectedEOF + } + + // Write out to writer in 16-byte lines. + var prev []byte + var skipped bool + for offset := 0; offset < pageSize; offset += bytesPerLineN { + // Retrieve current 16-byte line. + line := buf[offset : offset+bytesPerLineN] + isLastLine := (offset == (pageSize - bytesPerLineN)) + + // If it's the same as the previous line then print a skip. + if bytes.Equal(line, prev) && !isLastLine { + if !skipped { + fmt.Fprintf(w, "%07x *\n", addr+offset) + skipped = true + } + } else { + // Print line as hexadecimal in 2-byte groups. + fmt.Fprintf(w, "%07x %04x %04x %04x %04x %04x %04x %04x %04x\n", addr+offset, + line[0:2], line[2:4], line[4:6], line[6:8], + line[8:10], line[10:12], line[12:14], line[14:16], + ) + + skipped = false + } + + // Save the previous line. + prev = line + } + fmt.Fprint(w, "\n") + + return nil +} + +// Usage returns the help message. +func (cmd *PageCommand) Usage() string { + return strings.TrimLeft(` +usage: bolt page -page PATH pageid [pageid...] + +Page prints one or more pages in human readable format. +`, "\n") +} + +// PagesCommand represents the "pages" command execution. +type PagesCommand struct { + Stdin io.Reader + Stdout io.Writer + Stderr io.Writer +} + +// NewPagesCommand returns a PagesCommand. +func newPagesCommand(m *Main) *PagesCommand { + return &PagesCommand{ + Stdin: m.Stdin, + Stdout: m.Stdout, + Stderr: m.Stderr, + } +} + +// Run executes the command. +func (cmd *PagesCommand) Run(args ...string) error { + // Parse flags. + fs := flag.NewFlagSet("", flag.ContinueOnError) + help := fs.Bool("h", false, "") + if err := fs.Parse(args); err != nil { + return err + } else if *help { + fmt.Fprintln(cmd.Stderr, cmd.Usage()) + return ErrUsage + } + + // Require database path. + path := fs.Arg(0) + if path == "" { + return ErrPathRequired + } else if _, err := os.Stat(path); os.IsNotExist(err) { + return ErrFileNotFound + } + + // Open database. + db, err := bolt.Open(path, 0666, nil) + if err != nil { + return err + } + defer func() { _ = db.Close() }() + + // Write header. + fmt.Fprintln(cmd.Stdout, "ID TYPE ITEMS OVRFLW") + fmt.Fprintln(cmd.Stdout, "======== ========== ====== ======") + + return db.Update(func(tx *bolt.Tx) error { + var id int + for { + p, err := tx.Page(id) + if err != nil { + return &PageError{ID: id, Err: err} + } else if p == nil { + break + } + + // Only display count and overflow if this is a non-free page. + var count, overflow string + if p.Type != "free" { + count = strconv.Itoa(p.Count) + if p.OverflowCount > 0 { + overflow = strconv.Itoa(p.OverflowCount) + } + } + + // Print table row. + fmt.Fprintf(cmd.Stdout, "%-8d %-10s %-6s %-6s\n", p.ID, p.Type, count, overflow) + + // Move to the next non-overflow page. + id += 1 + if p.Type != "free" { + id += p.OverflowCount + } + } + return nil + }) +} + +// Usage returns the help message. +func (cmd *PagesCommand) Usage() string { + return strings.TrimLeft(` +usage: bolt pages PATH + +Pages prints a table of pages with their type (meta, leaf, branch, freelist). +Leaf and branch pages will show a key count in the "items" column while the +freelist will show the number of free pages in the "items" column. + +The "overflow" column shows the number of blocks that the page spills over +into. Normally there is no overflow but large keys and values can cause +a single page to take up multiple blocks. +`, "\n") +} + +// StatsCommand represents the "stats" command execution. +type StatsCommand struct { + Stdin io.Reader + Stdout io.Writer + Stderr io.Writer +} + +// NewStatsCommand returns a StatsCommand. +func newStatsCommand(m *Main) *StatsCommand { + return &StatsCommand{ + Stdin: m.Stdin, + Stdout: m.Stdout, + Stderr: m.Stderr, + } +} + +// Run executes the command. +func (cmd *StatsCommand) Run(args ...string) error { + // Parse flags. + fs := flag.NewFlagSet("", flag.ContinueOnError) + help := fs.Bool("h", false, "") + if err := fs.Parse(args); err != nil { + return err + } else if *help { + fmt.Fprintln(cmd.Stderr, cmd.Usage()) + return ErrUsage + } + + // Require database path. + path, prefix := fs.Arg(0), fs.Arg(1) + if path == "" { + return ErrPathRequired + } else if _, err := os.Stat(path); os.IsNotExist(err) { + return ErrFileNotFound + } + + // Open database. + db, err := bolt.Open(path, 0666, nil) + if err != nil { + return err + } + defer db.Close() + + return db.View(func(tx *bolt.Tx) error { + var s bolt.BucketStats + var count int + if err := tx.ForEach(func(name []byte, b *bolt.Bucket) error { + if bytes.HasPrefix(name, []byte(prefix)) { + s.Add(b.Stats()) + count += 1 + } + return nil + }); err != nil { + return err + } + + fmt.Fprintf(cmd.Stdout, "Aggregate statistics for %d buckets\n\n", count) + + fmt.Fprintln(cmd.Stdout, "Page count statistics") + fmt.Fprintf(cmd.Stdout, "\tNumber of logical branch pages: %d\n", s.BranchPageN) + fmt.Fprintf(cmd.Stdout, "\tNumber of physical branch overflow pages: %d\n", s.BranchOverflowN) + fmt.Fprintf(cmd.Stdout, "\tNumber of logical leaf pages: %d\n", s.LeafPageN) + fmt.Fprintf(cmd.Stdout, "\tNumber of physical leaf overflow pages: %d\n", s.LeafOverflowN) + + fmt.Fprintln(cmd.Stdout, "Tree statistics") + fmt.Fprintf(cmd.Stdout, "\tNumber of keys/value pairs: %d\n", s.KeyN) + fmt.Fprintf(cmd.Stdout, "\tNumber of levels in B+tree: %d\n", s.Depth) + + fmt.Fprintln(cmd.Stdout, "Page size utilization") + fmt.Fprintf(cmd.Stdout, "\tBytes allocated for physical branch pages: %d\n", s.BranchAlloc) + var percentage int + if s.BranchAlloc != 0 { + percentage = int(float32(s.BranchInuse) * 100.0 / float32(s.BranchAlloc)) + } + fmt.Fprintf(cmd.Stdout, "\tBytes actually used for branch data: %d (%d%%)\n", s.BranchInuse, percentage) + fmt.Fprintf(cmd.Stdout, "\tBytes allocated for physical leaf pages: %d\n", s.LeafAlloc) + percentage = 0 + if s.LeafAlloc != 0 { + percentage = int(float32(s.LeafInuse) * 100.0 / float32(s.LeafAlloc)) + } + fmt.Fprintf(cmd.Stdout, "\tBytes actually used for leaf data: %d (%d%%)\n", s.LeafInuse, percentage) + + fmt.Fprintln(cmd.Stdout, "Bucket statistics") + fmt.Fprintf(cmd.Stdout, "\tTotal number of buckets: %d\n", s.BucketN) + percentage = 0 + if s.BucketN != 0 { + percentage = int(float32(s.InlineBucketN) * 100.0 / float32(s.BucketN)) + } + fmt.Fprintf(cmd.Stdout, "\tTotal number on inlined buckets: %d (%d%%)\n", s.InlineBucketN, percentage) + percentage = 0 + if s.LeafInuse != 0 { + percentage = int(float32(s.InlineBucketInuse) * 100.0 / float32(s.LeafInuse)) + } + fmt.Fprintf(cmd.Stdout, "\tBytes used for inlined buckets: %d (%d%%)\n", s.InlineBucketInuse, percentage) + + return nil + }) +} + +// Usage returns the help message. +func (cmd *StatsCommand) Usage() string { + return strings.TrimLeft(` +usage: bolt stats PATH + +Stats performs an extensive search of the database to track every page +reference. It starts at the current meta page and recursively iterates +through every accessible bucket. + +The following errors can be reported: + + already freed + The page is referenced more than once in the freelist. + + unreachable unfreed + The page is not referenced by a bucket or in the freelist. + + reachable freed + The page is referenced by a bucket but is also in the freelist. + + out of bounds + A page is referenced that is above the high water mark. + + multiple references + A page is referenced by more than one other page. + + invalid type + The page type is not "meta", "leaf", "branch", or "freelist". + +No errors should occur in your database. However, if for some reason you +experience corruption, please submit a ticket to the Bolt project page: + + https://github.com/boltdb/bolt/issues +`, "\n") +} + +var benchBucketName = []byte("bench") + +// BenchCommand represents the "bench" command execution. +type BenchCommand struct { + Stdin io.Reader + Stdout io.Writer + Stderr io.Writer +} + +// NewBenchCommand returns a BenchCommand using the +func newBenchCommand(m *Main) *BenchCommand { + return &BenchCommand{ + Stdin: m.Stdin, + Stdout: m.Stdout, + Stderr: m.Stderr, + } +} + +// Run executes the "bench" command. +func (cmd *BenchCommand) Run(args ...string) error { + // Parse CLI arguments. + options, err := cmd.ParseFlags(args) + if err != nil { + return err + } + + // Remove path if "-work" is not set. Otherwise keep path. + if options.Work { + fmt.Fprintf(cmd.Stdout, "work: %s\n", options.Path) } else { - fmt.Print(v...) + defer os.Remove(options.Path) + } + + // Create database. + db, err := bolt.Open(options.Path, 0666, nil) + if err != nil { + return err + } + db.NoSync = options.NoSync + defer db.Close() + + // Write to the database. + var results BenchResults + if err := cmd.runWrites(db, options, &results); err != nil { + return fmt.Errorf("write: %v", err) + } + + // Read from the database. + if err := cmd.runReads(db, options, &results); err != nil { + return fmt.Errorf("bench: read: %s", err) + } + + // Print results. + fmt.Fprintf(os.Stderr, "# Write\t%v\t(%v/op)\t(%v op/sec)\n", results.WriteDuration, results.WriteOpDuration(), results.WriteOpsPerSecond()) + fmt.Fprintf(os.Stderr, "# Read\t%v\t(%v/op)\t(%v op/sec)\n", results.ReadDuration, results.ReadOpDuration(), results.ReadOpsPerSecond()) + fmt.Fprintln(os.Stderr, "") + return nil +} + +// ParseFlags parses the command line flags. +func (cmd *BenchCommand) ParseFlags(args []string) (*BenchOptions, error) { + var options BenchOptions + + // Parse flagset. + fs := flag.NewFlagSet("", flag.ContinueOnError) + fs.StringVar(&options.ProfileMode, "profile-mode", "rw", "") + fs.StringVar(&options.WriteMode, "write-mode", "seq", "") + fs.StringVar(&options.ReadMode, "read-mode", "seq", "") + fs.IntVar(&options.Iterations, "count", 1000, "") + fs.IntVar(&options.BatchSize, "batch-size", 0, "") + fs.IntVar(&options.KeySize, "key-size", 8, "") + fs.IntVar(&options.ValueSize, "value-size", 32, "") + fs.StringVar(&options.CPUProfile, "cpuprofile", "", "") + fs.StringVar(&options.MemProfile, "memprofile", "", "") + fs.StringVar(&options.BlockProfile, "blockprofile", "", "") + fs.Float64Var(&options.FillPercent, "fill-percent", bolt.DefaultFillPercent, "") + fs.BoolVar(&options.NoSync, "no-sync", false, "") + fs.BoolVar(&options.Work, "work", false, "") + fs.StringVar(&options.Path, "path", "", "") + fs.SetOutput(cmd.Stderr) + if err := fs.Parse(args); err != nil { + return nil, err + } + + // Set batch size to iteration size if not set. + // Require that batch size can be evenly divided by the iteration count. + if options.BatchSize == 0 { + options.BatchSize = options.Iterations + } else if options.Iterations%options.BatchSize != 0 { + return nil, ErrNonDivisibleBatchSize + } + + // Generate temp path if one is not passed in. + if options.Path == "" { + f, err := ioutil.TempFile("", "bolt-bench-") + if err != nil { + return nil, fmt.Errorf("temp file: %s", err) + } + f.Close() + os.Remove(f.Name()) + options.Path = f.Name() + } + + return &options, nil +} + +// Writes to the database. +func (cmd *BenchCommand) runWrites(db *bolt.DB, options *BenchOptions, results *BenchResults) error { + // Start profiling for writes. + if options.ProfileMode == "rw" || options.ProfileMode == "w" { + cmd.startProfiling(options) + } + + t := time.Now() + + var err error + switch options.WriteMode { + case "seq": + err = cmd.runWritesSequential(db, options, results) + case "rnd": + err = cmd.runWritesRandom(db, options, results) + case "seq-nest": + err = cmd.runWritesSequentialNested(db, options, results) + case "rnd-nest": + err = cmd.runWritesRandomNested(db, options, results) + default: + return fmt.Errorf("invalid write mode: %s", options.WriteMode) + } + + // Save time to write. + results.WriteDuration = time.Since(t) + + // Stop profiling for writes only. + if options.ProfileMode == "w" { + cmd.stopProfiling() + } + + return err +} + +func (cmd *BenchCommand) runWritesSequential(db *bolt.DB, options *BenchOptions, results *BenchResults) error { + var i = uint32(0) + return cmd.runWritesWithSource(db, options, results, func() uint32 { i++; return i }) +} + +func (cmd *BenchCommand) runWritesRandom(db *bolt.DB, options *BenchOptions, results *BenchResults) error { + r := rand.New(rand.NewSource(time.Now().UnixNano())) + return cmd.runWritesWithSource(db, options, results, func() uint32 { return r.Uint32() }) +} + +func (cmd *BenchCommand) runWritesSequentialNested(db *bolt.DB, options *BenchOptions, results *BenchResults) error { + var i = uint32(0) + return cmd.runWritesWithSource(db, options, results, func() uint32 { i++; return i }) +} + +func (cmd *BenchCommand) runWritesRandomNested(db *bolt.DB, options *BenchOptions, results *BenchResults) error { + r := rand.New(rand.NewSource(time.Now().UnixNano())) + return cmd.runWritesWithSource(db, options, results, func() uint32 { return r.Uint32() }) +} + +func (cmd *BenchCommand) runWritesWithSource(db *bolt.DB, options *BenchOptions, results *BenchResults, keySource func() uint32) error { + results.WriteOps = options.Iterations + + for i := 0; i < options.Iterations; i += options.BatchSize { + if err := db.Update(func(tx *bolt.Tx) error { + b, _ := tx.CreateBucketIfNotExists(benchBucketName) + b.FillPercent = options.FillPercent + + for j := 0; j < options.BatchSize; j++ { + key := make([]byte, options.KeySize) + value := make([]byte, options.ValueSize) + + // Write key as uint32. + binary.BigEndian.PutUint32(key, keySource()) + + // Insert key/value. + if err := b.Put(key, value); err != nil { + return err + } + } + + return nil + }); err != nil { + return err + } + } + return nil +} + +func (cmd *BenchCommand) runWritesNestedWithSource(db *bolt.DB, options *BenchOptions, results *BenchResults, keySource func() uint32) error { + results.WriteOps = options.Iterations + + for i := 0; i < options.Iterations; i += options.BatchSize { + if err := db.Update(func(tx *bolt.Tx) error { + top, err := tx.CreateBucketIfNotExists(benchBucketName) + if err != nil { + return err + } + top.FillPercent = options.FillPercent + + // Create bucket key. + name := make([]byte, options.KeySize) + binary.BigEndian.PutUint32(name, keySource()) + + // Create bucket. + b, err := top.CreateBucketIfNotExists(name) + if err != nil { + return err + } + b.FillPercent = options.FillPercent + + for j := 0; j < options.BatchSize; j++ { + var key = make([]byte, options.KeySize) + var value = make([]byte, options.ValueSize) + + // Generate key as uint32. + binary.BigEndian.PutUint32(key, keySource()) + + // Insert value into subbucket. + if err := b.Put(key, value); err != nil { + return err + } + } + + return nil + }); err != nil { + return err + } + } + return nil +} + +// Reads from the database. +func (cmd *BenchCommand) runReads(db *bolt.DB, options *BenchOptions, results *BenchResults) error { + // Start profiling for reads. + if options.ProfileMode == "r" { + cmd.startProfiling(options) + } + + t := time.Now() + + var err error + switch options.ReadMode { + case "seq": + switch options.WriteMode { + case "seq-nest", "rnd-nest": + err = cmd.runReadsSequentialNested(db, options, results) + default: + err = cmd.runReadsSequential(db, options, results) + } + default: + return fmt.Errorf("invalid read mode: %s", options.ReadMode) + } + + // Save read time. + results.ReadDuration = time.Since(t) + + // Stop profiling for reads. + if options.ProfileMode == "rw" || options.ProfileMode == "r" { + cmd.stopProfiling() + } + + return err +} + +func (cmd *BenchCommand) runReadsSequential(db *bolt.DB, options *BenchOptions, results *BenchResults) error { + return db.View(func(tx *bolt.Tx) error { + t := time.Now() + + for { + var count int + + c := tx.Bucket(benchBucketName).Cursor() + for k, v := c.First(); k != nil; k, v = c.Next() { + if v == nil { + return errors.New("invalid value") + } + count++ + } + + if options.WriteMode == "seq" && count != options.Iterations { + return fmt.Errorf("read seq: iter mismatch: expected %d, got %d", options.Iterations, count) + } + + results.ReadOps += count + + // Make sure we do this for at least a second. + if time.Since(t) >= time.Second { + break + } + } + + return nil + }) +} + +func (cmd *BenchCommand) runReadsSequentialNested(db *bolt.DB, options *BenchOptions, results *BenchResults) error { + return db.View(func(tx *bolt.Tx) error { + t := time.Now() + + for { + var count int + var top = tx.Bucket(benchBucketName) + if err := top.ForEach(func(name, _ []byte) error { + c := top.Bucket(name).Cursor() + for k, v := c.First(); k != nil; k, v = c.Next() { + if v == nil { + return ErrInvalidValue + } + count++ + } + return nil + }); err != nil { + return err + } + + if options.WriteMode == "seq-nest" && count != options.Iterations { + return fmt.Errorf("read seq-nest: iter mismatch: expected %d, got %d", options.Iterations, count) + } + + results.ReadOps += count + + // Make sure we do this for at least a second. + if time.Since(t) >= time.Second { + break + } + } + + return nil + }) +} + +// File handlers for the various profiles. +var cpuprofile, memprofile, blockprofile *os.File + +// Starts all profiles set on the options. +func (cmd *BenchCommand) startProfiling(options *BenchOptions) { + var err error + + // Start CPU profiling. + if options.CPUProfile != "" { + cpuprofile, err = os.Create(options.CPUProfile) + if err != nil { + fmt.Fprintf(cmd.Stderr, "bench: could not create cpu profile %q: %v\n", options.CPUProfile, err) + os.Exit(1) + } + pprof.StartCPUProfile(cpuprofile) + } + + // Start memory profiling. + if options.MemProfile != "" { + memprofile, err = os.Create(options.MemProfile) + if err != nil { + fmt.Fprintf(cmd.Stderr, "bench: could not create memory profile %q: %v\n", options.MemProfile, err) + os.Exit(1) + } + runtime.MemProfileRate = 4096 + } + + // Start fatal profiling. + if options.BlockProfile != "" { + blockprofile, err = os.Create(options.BlockProfile) + if err != nil { + fmt.Fprintf(cmd.Stderr, "bench: could not create block profile %q: %v\n", options.BlockProfile, err) + os.Exit(1) + } + runtime.SetBlockProfileRate(1) } } -func printf(format string, v ...interface{}) { - if testMode { - logger.Printf(format, v...) - } else { - fmt.Printf(format, v...) +// Stops all profiles. +func (cmd *BenchCommand) stopProfiling() { + if cpuprofile != nil { + pprof.StopCPUProfile() + cpuprofile.Close() + cpuprofile = nil + } + + if memprofile != nil { + pprof.Lookup("heap").WriteTo(memprofile, 0) + memprofile.Close() + memprofile = nil + } + + if blockprofile != nil { + pprof.Lookup("block").WriteTo(blockprofile, 0) + blockprofile.Close() + blockprofile = nil + runtime.SetBlockProfileRate(0) } } -func println(v ...interface{}) { - if testMode { - logger.Println(v...) - } else { - fmt.Println(v...) - } +// BenchOptions represents the set of options that can be passed to "bolt bench". +type BenchOptions struct { + ProfileMode string + WriteMode string + ReadMode string + Iterations int + BatchSize int + KeySize int + ValueSize int + CPUProfile string + MemProfile string + BlockProfile string + StatsInterval time.Duration + FillPercent float64 + NoSync bool + Work bool + Path string } -func fatal(v ...interface{}) { - logger.Print(v...) - if !testMode { - os.Exit(1) - } +// BenchResults represents the performance results of the benchmark. +type BenchResults struct { + WriteOps int + WriteDuration time.Duration + ReadOps int + ReadDuration time.Duration } -func fatalf(format string, v ...interface{}) { - logger.Printf(format, v...) - if !testMode { - os.Exit(1) +// Returns the duration for a single write operation. +func (r *BenchResults) WriteOpDuration() time.Duration { + if r.WriteOps == 0 { + return 0 } + return r.WriteDuration / time.Duration(r.WriteOps) } -func fatalln(v ...interface{}) { - logger.Println(v...) - if !testMode { - os.Exit(1) +// Returns average number of write operations that can be performed per second. +func (r *BenchResults) WriteOpsPerSecond() int { + var op = r.WriteOpDuration() + if op == 0 { + return 0 } + return int(time.Second) / int(op) } -// LogBuffer returns the contents of the log. -// This only works while the CLI is in test mode. -func LogBuffer() string { - if logBuffer != nil { - return logBuffer.String() +// Returns the duration for a single read operation. +func (r *BenchResults) ReadOpDuration() time.Duration { + if r.ReadOps == 0 { + return 0 } - return "" + return r.ReadDuration / time.Duration(r.ReadOps) } -var testMode bool - -// SetTestMode sets whether the CLI is running in test mode and resets the logger. -func SetTestMode(value bool) { - testMode = value - if testMode { - logBuffer = bytes.NewBuffer(nil) - logger = log.New(logBuffer, "", 0) - } else { - logger = log.New(os.Stderr, "", 0) +// Returns average number of read operations that can be performed per second. +func (r *BenchResults) ReadOpsPerSecond() int { + var op = r.ReadOpDuration() + if op == 0 { + return 0 } + return int(time.Second) / int(op) +} + +type PageError struct { + ID int + Err error +} + +func (e *PageError) Error() string { + return fmt.Sprintf("page error: id=%d, err=%s", e.ID, e.Err) +} + +// isPrintable returns true if the string is valid unicode and contains only printable runes. +func isPrintable(s string) bool { + if !utf8.ValidString(s) { + return false + } + for _, ch := range s { + if !unicode.IsPrint(ch) { + return false + } + } + return true +} + +// ReadPage reads page info & full page data from a path. +// This is not transactionally safe. +func ReadPage(path string, pageID int) (*page, []byte, error) { + // Find page size. + pageSize, err := ReadPageSize(path) + if err != nil { + return nil, nil, fmt.Errorf("read page size: %s", err) + } + + // Open database file. + f, err := os.Open(path) + if err != nil { + return nil, nil, err + } + defer f.Close() + + // Read one block into buffer. + buf := make([]byte, pageSize) + if n, err := f.ReadAt(buf, int64(pageID*pageSize)); err != nil { + return nil, nil, err + } else if n != len(buf) { + return nil, nil, io.ErrUnexpectedEOF + } + + // Determine total number of blocks. + p := (*page)(unsafe.Pointer(&buf[0])) + overflowN := p.overflow + + // Re-read entire page (with overflow) into buffer. + buf = make([]byte, (int(overflowN)+1)*pageSize) + if n, err := f.ReadAt(buf, int64(pageID*pageSize)); err != nil { + return nil, nil, err + } else if n != len(buf) { + return nil, nil, io.ErrUnexpectedEOF + } + p = (*page)(unsafe.Pointer(&buf[0])) + + return p, buf, nil +} + +// ReadPageSize reads page size a path. +// This is not transactionally safe. +func ReadPageSize(path string) (int, error) { + // Open database file. + f, err := os.Open(path) + if err != nil { + return 0, err + } + defer f.Close() + + // Read 4KB chunk. + buf := make([]byte, 4096) + if _, err := io.ReadFull(f, buf); err != nil { + return 0, err + } + + // Read page size from metadata. + m := (*meta)(unsafe.Pointer(&buf[PageHeaderSize])) + return int(m.pageSize), nil +} + +// atois parses a slice of strings into integers. +func atois(strs []string) ([]int, error) { + var a []int + for _, str := range strs { + i, err := strconv.Atoi(str) + if err != nil { + return nil, err + } + a = append(a, i) + } + return a, nil +} + +// DO NOT EDIT. Copied from the "bolt" package. +const maxAllocSize = 0xFFFFFFF + +// DO NOT EDIT. Copied from the "bolt" package. +const ( + branchPageFlag = 0x01 + leafPageFlag = 0x02 + metaPageFlag = 0x04 + freelistPageFlag = 0x10 +) + +// DO NOT EDIT. Copied from the "bolt" package. +const bucketLeafFlag = 0x01 + +// DO NOT EDIT. Copied from the "bolt" package. +type pgid uint64 + +// DO NOT EDIT. Copied from the "bolt" package. +type txid uint64 + +// DO NOT EDIT. Copied from the "bolt" package. +type meta struct { + magic uint32 + version uint32 + pageSize uint32 + flags uint32 + root bucket + freelist pgid + pgid pgid + txid txid + checksum uint64 +} + +// DO NOT EDIT. Copied from the "bolt" package. +type bucket struct { + root pgid + sequence uint64 +} + +// DO NOT EDIT. Copied from the "bolt" package. +type page struct { + id pgid + flags uint16 + count uint16 + overflow uint32 + ptr uintptr +} + +// DO NOT EDIT. Copied from the "bolt" package. +func (p *page) Type() string { + if (p.flags & branchPageFlag) != 0 { + return "branch" + } else if (p.flags & leafPageFlag) != 0 { + return "leaf" + } else if (p.flags & metaPageFlag) != 0 { + return "meta" + } else if (p.flags & freelistPageFlag) != 0 { + return "freelist" + } + return fmt.Sprintf("unknown<%02x>", p.flags) +} + +// DO NOT EDIT. Copied from the "bolt" package. +func (p *page) leafPageElement(index uint16) *leafPageElement { + n := &((*[0x7FFFFFF]leafPageElement)(unsafe.Pointer(&p.ptr)))[index] + return n +} + +// DO NOT EDIT. Copied from the "bolt" package. +func (p *page) branchPageElement(index uint16) *branchPageElement { + return &((*[0x7FFFFFF]branchPageElement)(unsafe.Pointer(&p.ptr)))[index] +} + +// DO NOT EDIT. Copied from the "bolt" package. +type branchPageElement struct { + pos uint32 + ksize uint32 + pgid pgid +} + +// DO NOT EDIT. Copied from the "bolt" package. +func (n *branchPageElement) key() []byte { + buf := (*[maxAllocSize]byte)(unsafe.Pointer(n)) + return buf[n.pos : n.pos+n.ksize] +} + +// DO NOT EDIT. Copied from the "bolt" package. +type leafPageElement struct { + flags uint32 + pos uint32 + ksize uint32 + vsize uint32 +} + +// DO NOT EDIT. Copied from the "bolt" package. +func (n *leafPageElement) key() []byte { + buf := (*[maxAllocSize]byte)(unsafe.Pointer(n)) + return buf[n.pos : n.pos+n.ksize] +} + +// DO NOT EDIT. Copied from the "bolt" package. +func (n *leafPageElement) value() []byte { + buf := (*[maxAllocSize]byte)(unsafe.Pointer(n)) + return buf[n.pos+n.ksize : n.pos+n.ksize+n.vsize] } diff --git a/Godeps/_workspace/src/github.com/boltdb/bolt/cmd/bolt/main_test.go b/Godeps/_workspace/src/github.com/boltdb/bolt/cmd/bolt/main_test.go index d9684349..7ddf1cc2 100644 --- a/Godeps/_workspace/src/github.com/boltdb/bolt/cmd/bolt/main_test.go +++ b/Godeps/_workspace/src/github.com/boltdb/bolt/cmd/bolt/main_test.go @@ -1,69 +1,185 @@ package main_test import ( - "fmt" + "bytes" "io/ioutil" "os" - "path/filepath" - "reflect" - "runtime" - "strings" + "strconv" "testing" "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/boltdb/bolt" - . "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/boltdb/bolt/cmd/bolt" + "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/boltdb/bolt/cmd/bolt" ) -// open creates and opens a Bolt database in the temp directory. -func open(fn func(*bolt.DB, string)) { - path := tempfile() - defer os.RemoveAll(path) +// Ensure the "info" command can print information about a database. +func TestInfoCommand_Run(t *testing.T) { + db := MustOpen(0666, nil) + db.DB.Close() + defer db.Close() - db, err := bolt.Open(path, 0600, nil) - if err != nil { - panic("db open error: " + err.Error()) + // Run the info command. + m := NewMain() + if err := m.Run("info", db.Path); err != nil { + t.Fatal(err) } - fn(db, path) } -// run executes a command against the CLI and returns the output. -func run(args ...string) string { - args = append([]string{"bolt"}, args...) - NewApp().Run(args) - return strings.TrimSpace(LogBuffer()) +// Ensure the "stats" command executes correctly with an empty database. +func TestStatsCommand_Run_EmptyDatabase(t *testing.T) { + // Ignore + if os.Getpagesize() != 4096 { + t.Skip("system does not use 4KB page size") + } + + db := MustOpen(0666, nil) + defer db.Close() + db.DB.Close() + + // Generate expected result. + exp := "Aggregate statistics for 0 buckets\n\n" + + "Page count statistics\n" + + "\tNumber of logical branch pages: 0\n" + + "\tNumber of physical branch overflow pages: 0\n" + + "\tNumber of logical leaf pages: 0\n" + + "\tNumber of physical leaf overflow pages: 0\n" + + "Tree statistics\n" + + "\tNumber of keys/value pairs: 0\n" + + "\tNumber of levels in B+tree: 0\n" + + "Page size utilization\n" + + "\tBytes allocated for physical branch pages: 0\n" + + "\tBytes actually used for branch data: 0 (0%)\n" + + "\tBytes allocated for physical leaf pages: 0\n" + + "\tBytes actually used for leaf data: 0 (0%)\n" + + "Bucket statistics\n" + + "\tTotal number of buckets: 0\n" + + "\tTotal number on inlined buckets: 0 (0%)\n" + + "\tBytes used for inlined buckets: 0 (0%)\n" + + // Run the command. + m := NewMain() + if err := m.Run("stats", db.Path); err != nil { + t.Fatal(err) + } else if m.Stdout.String() != exp { + t.Fatalf("unexpected stdout:\n\n%s", m.Stdout.String()) + } } -// tempfile returns a temporary file path. -func tempfile() string { +// Ensure the "stats" command can execute correctly. +func TestStatsCommand_Run(t *testing.T) { + // Ignore + if os.Getpagesize() != 4096 { + t.Skip("system does not use 4KB page size") + } + + db := MustOpen(0666, nil) + defer db.Close() + + if err := db.Update(func(tx *bolt.Tx) error { + // Create "foo" bucket. + b, err := tx.CreateBucket([]byte("foo")) + if err != nil { + return err + } + for i := 0; i < 10; i++ { + if err := b.Put([]byte(strconv.Itoa(i)), []byte(strconv.Itoa(i))); err != nil { + return err + } + } + + // Create "bar" bucket. + b, err = tx.CreateBucket([]byte("bar")) + if err != nil { + return err + } + for i := 0; i < 100; i++ { + if err := b.Put([]byte(strconv.Itoa(i)), []byte(strconv.Itoa(i))); err != nil { + return err + } + } + + // Create "baz" bucket. + b, err = tx.CreateBucket([]byte("baz")) + if err != nil { + return err + } + if err := b.Put([]byte("key"), []byte("value")); err != nil { + return err + } + + return nil + }); err != nil { + t.Fatal(err) + } + db.DB.Close() + + // Generate expected result. + exp := "Aggregate statistics for 3 buckets\n\n" + + "Page count statistics\n" + + "\tNumber of logical branch pages: 0\n" + + "\tNumber of physical branch overflow pages: 0\n" + + "\tNumber of logical leaf pages: 1\n" + + "\tNumber of physical leaf overflow pages: 0\n" + + "Tree statistics\n" + + "\tNumber of keys/value pairs: 111\n" + + "\tNumber of levels in B+tree: 1\n" + + "Page size utilization\n" + + "\tBytes allocated for physical branch pages: 0\n" + + "\tBytes actually used for branch data: 0 (0%)\n" + + "\tBytes allocated for physical leaf pages: 4096\n" + + "\tBytes actually used for leaf data: 1996 (48%)\n" + + "Bucket statistics\n" + + "\tTotal number of buckets: 3\n" + + "\tTotal number on inlined buckets: 2 (66%)\n" + + "\tBytes used for inlined buckets: 236 (11%)\n" + + // Run the command. + m := NewMain() + if err := m.Run("stats", db.Path); err != nil { + t.Fatal(err) + } else if m.Stdout.String() != exp { + t.Fatalf("unexpected stdout:\n\n%s", m.Stdout.String()) + } +} + +// Main represents a test wrapper for main.Main that records output. +type Main struct { + *main.Main + Stdin bytes.Buffer + Stdout bytes.Buffer + Stderr bytes.Buffer +} + +// NewMain returns a new instance of Main. +func NewMain() *Main { + m := &Main{Main: main.NewMain()} + m.Main.Stdin = &m.Stdin + m.Main.Stdout = &m.Stdout + m.Main.Stderr = &m.Stderr + return m +} + +// MustOpen creates a Bolt database in a temporary location. +func MustOpen(mode os.FileMode, options *bolt.Options) *DB { + // Create temporary path. f, _ := ioutil.TempFile("", "bolt-") f.Close() os.Remove(f.Name()) - return f.Name() -} -// assert fails the test if the condition is false. -func assert(tb testing.TB, condition bool, msg string, v ...interface{}) { - if !condition { - _, file, line, _ := runtime.Caller(1) - fmt.Printf("\033[31m%s:%d: "+msg+"\033[39m\n\n", append([]interface{}{filepath.Base(file), line}, v...)...) - tb.FailNow() - } -} - -// ok fails the test if an err is not nil. -func ok(tb testing.TB, err error) { + db, err := bolt.Open(f.Name(), mode, options) if err != nil { - _, file, line, _ := runtime.Caller(1) - fmt.Printf("\033[31m%s:%d: unexpected error: %s\033[39m\n\n", filepath.Base(file), line, err.Error()) - tb.FailNow() + panic(err.Error()) } + return &DB{DB: db, Path: f.Name()} } -// equals fails the test if exp is not equal to act. -func equals(tb testing.TB, exp, act interface{}) { - if !reflect.DeepEqual(exp, act) { - _, file, line, _ := runtime.Caller(1) - fmt.Printf("\033[31m%s:%d:\n\n\texp: %#v\n\n\tgot: %#v\033[39m\n\n", filepath.Base(file), line, exp, act) - tb.FailNow() - } +// DB is a test wrapper for bolt.DB. +type DB struct { + *bolt.DB + Path string +} + +// Close closes and removes the database. +func (db *DB) Close() error { + defer os.Remove(db.Path) + return db.DB.Close() } diff --git a/Godeps/_workspace/src/github.com/boltdb/bolt/cmd/bolt/pages.go b/Godeps/_workspace/src/github.com/boltdb/bolt/cmd/bolt/pages.go deleted file mode 100644 index a8acf068..00000000 --- a/Godeps/_workspace/src/github.com/boltdb/bolt/cmd/bolt/pages.go +++ /dev/null @@ -1,57 +0,0 @@ -package main - -import ( - "os" - "strconv" - - "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/boltdb/bolt" -) - -// Pages prints a list of all pages in a database. -func Pages(path string) { - if _, err := os.Stat(path); os.IsNotExist(err) { - fatal(err) - return - } - - db, err := bolt.Open(path, 0600, nil) - if err != nil { - fatal(err) - return - } - defer db.Close() - - println("ID TYPE ITEMS OVRFLW") - println("======== ========== ====== ======") - - db.Update(func(tx *bolt.Tx) error { - var id int - for { - p, err := tx.Page(id) - if err != nil { - fatalf("page error: %d: %s", id, err) - } else if p == nil { - break - } - - // Only display count and overflow if this is a non-free page. - var count, overflow string - if p.Type != "free" { - count = strconv.Itoa(p.Count) - if p.OverflowCount > 0 { - overflow = strconv.Itoa(p.OverflowCount) - } - } - - // Print table row. - printf("%-8d %-10s %-6s %-6s\n", p.ID, p.Type, count, overflow) - - // Move to the next non-overflow page. - id += 1 - if p.Type != "free" { - id += p.OverflowCount - } - } - return nil - }) -} diff --git a/Godeps/_workspace/src/github.com/boltdb/bolt/cmd/bolt/stats.go b/Godeps/_workspace/src/github.com/boltdb/bolt/cmd/bolt/stats.go deleted file mode 100644 index b6805e21..00000000 --- a/Godeps/_workspace/src/github.com/boltdb/bolt/cmd/bolt/stats.go +++ /dev/null @@ -1,77 +0,0 @@ -package main - -import ( - "bytes" - "os" - - "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/boltdb/bolt" -) - -// Collect stats for all top level buckets matching the prefix. -func Stats(path, prefix string) { - if _, err := os.Stat(path); os.IsNotExist(err) { - fatal(err) - return - } - - db, err := bolt.Open(path, 0600, nil) - if err != nil { - fatal(err) - return - } - defer db.Close() - - err = db.View(func(tx *bolt.Tx) error { - var s bolt.BucketStats - var count int - var prefix = []byte(prefix) - tx.ForEach(func(name []byte, b *bolt.Bucket) error { - if bytes.HasPrefix(name, prefix) { - s.Add(b.Stats()) - count += 1 - } - return nil - }) - printf("Aggregate statistics for %d buckets\n\n", count) - - println("Page count statistics") - printf("\tNumber of logical branch pages: %d\n", s.BranchPageN) - printf("\tNumber of physical branch overflow pages: %d\n", s.BranchOverflowN) - printf("\tNumber of logical leaf pages: %d\n", s.LeafPageN) - printf("\tNumber of physical leaf overflow pages: %d\n", s.LeafOverflowN) - - println("Tree statistics") - printf("\tNumber of keys/value pairs: %d\n", s.KeyN) - printf("\tNumber of levels in B+tree: %d\n", s.Depth) - - println("Page size utilization") - printf("\tBytes allocated for physical branch pages: %d\n", s.BranchAlloc) - var percentage int - if s.BranchAlloc != 0 { - percentage = int(float32(s.BranchInuse) * 100.0 / float32(s.BranchAlloc)) - } - printf("\tBytes actually used for branch data: %d (%d%%)\n", s.BranchInuse, percentage) - printf("\tBytes allocated for physical leaf pages: %d\n", s.LeafAlloc) - percentage = 0 - if s.LeafAlloc != 0 { - percentage = int(float32(s.LeafInuse) * 100.0 / float32(s.LeafAlloc)) - } - printf("\tBytes actually used for leaf data: %d (%d%%)\n", s.LeafInuse, percentage) - - println("Bucket statistics") - printf("\tTotal number of buckets: %d\n", s.BucketN) - percentage = int(float32(s.InlineBucketN) * 100.0 / float32(s.BucketN)) - printf("\tTotal number on inlined buckets: %d (%d%%)\n", s.InlineBucketN, percentage) - percentage = 0 - if s.LeafInuse != 0 { - percentage = int(float32(s.InlineBucketInuse) * 100.0 / float32(s.LeafInuse)) - } - printf("\tBytes used for inlined buckets: %d (%d%%)\n", s.InlineBucketInuse, percentage) - - return nil - }) - if err != nil { - fatal(err) - return - } -} diff --git a/Godeps/_workspace/src/github.com/boltdb/bolt/cmd/bolt/stats_test.go b/Godeps/_workspace/src/github.com/boltdb/bolt/cmd/bolt/stats_test.go deleted file mode 100644 index 244c4776..00000000 --- a/Godeps/_workspace/src/github.com/boltdb/bolt/cmd/bolt/stats_test.go +++ /dev/null @@ -1,61 +0,0 @@ -package main_test - -import ( - "os" - "strconv" - "testing" - - "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/boltdb/bolt" - . "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/boltdb/bolt/cmd/bolt" -) - -func TestStats(t *testing.T) { - if os.Getpagesize() != 4096 { - t.Skip() - } - SetTestMode(true) - open(func(db *bolt.DB, path string) { - db.Update(func(tx *bolt.Tx) error { - b, err := tx.CreateBucket([]byte("foo")) - if err != nil { - return err - } - for i := 0; i < 10; i++ { - b.Put([]byte(strconv.Itoa(i)), []byte(strconv.Itoa(i))) - } - b, err = tx.CreateBucket([]byte("bar")) - if err != nil { - return err - } - for i := 0; i < 100; i++ { - b.Put([]byte(strconv.Itoa(i)), []byte(strconv.Itoa(i))) - } - b, err = tx.CreateBucket([]byte("baz")) - if err != nil { - return err - } - b.Put([]byte("key"), []byte("value")) - return nil - }) - db.Close() - output := run("stats", path, "b") - equals(t, "Aggregate statistics for 2 buckets\n\n"+ - "Page count statistics\n"+ - "\tNumber of logical branch pages: 0\n"+ - "\tNumber of physical branch overflow pages: 0\n"+ - "\tNumber of logical leaf pages: 1\n"+ - "\tNumber of physical leaf overflow pages: 0\n"+ - "Tree statistics\n"+ - "\tNumber of keys/value pairs: 101\n"+ - "\tNumber of levels in B+tree: 1\n"+ - "Page size utilization\n"+ - "\tBytes allocated for physical branch pages: 0\n"+ - "\tBytes actually used for branch data: 0 (0%)\n"+ - "\tBytes allocated for physical leaf pages: 4096\n"+ - "\tBytes actually used for leaf data: 1996 (48%)\n"+ - "Bucket statistics\n"+ - "\tTotal number of buckets: 2\n"+ - "\tTotal number on inlined buckets: 1 (50%)\n"+ - "\tBytes used for inlined buckets: 40 (2%)", output) - }) -} diff --git a/Godeps/_workspace/src/github.com/boltdb/bolt/cursor.go b/Godeps/_workspace/src/github.com/boltdb/bolt/cursor.go index 0d8ed165..1be9f35e 100644 --- a/Godeps/_workspace/src/github.com/boltdb/bolt/cursor.go +++ b/Godeps/_workspace/src/github.com/boltdb/bolt/cursor.go @@ -10,6 +10,8 @@ import ( // Cursors see nested buckets with value == nil. // Cursors can be obtained from a transaction and are valid as long as the transaction is open. // +// Keys and values returned from the cursor are only valid for the life of the transaction. +// // Changing data while traversing with a cursor may cause it to be invalidated // and return unexpected keys and/or values. You must reposition your cursor // after mutating data. @@ -25,12 +27,20 @@ func (c *Cursor) Bucket() *Bucket { // First moves the cursor to the first item in the bucket and returns its key and value. // If the bucket is empty then a nil key and value are returned. +// The returned key and value are only valid for the life of the transaction. func (c *Cursor) First() (key []byte, value []byte) { _assert(c.bucket.tx.db != nil, "tx closed") c.stack = c.stack[:0] p, n := c.bucket.pageNode(c.bucket.root) c.stack = append(c.stack, elemRef{page: p, node: n, index: 0}) c.first() + + // If we land on an empty page then move to the next value. + // https://github.com/boltdb/bolt/issues/450 + if c.stack[len(c.stack)-1].count() == 0 { + c.next() + } + k, v, flags := c.keyValue() if (flags & uint32(bucketLeafFlag)) != 0 { return k, nil @@ -41,6 +51,7 @@ func (c *Cursor) First() (key []byte, value []byte) { // Last moves the cursor to the last item in the bucket and returns its key and value. // If the bucket is empty then a nil key and value are returned. +// The returned key and value are only valid for the life of the transaction. func (c *Cursor) Last() (key []byte, value []byte) { _assert(c.bucket.tx.db != nil, "tx closed") c.stack = c.stack[:0] @@ -58,6 +69,7 @@ func (c *Cursor) Last() (key []byte, value []byte) { // Next moves the cursor to the next item in the bucket and returns its key and value. // If the cursor is at the end of the bucket then a nil key and value are returned. +// The returned key and value are only valid for the life of the transaction. func (c *Cursor) Next() (key []byte, value []byte) { _assert(c.bucket.tx.db != nil, "tx closed") k, v, flags := c.next() @@ -69,6 +81,7 @@ func (c *Cursor) Next() (key []byte, value []byte) { // Prev moves the cursor to the previous item in the bucket and returns its key and value. // If the cursor is at the beginning of the bucket then a nil key and value are returned. +// The returned key and value are only valid for the life of the transaction. func (c *Cursor) Prev() (key []byte, value []byte) { _assert(c.bucket.tx.db != nil, "tx closed") @@ -100,6 +113,7 @@ func (c *Cursor) Prev() (key []byte, value []byte) { // Seek moves the cursor to a given key and returns it. // If the key does not exist then the next key is used. If no keys // follow, a nil key is returned. +// The returned key and value are only valid for the life of the transaction. func (c *Cursor) Seek(seek []byte) (key []byte, value []byte) { k, v, flags := c.seek(seek) @@ -202,28 +216,37 @@ func (c *Cursor) last() { // next moves to the next leaf element and returns the key and value. // If the cursor is at the last leaf element then it stays there and returns nil. func (c *Cursor) next() (key []byte, value []byte, flags uint32) { - // Attempt to move over one element until we're successful. - // Move up the stack as we hit the end of each page in our stack. - var i int - for i = len(c.stack) - 1; i >= 0; i-- { - elem := &c.stack[i] - if elem.index < elem.count()-1 { - elem.index++ - break + for { + // Attempt to move over one element until we're successful. + // Move up the stack as we hit the end of each page in our stack. + var i int + for i = len(c.stack) - 1; i >= 0; i-- { + elem := &c.stack[i] + if elem.index < elem.count()-1 { + elem.index++ + break + } } - } - // If we've hit the root page then stop and return. This will leave the - // cursor on the last element of the last page. - if i == -1 { - return nil, nil, 0 - } + // If we've hit the root page then stop and return. This will leave the + // cursor on the last element of the last page. + if i == -1 { + return nil, nil, 0 + } - // Otherwise start from where we left off in the stack and find the - // first element of the first leaf page. - c.stack = c.stack[:i+1] - c.first() - return c.keyValue() + // Otherwise start from where we left off in the stack and find the + // first element of the first leaf page. + c.stack = c.stack[:i+1] + c.first() + + // If this is an empty page then restart and move back up the stack. + // https://github.com/boltdb/bolt/issues/450 + if c.stack[len(c.stack)-1].count() == 0 { + continue + } + + return c.keyValue() + } } // search recursively performs a binary search against a given page/node until it finds a given key. diff --git a/Godeps/_workspace/src/github.com/boltdb/bolt/cursor_test.go b/Godeps/_workspace/src/github.com/boltdb/bolt/cursor_test.go index 2b5648de..ac1922da 100644 --- a/Godeps/_workspace/src/github.com/boltdb/bolt/cursor_test.go +++ b/Godeps/_workspace/src/github.com/boltdb/bolt/cursor_test.go @@ -4,7 +4,9 @@ import ( "bytes" "encoding/binary" "fmt" + "log" "os" + "reflect" "sort" "testing" "testing/quick" @@ -14,100 +16,149 @@ import ( // Ensure that a cursor can return a reference to the bucket that created it. func TestCursor_Bucket(t *testing.T) { - db := NewTestDB() - defer db.Close() - db.Update(func(tx *bolt.Tx) error { - b, _ := tx.CreateBucket([]byte("widgets")) - c := b.Cursor() - equals(t, b, c.Bucket()) + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + if cb := b.Cursor().Bucket(); !reflect.DeepEqual(cb, b) { + t.Fatal("cursor bucket mismatch") + } return nil - }) + }); err != nil { + t.Fatal(err) + } } // Ensure that a Tx cursor can seek to the appropriate keys. func TestCursor_Seek(t *testing.T) { - db := NewTestDB() - defer db.Close() - db.Update(func(tx *bolt.Tx) error { + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { b, err := tx.CreateBucket([]byte("widgets")) - ok(t, err) - ok(t, b.Put([]byte("foo"), []byte("0001"))) - ok(t, b.Put([]byte("bar"), []byte("0002"))) - ok(t, b.Put([]byte("baz"), []byte("0003"))) - _, err = b.CreateBucket([]byte("bkt")) - ok(t, err) + if err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("foo"), []byte("0001")); err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("bar"), []byte("0002")); err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("baz"), []byte("0003")); err != nil { + t.Fatal(err) + } + + if _, err := b.CreateBucket([]byte("bkt")); err != nil { + t.Fatal(err) + } return nil - }) - db.View(func(tx *bolt.Tx) error { + }); err != nil { + t.Fatal(err) + } + + if err := db.View(func(tx *bolt.Tx) error { c := tx.Bucket([]byte("widgets")).Cursor() // Exact match should go to the key. - k, v := c.Seek([]byte("bar")) - equals(t, []byte("bar"), k) - equals(t, []byte("0002"), v) + if k, v := c.Seek([]byte("bar")); !bytes.Equal(k, []byte("bar")) { + t.Fatalf("unexpected key: %v", k) + } else if !bytes.Equal(v, []byte("0002")) { + t.Fatalf("unexpected value: %v", v) + } // Inexact match should go to the next key. - k, v = c.Seek([]byte("bas")) - equals(t, []byte("baz"), k) - equals(t, []byte("0003"), v) + if k, v := c.Seek([]byte("bas")); !bytes.Equal(k, []byte("baz")) { + t.Fatalf("unexpected key: %v", k) + } else if !bytes.Equal(v, []byte("0003")) { + t.Fatalf("unexpected value: %v", v) + } // Low key should go to the first key. - k, v = c.Seek([]byte("")) - equals(t, []byte("bar"), k) - equals(t, []byte("0002"), v) + if k, v := c.Seek([]byte("")); !bytes.Equal(k, []byte("bar")) { + t.Fatalf("unexpected key: %v", k) + } else if !bytes.Equal(v, []byte("0002")) { + t.Fatalf("unexpected value: %v", v) + } // High key should return no key. - k, v = c.Seek([]byte("zzz")) - assert(t, k == nil, "") - assert(t, v == nil, "") + if k, v := c.Seek([]byte("zzz")); k != nil { + t.Fatalf("expected nil key: %v", k) + } else if v != nil { + t.Fatalf("expected nil value: %v", v) + } // Buckets should return their key but no value. - k, v = c.Seek([]byte("bkt")) - equals(t, []byte("bkt"), k) - assert(t, v == nil, "") + if k, v := c.Seek([]byte("bkt")); !bytes.Equal(k, []byte("bkt")) { + t.Fatalf("unexpected key: %v", k) + } else if v != nil { + t.Fatalf("expected nil value: %v", v) + } return nil - }) + }); err != nil { + t.Fatal(err) + } } func TestCursor_Delete(t *testing.T) { - db := NewTestDB() - defer db.Close() + db := MustOpenDB() + defer db.MustClose() - var count = 1000 + const count = 1000 // Insert every other key between 0 and $count. - db.Update(func(tx *bolt.Tx) error { - b, _ := tx.CreateBucket([]byte("widgets")) + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } for i := 0; i < count; i += 1 { k := make([]byte, 8) binary.BigEndian.PutUint64(k, uint64(i)) - b.Put(k, make([]byte, 100)) + if err := b.Put(k, make([]byte, 100)); err != nil { + t.Fatal(err) + } + } + if _, err := b.CreateBucket([]byte("sub")); err != nil { + t.Fatal(err) } - b.CreateBucket([]byte("sub")) return nil - }) + }); err != nil { + t.Fatal(err) + } - db.Update(func(tx *bolt.Tx) error { + if err := db.Update(func(tx *bolt.Tx) error { c := tx.Bucket([]byte("widgets")).Cursor() bound := make([]byte, 8) binary.BigEndian.PutUint64(bound, uint64(count/2)) for key, _ := c.First(); bytes.Compare(key, bound) < 0; key, _ = c.Next() { if err := c.Delete(); err != nil { - return err + t.Fatal(err) } } - c.Seek([]byte("sub")) - err := c.Delete() - equals(t, err, bolt.ErrIncompatibleValue) - return nil - }) - db.View(func(tx *bolt.Tx) error { - b := tx.Bucket([]byte("widgets")) - equals(t, b.Stats().KeyN, count/2+1) + c.Seek([]byte("sub")) + if err := c.Delete(); err != bolt.ErrIncompatibleValue { + t.Fatalf("unexpected error: %s", err) + } + return nil - }) + }); err != nil { + t.Fatal(err) + } + + if err := db.View(func(tx *bolt.Tx) error { + stats := tx.Bucket([]byte("widgets")).Stats() + if stats.KeyN != count/2+1 { + t.Fatalf("unexpected KeyN: %d", stats.KeyN) + } + return nil + }); err != nil { + t.Fatal(err) + } } // Ensure that a Tx cursor can seek to the appropriate keys when there are a @@ -116,25 +167,33 @@ func TestCursor_Delete(t *testing.T) { // // Related: https://github.com/boltdb/bolt/pull/187 func TestCursor_Seek_Large(t *testing.T) { - db := NewTestDB() - defer db.Close() + db := MustOpenDB() + defer db.MustClose() var count = 10000 // Insert every other key between 0 and $count. - db.Update(func(tx *bolt.Tx) error { - b, _ := tx.CreateBucket([]byte("widgets")) + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + for i := 0; i < count; i += 100 { for j := i; j < i+100; j += 2 { k := make([]byte, 8) binary.BigEndian.PutUint64(k, uint64(j)) - b.Put(k, make([]byte, 100)) + if err := b.Put(k, make([]byte, 100)); err != nil { + t.Fatal(err) + } } } return nil - }) + }); err != nil { + t.Fatal(err) + } - db.View(func(tx *bolt.Tx) error { + if err := db.View(func(tx *bolt.Tx) error { c := tx.Bucket([]byte("widgets")).Cursor() for i := 0; i < count; i++ { seek := make([]byte, 8) @@ -145,193 +204,359 @@ func TestCursor_Seek_Large(t *testing.T) { // The last seek is beyond the end of the the range so // it should return nil. if i == count-1 { - assert(t, k == nil, "") + if k != nil { + t.Fatal("expected nil key") + } continue } // Otherwise we should seek to the exact key or the next key. num := binary.BigEndian.Uint64(k) if i%2 == 0 { - equals(t, uint64(i), num) + if num != uint64(i) { + t.Fatalf("unexpected num: %d", num) + } } else { - equals(t, uint64(i+1), num) + if num != uint64(i+1) { + t.Fatalf("unexpected num: %d", num) + } } } return nil - }) + }); err != nil { + t.Fatal(err) + } } // Ensure that a cursor can iterate over an empty bucket without error. func TestCursor_EmptyBucket(t *testing.T) { - db := NewTestDB() - defer db.Close() - db.Update(func(tx *bolt.Tx) error { + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { _, err := tx.CreateBucket([]byte("widgets")) return err - }) - db.View(func(tx *bolt.Tx) error { + }); err != nil { + t.Fatal(err) + } + + if err := db.View(func(tx *bolt.Tx) error { c := tx.Bucket([]byte("widgets")).Cursor() k, v := c.First() - assert(t, k == nil, "") - assert(t, v == nil, "") + if k != nil { + t.Fatalf("unexpected key: %v", k) + } else if v != nil { + t.Fatalf("unexpected value: %v", v) + } return nil - }) + }); err != nil { + t.Fatal(err) + } } // Ensure that a Tx cursor can reverse iterate over an empty bucket without error. func TestCursor_EmptyBucketReverse(t *testing.T) { - db := NewTestDB() - defer db.Close() + db := MustOpenDB() + defer db.MustClose() - db.Update(func(tx *bolt.Tx) error { + if err := db.Update(func(tx *bolt.Tx) error { _, err := tx.CreateBucket([]byte("widgets")) return err - }) - db.View(func(tx *bolt.Tx) error { + }); err != nil { + t.Fatal(err) + } + if err := db.View(func(tx *bolt.Tx) error { c := tx.Bucket([]byte("widgets")).Cursor() k, v := c.Last() - assert(t, k == nil, "") - assert(t, v == nil, "") + if k != nil { + t.Fatalf("unexpected key: %v", k) + } else if v != nil { + t.Fatalf("unexpected value: %v", v) + } return nil - }) + }); err != nil { + t.Fatal(err) + } } // Ensure that a Tx cursor can iterate over a single root with a couple elements. func TestCursor_Iterate_Leaf(t *testing.T) { - db := NewTestDB() - defer db.Close() + db := MustOpenDB() + defer db.MustClose() - db.Update(func(tx *bolt.Tx) error { - tx.CreateBucket([]byte("widgets")) - tx.Bucket([]byte("widgets")).Put([]byte("baz"), []byte{}) - tx.Bucket([]byte("widgets")).Put([]byte("foo"), []byte{0}) - tx.Bucket([]byte("widgets")).Put([]byte("bar"), []byte{1}) + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("baz"), []byte{}); err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("foo"), []byte{0}); err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("bar"), []byte{1}); err != nil { + t.Fatal(err) + } return nil - }) - tx, _ := db.Begin(false) + }); err != nil { + t.Fatal(err) + } + tx, err := db.Begin(false) + if err != nil { + t.Fatal(err) + } + defer func() { _ = tx.Rollback() }() + c := tx.Bucket([]byte("widgets")).Cursor() k, v := c.First() - equals(t, string(k), "bar") - equals(t, v, []byte{1}) + if !bytes.Equal(k, []byte("bar")) { + t.Fatalf("unexpected key: %v", k) + } else if !bytes.Equal(v, []byte{1}) { + t.Fatalf("unexpected value: %v", v) + } k, v = c.Next() - equals(t, string(k), "baz") - equals(t, v, []byte{}) + if !bytes.Equal(k, []byte("baz")) { + t.Fatalf("unexpected key: %v", k) + } else if !bytes.Equal(v, []byte{}) { + t.Fatalf("unexpected value: %v", v) + } k, v = c.Next() - equals(t, string(k), "foo") - equals(t, v, []byte{0}) + if !bytes.Equal(k, []byte("foo")) { + t.Fatalf("unexpected key: %v", k) + } else if !bytes.Equal(v, []byte{0}) { + t.Fatalf("unexpected value: %v", v) + } k, v = c.Next() - assert(t, k == nil, "") - assert(t, v == nil, "") + if k != nil { + t.Fatalf("expected nil key: %v", k) + } else if v != nil { + t.Fatalf("expected nil value: %v", v) + } k, v = c.Next() - assert(t, k == nil, "") - assert(t, v == nil, "") + if k != nil { + t.Fatalf("expected nil key: %v", k) + } else if v != nil { + t.Fatalf("expected nil value: %v", v) + } - tx.Rollback() + if err := tx.Rollback(); err != nil { + t.Fatal(err) + } } // Ensure that a Tx cursor can iterate in reverse over a single root with a couple elements. func TestCursor_LeafRootReverse(t *testing.T) { - db := NewTestDB() - defer db.Close() + db := MustOpenDB() + defer db.MustClose() - db.Update(func(tx *bolt.Tx) error { - tx.CreateBucket([]byte("widgets")) - tx.Bucket([]byte("widgets")).Put([]byte("baz"), []byte{}) - tx.Bucket([]byte("widgets")).Put([]byte("foo"), []byte{0}) - tx.Bucket([]byte("widgets")).Put([]byte("bar"), []byte{1}) + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("baz"), []byte{}); err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("foo"), []byte{0}); err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("bar"), []byte{1}); err != nil { + t.Fatal(err) + } return nil - }) - tx, _ := db.Begin(false) + }); err != nil { + t.Fatal(err) + } + tx, err := db.Begin(false) + if err != nil { + t.Fatal(err) + } c := tx.Bucket([]byte("widgets")).Cursor() - k, v := c.Last() - equals(t, string(k), "foo") - equals(t, v, []byte{0}) + if k, v := c.Last(); !bytes.Equal(k, []byte("foo")) { + t.Fatalf("unexpected key: %v", k) + } else if !bytes.Equal(v, []byte{0}) { + t.Fatalf("unexpected value: %v", v) + } - k, v = c.Prev() - equals(t, string(k), "baz") - equals(t, v, []byte{}) + if k, v := c.Prev(); !bytes.Equal(k, []byte("baz")) { + t.Fatalf("unexpected key: %v", k) + } else if !bytes.Equal(v, []byte{}) { + t.Fatalf("unexpected value: %v", v) + } - k, v = c.Prev() - equals(t, string(k), "bar") - equals(t, v, []byte{1}) + if k, v := c.Prev(); !bytes.Equal(k, []byte("bar")) { + t.Fatalf("unexpected key: %v", k) + } else if !bytes.Equal(v, []byte{1}) { + t.Fatalf("unexpected value: %v", v) + } - k, v = c.Prev() - assert(t, k == nil, "") - assert(t, v == nil, "") + if k, v := c.Prev(); k != nil { + t.Fatalf("expected nil key: %v", k) + } else if v != nil { + t.Fatalf("expected nil value: %v", v) + } - k, v = c.Prev() - assert(t, k == nil, "") - assert(t, v == nil, "") + if k, v := c.Prev(); k != nil { + t.Fatalf("expected nil key: %v", k) + } else if v != nil { + t.Fatalf("expected nil value: %v", v) + } - tx.Rollback() + if err := tx.Rollback(); err != nil { + t.Fatal(err) + } } // Ensure that a Tx cursor can restart from the beginning. func TestCursor_Restart(t *testing.T) { - db := NewTestDB() - defer db.Close() + db := MustOpenDB() + defer db.MustClose() - db.Update(func(tx *bolt.Tx) error { - tx.CreateBucket([]byte("widgets")) - tx.Bucket([]byte("widgets")).Put([]byte("bar"), []byte{}) - tx.Bucket([]byte("widgets")).Put([]byte("foo"), []byte{}) + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("bar"), []byte{}); err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("foo"), []byte{}); err != nil { + t.Fatal(err) + } return nil - }) + }); err != nil { + t.Fatal(err) + } - tx, _ := db.Begin(false) + tx, err := db.Begin(false) + if err != nil { + t.Fatal(err) + } c := tx.Bucket([]byte("widgets")).Cursor() - k, _ := c.First() - equals(t, string(k), "bar") + if k, _ := c.First(); !bytes.Equal(k, []byte("bar")) { + t.Fatalf("unexpected key: %v", k) + } + if k, _ := c.Next(); !bytes.Equal(k, []byte("foo")) { + t.Fatalf("unexpected key: %v", k) + } - k, _ = c.Next() - equals(t, string(k), "foo") + if k, _ := c.First(); !bytes.Equal(k, []byte("bar")) { + t.Fatalf("unexpected key: %v", k) + } + if k, _ := c.Next(); !bytes.Equal(k, []byte("foo")) { + t.Fatalf("unexpected key: %v", k) + } - k, _ = c.First() - equals(t, string(k), "bar") + if err := tx.Rollback(); err != nil { + t.Fatal(err) + } +} - k, _ = c.Next() - equals(t, string(k), "foo") +// Ensure that a cursor can skip over empty pages that have been deleted. +func TestCursor_First_EmptyPages(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() - tx.Rollback() + // Create 1000 keys in the "widgets" bucket. + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + + for i := 0; i < 1000; i++ { + if err := b.Put(u64tob(uint64(i)), []byte{}); err != nil { + t.Fatal(err) + } + } + + return nil + }); err != nil { + t.Fatal(err) + } + + // Delete half the keys and then try to iterate. + if err := db.Update(func(tx *bolt.Tx) error { + b := tx.Bucket([]byte("widgets")) + for i := 0; i < 600; i++ { + if err := b.Delete(u64tob(uint64(i))); err != nil { + t.Fatal(err) + } + } + + c := b.Cursor() + var n int + for k, _ := c.First(); k != nil; k, _ = c.Next() { + n++ + } + if n != 400 { + t.Fatalf("unexpected key count: %d", n) + } + + return nil + }); err != nil { + t.Fatal(err) + } } // Ensure that a Tx can iterate over all elements in a bucket. func TestCursor_QuickCheck(t *testing.T) { f := func(items testdata) bool { - db := NewTestDB() - defer db.Close() + db := MustOpenDB() + defer db.MustClose() // Bulk insert all values. - tx, _ := db.Begin(true) - tx.CreateBucket([]byte("widgets")) - b := tx.Bucket([]byte("widgets")) - for _, item := range items { - ok(t, b.Put(item.Key, item.Value)) + tx, err := db.Begin(true) + if err != nil { + t.Fatal(err) + } + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + for _, item := range items { + if err := b.Put(item.Key, item.Value); err != nil { + t.Fatal(err) + } + } + if err := tx.Commit(); err != nil { + t.Fatal(err) } - ok(t, tx.Commit()) // Sort test data. sort.Sort(items) // Iterate over all items and check consistency. var index = 0 - tx, _ = db.Begin(false) + tx, err = db.Begin(false) + if err != nil { + t.Fatal(err) + } + c := tx.Bucket([]byte("widgets")).Cursor() for k, v := c.First(); k != nil && index < len(items); k, v = c.Next() { - equals(t, k, items[index].Key) - equals(t, v, items[index].Value) + if !bytes.Equal(k, items[index].Key) { + t.Fatalf("unexpected key: %v", k) + } else if !bytes.Equal(v, items[index].Value) { + t.Fatalf("unexpected value: %v", v) + } index++ } - equals(t, len(items), index) - tx.Rollback() + if len(items) != index { + t.Fatalf("unexpected item count: %v, expected %v", len(items), index) + } + + if err := tx.Rollback(); err != nil { + t.Fatal(err) + } return true } @@ -343,32 +568,52 @@ func TestCursor_QuickCheck(t *testing.T) { // Ensure that a transaction can iterate over all elements in a bucket in reverse. func TestCursor_QuickCheck_Reverse(t *testing.T) { f := func(items testdata) bool { - db := NewTestDB() - defer db.Close() + db := MustOpenDB() + defer db.MustClose() // Bulk insert all values. - tx, _ := db.Begin(true) - tx.CreateBucket([]byte("widgets")) - b := tx.Bucket([]byte("widgets")) - for _, item := range items { - ok(t, b.Put(item.Key, item.Value)) + tx, err := db.Begin(true) + if err != nil { + t.Fatal(err) + } + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + for _, item := range items { + if err := b.Put(item.Key, item.Value); err != nil { + t.Fatal(err) + } + } + if err := tx.Commit(); err != nil { + t.Fatal(err) } - ok(t, tx.Commit()) // Sort test data. sort.Sort(revtestdata(items)) // Iterate over all items and check consistency. var index = 0 - tx, _ = db.Begin(false) + tx, err = db.Begin(false) + if err != nil { + t.Fatal(err) + } c := tx.Bucket([]byte("widgets")).Cursor() for k, v := c.Last(); k != nil && index < len(items); k, v = c.Prev() { - equals(t, k, items[index].Key) - equals(t, v, items[index].Value) + if !bytes.Equal(k, items[index].Key) { + t.Fatalf("unexpected key: %v", k) + } else if !bytes.Equal(v, items[index].Value) { + t.Fatalf("unexpected value: %v", v) + } index++ } - equals(t, len(items), index) - tx.Rollback() + if len(items) != index { + t.Fatalf("unexpected item count: %v, expected %v", len(items), index) + } + + if err := tx.Rollback(); err != nil { + t.Fatal(err) + } return true } @@ -379,76 +624,114 @@ func TestCursor_QuickCheck_Reverse(t *testing.T) { // Ensure that a Tx cursor can iterate over subbuckets. func TestCursor_QuickCheck_BucketsOnly(t *testing.T) { - db := NewTestDB() - defer db.Close() + db := MustOpenDB() + defer db.MustClose() - db.Update(func(tx *bolt.Tx) error { + if err := db.Update(func(tx *bolt.Tx) error { b, err := tx.CreateBucket([]byte("widgets")) - ok(t, err) - _, err = b.CreateBucket([]byte("foo")) - ok(t, err) - _, err = b.CreateBucket([]byte("bar")) - ok(t, err) - _, err = b.CreateBucket([]byte("baz")) - ok(t, err) + if err != nil { + t.Fatal(err) + } + if _, err := b.CreateBucket([]byte("foo")); err != nil { + t.Fatal(err) + } + if _, err := b.CreateBucket([]byte("bar")); err != nil { + t.Fatal(err) + } + if _, err := b.CreateBucket([]byte("baz")); err != nil { + t.Fatal(err) + } return nil - }) - db.View(func(tx *bolt.Tx) error { + }); err != nil { + t.Fatal(err) + } + + if err := db.View(func(tx *bolt.Tx) error { var names []string c := tx.Bucket([]byte("widgets")).Cursor() for k, v := c.First(); k != nil; k, v = c.Next() { names = append(names, string(k)) - assert(t, v == nil, "") + if v != nil { + t.Fatalf("unexpected value: %v", v) + } + } + if !reflect.DeepEqual(names, []string{"bar", "baz", "foo"}) { + t.Fatalf("unexpected names: %+v", names) } - equals(t, names, []string{"bar", "baz", "foo"}) return nil - }) + }); err != nil { + t.Fatal(err) + } } // Ensure that a Tx cursor can reverse iterate over subbuckets. func TestCursor_QuickCheck_BucketsOnly_Reverse(t *testing.T) { - db := NewTestDB() - defer db.Close() + db := MustOpenDB() + defer db.MustClose() - db.Update(func(tx *bolt.Tx) error { + if err := db.Update(func(tx *bolt.Tx) error { b, err := tx.CreateBucket([]byte("widgets")) - ok(t, err) - _, err = b.CreateBucket([]byte("foo")) - ok(t, err) - _, err = b.CreateBucket([]byte("bar")) - ok(t, err) - _, err = b.CreateBucket([]byte("baz")) - ok(t, err) + if err != nil { + t.Fatal(err) + } + if _, err := b.CreateBucket([]byte("foo")); err != nil { + t.Fatal(err) + } + if _, err := b.CreateBucket([]byte("bar")); err != nil { + t.Fatal(err) + } + if _, err := b.CreateBucket([]byte("baz")); err != nil { + t.Fatal(err) + } return nil - }) - db.View(func(tx *bolt.Tx) error { + }); err != nil { + t.Fatal(err) + } + + if err := db.View(func(tx *bolt.Tx) error { var names []string c := tx.Bucket([]byte("widgets")).Cursor() for k, v := c.Last(); k != nil; k, v = c.Prev() { names = append(names, string(k)) - assert(t, v == nil, "") + if v != nil { + t.Fatalf("unexpected value: %v", v) + } + } + if !reflect.DeepEqual(names, []string{"foo", "baz", "bar"}) { + t.Fatalf("unexpected names: %+v", names) } - equals(t, names, []string{"foo", "baz", "bar"}) return nil - }) + }); err != nil { + t.Fatal(err) + } } func ExampleCursor() { // Open the database. - db, _ := bolt.Open(tempfile(), 0666, nil) + db, err := bolt.Open(tempfile(), 0666, nil) + if err != nil { + log.Fatal(err) + } defer os.Remove(db.Path()) - defer db.Close() // Start a read-write transaction. - db.Update(func(tx *bolt.Tx) error { + if err := db.Update(func(tx *bolt.Tx) error { // Create a new bucket. - tx.CreateBucket([]byte("animals")) + b, err := tx.CreateBucket([]byte("animals")) + if err != nil { + return err + } // Insert data into a bucket. - b := tx.Bucket([]byte("animals")) - b.Put([]byte("dog"), []byte("fun")) - b.Put([]byte("cat"), []byte("lame")) - b.Put([]byte("liger"), []byte("awesome")) + if err := b.Put([]byte("dog"), []byte("fun")); err != nil { + log.Fatal(err) + } + if err := b.Put([]byte("cat"), []byte("lame")); err != nil { + log.Fatal(err) + } + if err := b.Put([]byte("liger"), []byte("awesome")); err != nil { + log.Fatal(err) + } // Create a cursor for iteration. c := b.Cursor() @@ -463,7 +746,13 @@ func ExampleCursor() { } return nil - }) + }); err != nil { + log.Fatal(err) + } + + if err := db.Close(); err != nil { + log.Fatal(err) + } // Output: // A cat is lame. @@ -473,20 +762,30 @@ func ExampleCursor() { func ExampleCursor_reverse() { // Open the database. - db, _ := bolt.Open(tempfile(), 0666, nil) + db, err := bolt.Open(tempfile(), 0666, nil) + if err != nil { + log.Fatal(err) + } defer os.Remove(db.Path()) - defer db.Close() // Start a read-write transaction. - db.Update(func(tx *bolt.Tx) error { + if err := db.Update(func(tx *bolt.Tx) error { // Create a new bucket. - tx.CreateBucket([]byte("animals")) + b, err := tx.CreateBucket([]byte("animals")) + if err != nil { + return err + } // Insert data into a bucket. - b := tx.Bucket([]byte("animals")) - b.Put([]byte("dog"), []byte("fun")) - b.Put([]byte("cat"), []byte("lame")) - b.Put([]byte("liger"), []byte("awesome")) + if err := b.Put([]byte("dog"), []byte("fun")); err != nil { + log.Fatal(err) + } + if err := b.Put([]byte("cat"), []byte("lame")); err != nil { + log.Fatal(err) + } + if err := b.Put([]byte("liger"), []byte("awesome")); err != nil { + log.Fatal(err) + } // Create a cursor for iteration. c := b.Cursor() @@ -502,7 +801,14 @@ func ExampleCursor_reverse() { } return nil - }) + }); err != nil { + log.Fatal(err) + } + + // Close the database to release the file lock. + if err := db.Close(); err != nil { + log.Fatal(err) + } // Output: // A liger is awesome. diff --git a/Godeps/_workspace/src/github.com/boltdb/bolt/db.go b/Godeps/_workspace/src/github.com/boltdb/bolt/db.go index 4775850f..0f1e1bc3 100644 --- a/Godeps/_workspace/src/github.com/boltdb/bolt/db.go +++ b/Godeps/_workspace/src/github.com/boltdb/bolt/db.go @@ -1,8 +1,10 @@ package bolt import ( + "errors" "fmt" "hash/fnv" + "log" "os" "runtime" "runtime/debug" @@ -24,9 +26,16 @@ const magic uint32 = 0xED0CDAED // IgnoreNoSync specifies whether the NoSync field of a DB is ignored when // syncing changes to a file. This is required as some operating systems, // such as OpenBSD, do not have a unified buffer cache (UBC) and writes -// must be synchronzied using the msync(2) syscall. +// must be synchronized using the msync(2) syscall. const IgnoreNoSync = runtime.GOOS == "openbsd" +// Default values if not set in a DB instance. +const ( + DefaultMaxBatchSize int = 1000 + DefaultMaxBatchDelay = 10 * time.Millisecond + DefaultAllocSize = 16 * 1024 * 1024 +) + // DB represents a collection of buckets persisted to a file on disk. // All data access is performed through transactions which can be obtained through the DB. // All the functions on DB will return a ErrDatabaseNotOpen if accessed before Open() is called. @@ -49,11 +58,45 @@ type DB struct { // THIS IS UNSAFE. PLEASE USE WITH CAUTION. NoSync bool + // When true, skips the truncate call when growing the database. + // Setting this to true is only safe on non-ext3/ext4 systems. + // Skipping truncation avoids preallocation of hard drive space and + // bypasses a truncate() and fsync() syscall on remapping. + // + // https://github.com/boltdb/bolt/issues/284 + NoGrowSync bool + + // If you want to read the entire database fast, you can set MmapFlag to + // syscall.MAP_POPULATE on Linux 2.6.23+ for sequential read-ahead. + MmapFlags int + + // MaxBatchSize is the maximum size of a batch. Default value is + // copied from DefaultMaxBatchSize in Open. + // + // If <=0, disables batching. + // + // Do not change concurrently with calls to Batch. + MaxBatchSize int + + // MaxBatchDelay is the maximum delay before a batch starts. + // Default value is copied from DefaultMaxBatchDelay in Open. + // + // If <=0, effectively disables batching. + // + // Do not change concurrently with calls to Batch. + MaxBatchDelay time.Duration + + // AllocSize is the amount of space allocated when the database + // needs to create new pages. This is done to amortize the cost + // of truncate() and fsync() when growing the data file. + AllocSize int + path string file *os.File - dataref []byte + dataref []byte // mmap'ed readonly, write throws SEGV data *[maxMapSize]byte datasz int + filesz int // current on disk file size meta0 *meta meta1 *meta pageSize int @@ -63,6 +106,9 @@ type DB struct { freelist *freelist stats Stats + batchMu sync.Mutex + batch *batch + rwlock sync.Mutex // Allows only one writer at a time. metalock sync.Mutex // Protects meta page access. mmaplock sync.RWMutex // Protects mmap access during remapping. @@ -71,6 +117,10 @@ type DB struct { ops struct { writeAt func(b []byte, off int64) (n int, err error) } + + // Read only mode. + // When true, Update() and Begin(true) return ErrDatabaseReadOnly immediately. + readOnly bool } // Path returns the path to currently open database file. @@ -98,20 +148,36 @@ func Open(path string, mode os.FileMode, options *Options) (*DB, error) { if options == nil { options = DefaultOptions } + db.NoGrowSync = options.NoGrowSync + db.MmapFlags = options.MmapFlags + + // Set default values for later DB operations. + db.MaxBatchSize = DefaultMaxBatchSize + db.MaxBatchDelay = DefaultMaxBatchDelay + db.AllocSize = DefaultAllocSize + + flag := os.O_RDWR + if options.ReadOnly { + flag = os.O_RDONLY + db.readOnly = true + } // Open data file and separate sync handler for metadata writes. db.path = path - var err error - if db.file, err = os.OpenFile(db.path, os.O_RDWR|os.O_CREATE, mode); err != nil { + if db.file, err = os.OpenFile(db.path, flag|os.O_CREATE, mode); err != nil { _ = db.close() return nil, err } - // Lock file so that other processes using Bolt cannot use the database - // at the same time. This would cause corruption since the two processes - // would write meta pages and free pages separately. - if err := flock(db.file, options.Timeout); err != nil { + // Lock file so that other processes using Bolt in read-write mode cannot + // use the database at the same time. This would cause corruption since + // the two processes would write meta pages and free pages separately. + // The database file is locked exclusively (only one process can grab the lock) + // if !options.ReadOnly. + // The database file is locked using the shared lock (more than one process may + // hold a lock at the same time) otherwise (options.ReadOnly is set). + if err := flock(db.file, !db.readOnly, options.Timeout); err != nil { _ = db.close() return nil, err } @@ -121,7 +187,7 @@ func Open(path string, mode os.FileMode, options *Options) (*DB, error) { // Initialize the database if it doesn't exist. if info, err := db.file.Stat(); err != nil { - return nil, fmt.Errorf("stat error: %s", err) + return nil, err } else if info.Size() == 0 { // Initialize new files with meta pages. if err := db.init(); err != nil { @@ -133,14 +199,14 @@ func Open(path string, mode os.FileMode, options *Options) (*DB, error) { if _, err := db.file.ReadAt(buf[:], 0); err == nil { m := db.pageInBuffer(buf[:], 0).meta() if err := m.validate(); err != nil { - return nil, fmt.Errorf("meta0 error: %s", err) + return nil, err } db.pageSize = int(m.pageSize) } } // Memory map the data file. - if err := db.mmap(0); err != nil { + if err := db.mmap(options.InitialMmapSize); err != nil { _ = db.close() return nil, err } @@ -197,10 +263,10 @@ func (db *DB) mmap(minsz int) error { // Validate the meta pages. if err := db.meta0.validate(); err != nil { - return fmt.Errorf("meta0 error: %s", err) + return err } if err := db.meta1.validate(); err != nil { - return fmt.Errorf("meta1 error: %s", err) + return err } return nil @@ -215,11 +281,11 @@ func (db *DB) munmap() error { } // mmapSize determines the appropriate size for the mmap given the current size -// of the database. The minimum size is 4MB and doubles until it reaches 1GB. +// of the database. The minimum size is 32KB and doubles until it reaches 1GB. // Returns an error if the new mmap size is greater than the max allowed. func (db *DB) mmapSize(size int) (int, error) { - // Double the size from 1MB until 1GB. - for i := uint(20); i <= 30; i++ { + // Double the size from 32KB until 1GB. + for i := uint(15); i <= 30; i++ { if size <= 1<= db.MaxBatchSize) { + // There is no existing batch, or the existing batch is full; start a new one. + db.batch = &batch{ + db: db, + } + db.batch.timer = time.AfterFunc(db.MaxBatchDelay, db.batch.trigger) + } + db.batch.calls = append(db.batch.calls, call{fn: fn, err: errCh}) + if len(db.batch.calls) >= db.MaxBatchSize { + // wake up batch, it's ready to run + go db.batch.trigger() + } + db.batchMu.Unlock() + + err := <-errCh + if err == trySolo { + err = db.Update(fn) + } + return err +} + +type call struct { + fn func(*Tx) error + err chan<- error +} + +type batch struct { + db *DB + timer *time.Timer + start sync.Once + calls []call +} + +// trigger runs the batch if it hasn't already been run. +func (b *batch) trigger() { + b.start.Do(b.run) +} + +// run performs the transactions in the batch and communicates results +// back to DB.Batch. +func (b *batch) run() { + b.db.batchMu.Lock() + b.timer.Stop() + // Make sure no new work is added to this batch, but don't break + // other batches. + if b.db.batch == b { + b.db.batch = nil + } + b.db.batchMu.Unlock() + +retry: + for len(b.calls) > 0 { + var failIdx = -1 + err := b.db.Update(func(tx *Tx) error { + for i, c := range b.calls { + if err := safelyCall(c.fn, tx); err != nil { + failIdx = i + return err + } + } + return nil + }) + + if failIdx >= 0 { + // take the failing transaction out of the batch. it's + // safe to shorten b.calls here because db.batch no longer + // points to us, and we hold the mutex anyway. + c := b.calls[failIdx] + b.calls[failIdx], b.calls = b.calls[len(b.calls)-1], b.calls[:len(b.calls)-1] + // tell the submitter re-run it solo, continue with the rest of the batch + c.err <- trySolo + continue retry + } + + // pass success, or bolt internal errors, to all callers + for _, c := range b.calls { + if c.err != nil { + c.err <- err + } + } + break retry + } +} + +// trySolo is a special sentinel error value used for signaling that a +// transaction function should be re-run. It should never be seen by +// callers. +var trySolo = errors.New("batch function returned an error and should be re-run solo") + +type panicked struct { + reason interface{} +} + +func (p panicked) Error() string { + if err, ok := p.reason.(error); ok { + return err.Error() + } + return fmt.Sprintf("panic: %v", p.reason) +} + +func safelyCall(fn func(*Tx) error, tx *Tx) (err error) { + defer func() { + if p := recover(); p != nil { + err = panicked{p} + } + }() + return fn(tx) +} + +// Sync executes fdatasync() against the database file handle. +// +// This is not necessary under normal operation, however, if you use NoSync +// then it allows you to force the database file to sync against the disk. +func (db *DB) Sync() error { return fdatasync(db) } + // Stats retrieves ongoing performance stats for the database. // This is only updated when a transaction closes. func (db *DB) Stats() Stats { @@ -578,18 +806,73 @@ func (db *DB) allocate(count int) (*page, error) { return p, nil } +// grow grows the size of the database to the given sz. +func (db *DB) grow(sz int) error { + // Ignore if the new size is less than available file size. + if sz <= db.filesz { + return nil + } + + // If the data is smaller than the alloc size then only allocate what's needed. + // Once it goes over the allocation size then allocate in chunks. + if db.datasz < db.AllocSize { + sz = db.datasz + } else { + sz += db.AllocSize + } + + // Truncate and fsync to ensure file size metadata is flushed. + // https://github.com/boltdb/bolt/issues/284 + if !db.NoGrowSync && !db.readOnly { + if err := db.file.Truncate(int64(sz)); err != nil { + return fmt.Errorf("file resize error: %s", err) + } + if err := db.file.Sync(); err != nil { + return fmt.Errorf("file sync error: %s", err) + } + } + + db.filesz = sz + return nil +} + +func (db *DB) IsReadOnly() bool { + return db.readOnly +} + // Options represents the options that can be set when opening a database. type Options struct { // Timeout is the amount of time to wait to obtain a file lock. // When set to zero it will wait indefinitely. This option is only // available on Darwin and Linux. Timeout time.Duration + + // Sets the DB.NoGrowSync flag before memory mapping the file. + NoGrowSync bool + + // Open database in read-only mode. Uses flock(..., LOCK_SH |LOCK_NB) to + // grab a shared lock (UNIX). + ReadOnly bool + + // Sets the DB.MmapFlags flag before memory mapping the file. + MmapFlags int + + // InitialMmapSize is the initial mmap size of the database + // in bytes. Read transactions won't block write transaction + // if the InitialMmapSize is large enough to hold database mmap + // size. (See DB.Begin for more information) + // + // If <=0, the initial map size is 0. + // If initialMmapSize is smaller than the previous database size, + // it takes no effect. + InitialMmapSize int } // DefaultOptions represent the options used if nil options are passed into Open(). // No timeout is used which will cause Bolt to wait indefinitely for a lock. var DefaultOptions = &Options{ - Timeout: 0, + Timeout: 0, + NoGrowSync: false, } // Stats represents statistics about the database. diff --git a/Godeps/_workspace/src/github.com/boltdb/bolt/db_test.go b/Godeps/_workspace/src/github.com/boltdb/bolt/db_test.go index 9ca7193a..b535fa76 100644 --- a/Godeps/_workspace/src/github.com/boltdb/bolt/db_test.go +++ b/Godeps/_workspace/src/github.com/boltdb/bolt/db_test.go @@ -1,110 +1,253 @@ package bolt_test import ( + "bytes" "encoding/binary" "errors" "flag" "fmt" + "hash/fnv" "io/ioutil" + "log" "os" + "path/filepath" "regexp" "runtime" "sort" "strings" + "sync" "testing" "time" + "unsafe" "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/boltdb/bolt" ) var statsFlag = flag.Bool("stats", false, "show performance stats") -// Ensure that opening a database with a bad path returns an error. -func TestOpen_BadPath(t *testing.T) { - db, err := bolt.Open("", 0666, nil) - assert(t, err != nil, "err: %s", err) - assert(t, db == nil, "") +// version is the data file format version. +const version = 2 + +// magic is the marker value to indicate that a file is a Bolt DB. +const magic uint32 = 0xED0CDAED + +// pageSize is the size of one page in the data file. +const pageSize = 4096 + +// pageHeaderSize is the size of a page header. +const pageHeaderSize = 16 + +// meta represents a simplified version of a database meta page for testing. +type meta struct { + magic uint32 + version uint32 + _ uint32 + _ uint32 + _ [16]byte + _ uint64 + _ uint64 + _ uint64 + checksum uint64 } // Ensure that a database can be opened without error. func TestOpen(t *testing.T) { path := tempfile() - defer os.Remove(path) db, err := bolt.Open(path, 0666, nil) - assert(t, db != nil, "") - ok(t, err) - equals(t, db.Path(), path) - ok(t, db.Close()) + if err != nil { + t.Fatal(err) + } else if db == nil { + t.Fatal("expected db") + } + + if s := db.Path(); s != path { + t.Fatalf("unexpected path: %s", s) + } + + if err := db.Close(); err != nil { + t.Fatal(err) + } +} + +// Ensure that opening a database with a blank path returns an error. +func TestOpen_ErrPathRequired(t *testing.T) { + _, err := bolt.Open("", 0666, nil) + if err == nil { + t.Fatalf("expected error") + } +} + +// Ensure that opening a database with a bad path returns an error. +func TestOpen_ErrNotExists(t *testing.T) { + _, err := bolt.Open(filepath.Join(tempfile(), "bad-path"), 0666, nil) + if err == nil { + t.Fatal("expected error") + } +} + +// Ensure that opening a file with wrong checksum returns ErrChecksum. +func TestOpen_ErrChecksum(t *testing.T) { + buf := make([]byte, pageSize) + meta := (*meta)(unsafe.Pointer(&buf[0])) + meta.magic = magic + meta.version = version + meta.checksum = 123 + + path := tempfile() + f, err := os.Create(path) + if err != nil { + t.Fatal(err) + } + if _, err := f.WriteAt(buf, pageHeaderSize); err != nil { + t.Fatal(err) + } + if err := f.Close(); err != nil { + t.Fatal(err) + } + defer os.Remove(path) + + if _, err := bolt.Open(path, 0666, nil); err != bolt.ErrChecksum { + t.Fatalf("unexpected error: %s", err) + } +} + +// Ensure that opening a file that is not a Bolt database returns ErrInvalid. +func TestOpen_ErrInvalid(t *testing.T) { + path := tempfile() + + f, err := os.Create(path) + if err != nil { + t.Fatal(err) + } + if _, err := fmt.Fprintln(f, "this is not a bolt database"); err != nil { + t.Fatal(err) + } + if err := f.Close(); err != nil { + t.Fatal(err) + } + defer os.Remove(path) + + if _, err := bolt.Open(path, 0666, nil); err != bolt.ErrInvalid { + t.Fatalf("unexpected error: %s", err) + } +} + +// Ensure that opening a file created with a different version of Bolt returns +// ErrVersionMismatch. +func TestOpen_ErrVersionMismatch(t *testing.T) { + buf := make([]byte, pageSize) + meta := (*meta)(unsafe.Pointer(&buf[0])) + meta.magic = magic + meta.version = version + 100 + + path := tempfile() + f, err := os.Create(path) + if err != nil { + t.Fatal(err) + } + if _, err := f.WriteAt(buf, pageHeaderSize); err != nil { + t.Fatal(err) + } + if err := f.Close(); err != nil { + t.Fatal(err) + } + defer os.Remove(path) + + if _, err := bolt.Open(path, 0666, nil); err != bolt.ErrVersionMismatch { + t.Fatalf("unexpected error: %s", err) + } } // Ensure that opening an already open database file will timeout. func TestOpen_Timeout(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("timeout not supported on windows") + if runtime.GOOS == "solaris" { + t.Skip("solaris fcntl locks don't support intra-process locking") } path := tempfile() - defer os.Remove(path) // Open a data file. db0, err := bolt.Open(path, 0666, nil) - assert(t, db0 != nil, "") - ok(t, err) + if err != nil { + t.Fatal(err) + } else if db0 == nil { + t.Fatal("expected database") + } // Attempt to open the database again. start := time.Now() db1, err := bolt.Open(path, 0666, &bolt.Options{Timeout: 100 * time.Millisecond}) - assert(t, db1 == nil, "") - equals(t, bolt.ErrTimeout, err) - assert(t, time.Since(start) > 100*time.Millisecond, "") + if err != bolt.ErrTimeout { + t.Fatalf("unexpected timeout: %s", err) + } else if db1 != nil { + t.Fatal("unexpected database") + } else if time.Since(start) <= 100*time.Millisecond { + t.Fatal("expected to wait at least timeout duration") + } - db0.Close() + if err := db0.Close(); err != nil { + t.Fatal(err) + } } // Ensure that opening an already open database file will wait until its closed. func TestOpen_Wait(t *testing.T) { - if runtime.GOOS == "windows" { - t.Skip("timeout not supported on windows") + if runtime.GOOS == "solaris" { + t.Skip("solaris fcntl locks don't support intra-process locking") } path := tempfile() - defer os.Remove(path) // Open a data file. db0, err := bolt.Open(path, 0666, nil) - assert(t, db0 != nil, "") - ok(t, err) + if err != nil { + t.Fatal(err) + } // Close it in just a bit. - time.AfterFunc(100*time.Millisecond, func() { db0.Close() }) + time.AfterFunc(100*time.Millisecond, func() { _ = db0.Close() }) // Attempt to open the database again. start := time.Now() db1, err := bolt.Open(path, 0666, &bolt.Options{Timeout: 200 * time.Millisecond}) - assert(t, db1 != nil, "") - ok(t, err) - assert(t, time.Since(start) > 100*time.Millisecond, "") + if err != nil { + t.Fatal(err) + } else if time.Since(start) <= 100*time.Millisecond { + t.Fatal("expected to wait at least timeout duration") + } + + if err := db1.Close(); err != nil { + t.Fatal(err) + } } // Ensure that opening a database does not increase its size. // https://github.com/boltdb/bolt/issues/291 func TestOpen_Size(t *testing.T) { // Open a data file. - db := NewTestDB() + db := MustOpenDB() path := db.Path() - defer db.Close() + defer db.MustClose() + + pagesize := db.Info().PageSize // Insert until we get above the minimum 4MB size. - ok(t, db.Update(func(tx *bolt.Tx) error { + if err := db.Update(func(tx *bolt.Tx) error { b, _ := tx.CreateBucketIfNotExists([]byte("data")) for i := 0; i < 10000; i++ { - ok(t, b.Put([]byte(fmt.Sprintf("%04d", i)), make([]byte, 1000))) + if err := b.Put([]byte(fmt.Sprintf("%04d", i)), make([]byte, 1000)); err != nil { + t.Fatal(err) + } } return nil - })) + }); err != nil { + t.Fatal(err) + } // Close database and grab the size. - db.DB.Close() + if err := db.DB.Close(); err != nil { + t.Fatal(err) + } sz := fileSize(path) if sz == 0 { t.Fatalf("unexpected new file size: %d", sz) @@ -112,16 +255,28 @@ func TestOpen_Size(t *testing.T) { // Reopen database, update, and check size again. db0, err := bolt.Open(path, 0666, nil) - ok(t, err) - ok(t, db0.Update(func(tx *bolt.Tx) error { return tx.Bucket([]byte("data")).Put([]byte{0}, []byte{0}) })) - ok(t, db0.Close()) + if err != nil { + t.Fatal(err) + } + if err := db0.Update(func(tx *bolt.Tx) error { + if err := tx.Bucket([]byte("data")).Put([]byte{0}, []byte{0}); err != nil { + t.Fatal(err) + } + return nil + }); err != nil { + t.Fatal(err) + } + if err := db0.Close(); err != nil { + t.Fatal(err) + } newSz := fileSize(path) if newSz == 0 { t.Fatalf("unexpected new file size: %d", newSz) } // Compare the original size with the new size. - if sz != newSz { + // db size might increase by a few page sizes due to the new small update. + if sz < newSz-5*int64(pagesize) { t.Fatalf("unexpected file growth: %d => %d", sz, newSz) } } @@ -134,25 +289,33 @@ func TestOpen_Size_Large(t *testing.T) { } // Open a data file. - db := NewTestDB() + db := MustOpenDB() path := db.Path() - defer db.Close() + defer db.MustClose() + + pagesize := db.Info().PageSize // Insert until we get above the minimum 4MB size. var index uint64 for i := 0; i < 10000; i++ { - ok(t, db.Update(func(tx *bolt.Tx) error { + if err := db.Update(func(tx *bolt.Tx) error { b, _ := tx.CreateBucketIfNotExists([]byte("data")) for j := 0; j < 1000; j++ { - ok(t, b.Put(u64tob(index), make([]byte, 50))) + if err := b.Put(u64tob(index), make([]byte, 50)); err != nil { + t.Fatal(err) + } index++ } return nil - })) + }); err != nil { + t.Fatal(err) + } } // Close database and grab the size. - db.DB.Close() + if err := db.DB.Close(); err != nil { + t.Fatal(err) + } sz := fileSize(path) if sz == 0 { t.Fatalf("unexpected new file size: %d", sz) @@ -162,16 +325,26 @@ func TestOpen_Size_Large(t *testing.T) { // Reopen database, update, and check size again. db0, err := bolt.Open(path, 0666, nil) - ok(t, err) - ok(t, db0.Update(func(tx *bolt.Tx) error { return tx.Bucket([]byte("data")).Put([]byte{0}, []byte{0}) })) - ok(t, db0.Close()) + if err != nil { + t.Fatal(err) + } + if err := db0.Update(func(tx *bolt.Tx) error { + return tx.Bucket([]byte("data")).Put([]byte{0}, []byte{0}) + }); err != nil { + t.Fatal(err) + } + if err := db0.Close(); err != nil { + t.Fatal(err) + } + newSz := fileSize(path) if newSz == 0 { t.Fatalf("unexpected new file size: %d", newSz) } // Compare the original size with the new size. - if sz != newSz { + // db size might increase by a few page sizes due to the new small update. + if sz < newSz-5*int64(pagesize) { t.Fatalf("unexpected file growth: %d => %d", sz, newSz) } } @@ -179,327 +352,641 @@ func TestOpen_Size_Large(t *testing.T) { // Ensure that a re-opened database is consistent. func TestOpen_Check(t *testing.T) { path := tempfile() - defer os.Remove(path) db, err := bolt.Open(path, 0666, nil) - ok(t, err) - ok(t, db.View(func(tx *bolt.Tx) error { return <-tx.Check() })) - db.Close() + if err != nil { + t.Fatal(err) + } + if err := db.View(func(tx *bolt.Tx) error { return <-tx.Check() }); err != nil { + t.Fatal(err) + } + if err := db.Close(); err != nil { + t.Fatal(err) + } db, err = bolt.Open(path, 0666, nil) - ok(t, err) - ok(t, db.View(func(tx *bolt.Tx) error { return <-tx.Check() })) - db.Close() -} - -// Ensure that the database returns an error if the file handle cannot be open. -func TestDB_Open_FileError(t *testing.T) { - path := tempfile() - defer os.Remove(path) - - _, err := bolt.Open(path+"/youre-not-my-real-parent", 0666, nil) - assert(t, err.(*os.PathError) != nil, "") - equals(t, path+"/youre-not-my-real-parent", err.(*os.PathError).Path) - equals(t, "open", err.(*os.PathError).Op) + if err != nil { + t.Fatal(err) + } + if err := db.View(func(tx *bolt.Tx) error { return <-tx.Check() }); err != nil { + t.Fatal(err) + } + if err := db.Close(); err != nil { + t.Fatal(err) + } } // Ensure that write errors to the meta file handler during initialization are returned. -func TestDB_Open_MetaInitWriteError(t *testing.T) { +func TestOpen_MetaInitWriteError(t *testing.T) { t.Skip("pending") } // Ensure that a database that is too small returns an error. -func TestDB_Open_FileTooSmall(t *testing.T) { +func TestOpen_FileTooSmall(t *testing.T) { + path := tempfile() + + db, err := bolt.Open(path, 0666, nil) + if err != nil { + t.Fatal(err) + } + if err := db.Close(); err != nil { + t.Fatal(err) + } + + // corrupt the database + if err := os.Truncate(path, int64(os.Getpagesize())); err != nil { + t.Fatal(err) + } + + db, err = bolt.Open(path, 0666, nil) + if err == nil || err.Error() != "file size too small" { + t.Fatalf("unexpected error: %s", err) + } +} + +// Ensure that a database can be opened in read-only mode by multiple processes +// and that a database can not be opened in read-write mode and in read-only +// mode at the same time. +func TestOpen_ReadOnly(t *testing.T) { + if runtime.GOOS == "solaris" { + t.Skip("solaris fcntl locks don't support intra-process locking") + } + + bucket, key, value := []byte(`bucket`), []byte(`key`), []byte(`value`) + + path := tempfile() + + // Open in read-write mode. + db, err := bolt.Open(path, 0666, nil) + if err != nil { + t.Fatal(err) + } else if db.IsReadOnly() { + t.Fatal("db should not be in read only mode") + } + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket(bucket) + if err != nil { + return err + } + if err := b.Put(key, value); err != nil { + t.Fatal(err) + } + return nil + }); err != nil { + t.Fatal(err) + } + if err := db.Close(); err != nil { + t.Fatal(err) + } + + // Open in read-only mode. + db0, err := bolt.Open(path, 0666, &bolt.Options{ReadOnly: true}) + if err != nil { + t.Fatal(err) + } + + // Opening in read-write mode should return an error. + if _, err = bolt.Open(path, 0666, &bolt.Options{Timeout: time.Millisecond * 100}); err == nil { + t.Fatal("expected error") + } + + // And again (in read-only mode). + db1, err := bolt.Open(path, 0666, &bolt.Options{ReadOnly: true}) + if err != nil { + t.Fatal(err) + } + + // Verify both read-only databases are accessible. + for _, db := range []*bolt.DB{db0, db1} { + // Verify is is in read only mode indeed. + if !db.IsReadOnly() { + t.Fatal("expected read only mode") + } + + // Read-only databases should not allow updates. + if err := db.Update(func(*bolt.Tx) error { + panic(`should never get here`) + }); err != bolt.ErrDatabaseReadOnly { + t.Fatalf("unexpected error: %s", err) + } + + // Read-only databases should not allow beginning writable txns. + if _, err := db.Begin(true); err != bolt.ErrDatabaseReadOnly { + t.Fatalf("unexpected error: %s", err) + } + + // Verify the data. + if err := db.View(func(tx *bolt.Tx) error { + b := tx.Bucket(bucket) + if b == nil { + return fmt.Errorf("expected bucket `%s`", string(bucket)) + } + + got := string(b.Get(key)) + expected := string(value) + if got != expected { + return fmt.Errorf("expected `%s`, got `%s`", expected, got) + } + return nil + }); err != nil { + t.Fatal(err) + } + } + + if err := db0.Close(); err != nil { + t.Fatal(err) + } + if err := db1.Close(); err != nil { + t.Fatal(err) + } +} + +// TestDB_Open_InitialMmapSize tests if having InitialMmapSize large enough +// to hold data from concurrent write transaction resolves the issue that +// read transaction blocks the write transaction and causes deadlock. +// This is a very hacky test since the mmap size is not exposed. +func TestDB_Open_InitialMmapSize(t *testing.T) { path := tempfile() defer os.Remove(path) - db, err := bolt.Open(path, 0666, nil) - ok(t, err) - db.Close() + initMmapSize := 1 << 31 // 2GB + testWriteSize := 1 << 27 // 134MB - // corrupt the database - ok(t, os.Truncate(path, int64(os.Getpagesize()))) + db, err := bolt.Open(path, 0666, &bolt.Options{InitialMmapSize: initMmapSize}) + if err != nil { + t.Fatal(err) + } - db, err = bolt.Open(path, 0666, nil) - equals(t, errors.New("file size too small"), err) + // create a long-running read transaction + // that never gets closed while writing + rtx, err := db.Begin(false) + if err != nil { + t.Fatal(err) + } + + // create a write transaction + wtx, err := db.Begin(true) + if err != nil { + t.Fatal(err) + } + + b, err := wtx.CreateBucket([]byte("test")) + if err != nil { + t.Fatal(err) + } + + // and commit a large write + err = b.Put([]byte("foo"), make([]byte, testWriteSize)) + if err != nil { + t.Fatal(err) + } + + done := make(chan struct{}) + + go func() { + if err := wtx.Commit(); err != nil { + t.Fatal(err) + } + done <- struct{}{} + }() + + select { + case <-time.After(5 * time.Second): + t.Errorf("unexpected that the reader blocks writer") + case <-done: + } + + if err := rtx.Rollback(); err != nil { + t.Fatal(err) + } } -// TODO(benbjohnson): Test corruption at every byte of the first two pages. - // Ensure that a database cannot open a transaction when it's not open. -func TestDB_Begin_DatabaseNotOpen(t *testing.T) { +func TestDB_Begin_ErrDatabaseNotOpen(t *testing.T) { var db bolt.DB - tx, err := db.Begin(false) - assert(t, tx == nil, "") - equals(t, err, bolt.ErrDatabaseNotOpen) + if _, err := db.Begin(false); err != bolt.ErrDatabaseNotOpen { + t.Fatalf("unexpected error: %s", err) + } } // Ensure that a read-write transaction can be retrieved. func TestDB_BeginRW(t *testing.T) { - db := NewTestDB() - defer db.Close() + db := MustOpenDB() + defer db.MustClose() + tx, err := db.Begin(true) - assert(t, tx != nil, "") - ok(t, err) - assert(t, tx.DB() == db.DB, "") - equals(t, tx.Writable(), true) - ok(t, tx.Commit()) + if err != nil { + t.Fatal(err) + } else if tx == nil { + t.Fatal("expected tx") + } + + if tx.DB() != db.DB { + t.Fatal("unexpected tx database") + } else if !tx.Writable() { + t.Fatal("expected writable tx") + } + + if err := tx.Commit(); err != nil { + t.Fatal(err) + } } // Ensure that opening a transaction while the DB is closed returns an error. func TestDB_BeginRW_Closed(t *testing.T) { var db bolt.DB + if _, err := db.Begin(true); err != bolt.ErrDatabaseNotOpen { + t.Fatalf("unexpected error: %s", err) + } +} + +func TestDB_Close_PendingTx_RW(t *testing.T) { testDB_Close_PendingTx(t, true) } +func TestDB_Close_PendingTx_RO(t *testing.T) { testDB_Close_PendingTx(t, false) } + +// Ensure that a database cannot close while transactions are open. +func testDB_Close_PendingTx(t *testing.T, writable bool) { + db := MustOpenDB() + defer db.MustClose() + + // Start transaction. tx, err := db.Begin(true) - equals(t, err, bolt.ErrDatabaseNotOpen) - assert(t, tx == nil, "") + if err != nil { + t.Fatal(err) + } + + // Open update in separate goroutine. + done := make(chan struct{}) + go func() { + if err := db.Close(); err != nil { + t.Fatal(err) + } + close(done) + }() + + // Ensure database hasn't closed. + time.Sleep(100 * time.Millisecond) + select { + case <-done: + t.Fatal("database closed too early") + default: + } + + // Commit transaction. + if err := tx.Commit(); err != nil { + t.Fatal(err) + } + + // Ensure database closed now. + time.Sleep(100 * time.Millisecond) + select { + case <-done: + default: + t.Fatal("database did not close") + } } // Ensure a database can provide a transactional block. func TestDB_Update(t *testing.T) { - db := NewTestDB() - defer db.Close() - err := db.Update(func(tx *bolt.Tx) error { - tx.CreateBucket([]byte("widgets")) + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("foo"), []byte("bar")); err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("baz"), []byte("bat")); err != nil { + t.Fatal(err) + } + if err := b.Delete([]byte("foo")); err != nil { + t.Fatal(err) + } + return nil + }); err != nil { + t.Fatal(err) + } + if err := db.View(func(tx *bolt.Tx) error { b := tx.Bucket([]byte("widgets")) - b.Put([]byte("foo"), []byte("bar")) - b.Put([]byte("baz"), []byte("bat")) - b.Delete([]byte("foo")) + if v := b.Get([]byte("foo")); v != nil { + t.Fatalf("expected nil value, got: %v", v) + } + if v := b.Get([]byte("baz")); !bytes.Equal(v, []byte("bat")) { + t.Fatalf("unexpected value: %v", v) + } return nil - }) - ok(t, err) - err = db.View(func(tx *bolt.Tx) error { - assert(t, tx.Bucket([]byte("widgets")).Get([]byte("foo")) == nil, "") - equals(t, []byte("bat"), tx.Bucket([]byte("widgets")).Get([]byte("baz"))) - return nil - }) - ok(t, err) + }); err != nil { + t.Fatal(err) + } } // Ensure a closed database returns an error while running a transaction block func TestDB_Update_Closed(t *testing.T) { var db bolt.DB - err := db.Update(func(tx *bolt.Tx) error { - tx.CreateBucket([]byte("widgets")) + if err := db.Update(func(tx *bolt.Tx) error { + if _, err := tx.CreateBucket([]byte("widgets")); err != nil { + t.Fatal(err) + } return nil - }) - equals(t, err, bolt.ErrDatabaseNotOpen) + }); err != bolt.ErrDatabaseNotOpen { + t.Fatalf("unexpected error: %s", err) + } } // Ensure a panic occurs while trying to commit a managed transaction. func TestDB_Update_ManualCommit(t *testing.T) { - db := NewTestDB() - defer db.Close() + db := MustOpenDB() + defer db.MustClose() - var ok bool - db.Update(func(tx *bolt.Tx) error { + var panicked bool + if err := db.Update(func(tx *bolt.Tx) error { func() { defer func() { if r := recover(); r != nil { - ok = true + panicked = true } }() - tx.Commit() + + if err := tx.Commit(); err != nil { + t.Fatal(err) + } }() return nil - }) - assert(t, ok, "expected panic") + }); err != nil { + t.Fatal(err) + } else if !panicked { + t.Fatal("expected panic") + } } // Ensure a panic occurs while trying to rollback a managed transaction. func TestDB_Update_ManualRollback(t *testing.T) { - db := NewTestDB() - defer db.Close() + db := MustOpenDB() + defer db.MustClose() - var ok bool - db.Update(func(tx *bolt.Tx) error { + var panicked bool + if err := db.Update(func(tx *bolt.Tx) error { func() { defer func() { if r := recover(); r != nil { - ok = true + panicked = true } }() - tx.Rollback() + + if err := tx.Rollback(); err != nil { + t.Fatal(err) + } }() return nil - }) - assert(t, ok, "expected panic") + }); err != nil { + t.Fatal(err) + } else if !panicked { + t.Fatal("expected panic") + } } // Ensure a panic occurs while trying to commit a managed transaction. func TestDB_View_ManualCommit(t *testing.T) { - db := NewTestDB() - defer db.Close() + db := MustOpenDB() + defer db.MustClose() - var ok bool - db.Update(func(tx *bolt.Tx) error { + var panicked bool + if err := db.View(func(tx *bolt.Tx) error { func() { defer func() { if r := recover(); r != nil { - ok = true + panicked = true } }() - tx.Commit() + + if err := tx.Commit(); err != nil { + t.Fatal(err) + } }() return nil - }) - assert(t, ok, "expected panic") + }); err != nil { + t.Fatal(err) + } else if !panicked { + t.Fatal("expected panic") + } } // Ensure a panic occurs while trying to rollback a managed transaction. func TestDB_View_ManualRollback(t *testing.T) { - db := NewTestDB() - defer db.Close() + db := MustOpenDB() + defer db.MustClose() - var ok bool - db.Update(func(tx *bolt.Tx) error { + var panicked bool + if err := db.View(func(tx *bolt.Tx) error { func() { defer func() { if r := recover(); r != nil { - ok = true + panicked = true } }() - tx.Rollback() + + if err := tx.Rollback(); err != nil { + t.Fatal(err) + } }() return nil - }) - assert(t, ok, "expected panic") + }); err != nil { + t.Fatal(err) + } else if !panicked { + t.Fatal("expected panic") + } } // Ensure a write transaction that panics does not hold open locks. func TestDB_Update_Panic(t *testing.T) { - db := NewTestDB() - defer db.Close() + db := MustOpenDB() + defer db.MustClose() + // Panic during update but recover. func() { defer func() { if r := recover(); r != nil { t.Log("recover: update", r) } }() - db.Update(func(tx *bolt.Tx) error { - tx.CreateBucket([]byte("widgets")) + + if err := db.Update(func(tx *bolt.Tx) error { + if _, err := tx.CreateBucket([]byte("widgets")); err != nil { + t.Fatal(err) + } panic("omg") - }) + }); err != nil { + t.Fatal(err) + } }() // Verify we can update again. - err := db.Update(func(tx *bolt.Tx) error { - _, err := tx.CreateBucket([]byte("widgets")) - return err - }) - ok(t, err) + if err := db.Update(func(tx *bolt.Tx) error { + if _, err := tx.CreateBucket([]byte("widgets")); err != nil { + t.Fatal(err) + } + return nil + }); err != nil { + t.Fatal(err) + } // Verify that our change persisted. - err = db.Update(func(tx *bolt.Tx) error { - assert(t, tx.Bucket([]byte("widgets")) != nil, "") + if err := db.Update(func(tx *bolt.Tx) error { + if tx.Bucket([]byte("widgets")) == nil { + t.Fatal("expected bucket") + } return nil - }) + }); err != nil { + t.Fatal(err) + } } // Ensure a database can return an error through a read-only transactional block. func TestDB_View_Error(t *testing.T) { - db := NewTestDB() - defer db.Close() - err := db.View(func(tx *bolt.Tx) error { + db := MustOpenDB() + defer db.MustClose() + + if err := db.View(func(tx *bolt.Tx) error { return errors.New("xxx") - }) - equals(t, errors.New("xxx"), err) + }); err == nil || err.Error() != "xxx" { + t.Fatalf("unexpected error: %s", err) + } } // Ensure a read transaction that panics does not hold open locks. func TestDB_View_Panic(t *testing.T) { - db := NewTestDB() - defer db.Close() - db.Update(func(tx *bolt.Tx) error { - tx.CreateBucket([]byte("widgets")) - return nil - }) + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { + if _, err := tx.CreateBucket([]byte("widgets")); err != nil { + t.Fatal(err) + } + return nil + }); err != nil { + t.Fatal(err) + } + + // Panic during view transaction but recover. func() { defer func() { if r := recover(); r != nil { t.Log("recover: view", r) } }() - db.View(func(tx *bolt.Tx) error { - assert(t, tx.Bucket([]byte("widgets")) != nil, "") + + if err := db.View(func(tx *bolt.Tx) error { + if tx.Bucket([]byte("widgets")) == nil { + t.Fatal("expected bucket") + } panic("omg") - }) + }); err != nil { + t.Fatal(err) + } }() // Verify that we can still use read transactions. - db.View(func(tx *bolt.Tx) error { - assert(t, tx.Bucket([]byte("widgets")) != nil, "") + if err := db.View(func(tx *bolt.Tx) error { + if tx.Bucket([]byte("widgets")) == nil { + t.Fatal("expected bucket") + } return nil - }) -} - -// Ensure that an error is returned when a database write fails. -func TestDB_Commit_WriteFail(t *testing.T) { - t.Skip("pending") // TODO(benbjohnson) + }); err != nil { + t.Fatal(err) + } } // Ensure that DB stats can be returned. func TestDB_Stats(t *testing.T) { - db := NewTestDB() - defer db.Close() - db.Update(func(tx *bolt.Tx) error { + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { _, err := tx.CreateBucket([]byte("widgets")) return err - }) + }); err != nil { + t.Fatal(err) + } + stats := db.Stats() - equals(t, 2, stats.TxStats.PageCount) - equals(t, 0, stats.FreePageN) - equals(t, 2, stats.PendingPageN) + if stats.TxStats.PageCount != 2 { + t.Fatalf("unexpected TxStats.PageCount: %d", stats.TxStats.PageCount) + } else if stats.FreePageN != 0 { + t.Fatalf("unexpected FreePageN != 0: %d", stats.FreePageN) + } else if stats.PendingPageN != 2 { + t.Fatalf("unexpected PendingPageN != 2: %d", stats.PendingPageN) + } } // Ensure that database pages are in expected order and type. func TestDB_Consistency(t *testing.T) { - db := NewTestDB() - defer db.Close() - db.Update(func(tx *bolt.Tx) error { + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { _, err := tx.CreateBucket([]byte("widgets")) return err - }) + }); err != nil { + t.Fatal(err) + } for i := 0; i < 10; i++ { - db.Update(func(tx *bolt.Tx) error { - ok(t, tx.Bucket([]byte("widgets")).Put([]byte("foo"), []byte("bar"))) + if err := db.Update(func(tx *bolt.Tx) error { + if err := tx.Bucket([]byte("widgets")).Put([]byte("foo"), []byte("bar")); err != nil { + t.Fatal(err) + } return nil - }) + }); err != nil { + t.Fatal(err) + } } - db.Update(func(tx *bolt.Tx) error { - p, _ := tx.Page(0) - assert(t, p != nil, "") - equals(t, "meta", p.Type) - p, _ = tx.Page(1) - assert(t, p != nil, "") - equals(t, "meta", p.Type) + if err := db.Update(func(tx *bolt.Tx) error { + if p, _ := tx.Page(0); p == nil { + t.Fatal("expected page") + } else if p.Type != "meta" { + t.Fatalf("unexpected page type: %s", p.Type) + } - p, _ = tx.Page(2) - assert(t, p != nil, "") - equals(t, "free", p.Type) + if p, _ := tx.Page(1); p == nil { + t.Fatal("expected page") + } else if p.Type != "meta" { + t.Fatalf("unexpected page type: %s", p.Type) + } - p, _ = tx.Page(3) - assert(t, p != nil, "") - equals(t, "free", p.Type) + if p, _ := tx.Page(2); p == nil { + t.Fatal("expected page") + } else if p.Type != "free" { + t.Fatalf("unexpected page type: %s", p.Type) + } - p, _ = tx.Page(4) - assert(t, p != nil, "") - equals(t, "leaf", p.Type) + if p, _ := tx.Page(3); p == nil { + t.Fatal("expected page") + } else if p.Type != "free" { + t.Fatalf("unexpected page type: %s", p.Type) + } - p, _ = tx.Page(5) - assert(t, p != nil, "") - equals(t, "freelist", p.Type) + if p, _ := tx.Page(4); p == nil { + t.Fatal("expected page") + } else if p.Type != "leaf" { + t.Fatalf("unexpected page type: %s", p.Type) + } - p, _ = tx.Page(6) - assert(t, p == nil, "") + if p, _ := tx.Page(5); p == nil { + t.Fatal("expected page") + } else if p.Type != "freelist" { + t.Fatalf("unexpected page type: %s", p.Type) + } + + if p, _ := tx.Page(6); p != nil { + t.Fatal("unexpected page") + } return nil - }) + }); err != nil { + t.Fatal(err) + } } -// Ensure that DB stats can be substracted from one another. +// Ensure that DB stats can be subtracted from one another. func TestDBStats_Sub(t *testing.T) { var a, b bolt.Stats a.TxStats.PageCount = 3 @@ -507,19 +994,209 @@ func TestDBStats_Sub(t *testing.T) { b.TxStats.PageCount = 10 b.FreePageN = 14 diff := b.Sub(&a) - equals(t, 7, diff.TxStats.PageCount) + if diff.TxStats.PageCount != 7 { + t.Fatalf("unexpected TxStats.PageCount: %d", diff.TxStats.PageCount) + } + // free page stats are copied from the receiver and not subtracted - equals(t, 14, diff.FreePageN) + if diff.FreePageN != 14 { + t.Fatalf("unexpected FreePageN: %d", diff.FreePageN) + } +} + +// Ensure two functions can perform updates in a single batch. +func TestDB_Batch(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + + if err := db.Update(func(tx *bolt.Tx) error { + if _, err := tx.CreateBucket([]byte("widgets")); err != nil { + t.Fatal(err) + } + return nil + }); err != nil { + t.Fatal(err) + } + + // Iterate over multiple updates in separate goroutines. + n := 2 + ch := make(chan error) + for i := 0; i < n; i++ { + go func(i int) { + ch <- db.Batch(func(tx *bolt.Tx) error { + return tx.Bucket([]byte("widgets")).Put(u64tob(uint64(i)), []byte{}) + }) + }(i) + } + + // Check all responses to make sure there's no error. + for i := 0; i < n; i++ { + if err := <-ch; err != nil { + t.Fatal(err) + } + } + + // Ensure data is correct. + if err := db.View(func(tx *bolt.Tx) error { + b := tx.Bucket([]byte("widgets")) + for i := 0; i < n; i++ { + if v := b.Get(u64tob(uint64(i))); v == nil { + t.Errorf("key not found: %d", i) + } + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +func TestDB_Batch_Panic(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + + var sentinel int + var bork = &sentinel + var problem interface{} + var err error + + // Execute a function inside a batch that panics. + func() { + defer func() { + if p := recover(); p != nil { + problem = p + } + }() + err = db.Batch(func(tx *bolt.Tx) error { + panic(bork) + }) + }() + + // Verify there is no error. + if g, e := err, error(nil); g != e { + t.Fatalf("wrong error: %v != %v", g, e) + } + // Verify the panic was captured. + if g, e := problem, bork; g != e { + t.Fatalf("wrong error: %v != %v", g, e) + } +} + +func TestDB_BatchFull(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { + _, err := tx.CreateBucket([]byte("widgets")) + return err + }); err != nil { + t.Fatal(err) + } + + const size = 3 + // buffered so we never leak goroutines + ch := make(chan error, size) + put := func(i int) { + ch <- db.Batch(func(tx *bolt.Tx) error { + return tx.Bucket([]byte("widgets")).Put(u64tob(uint64(i)), []byte{}) + }) + } + + db.MaxBatchSize = size + // high enough to never trigger here + db.MaxBatchDelay = 1 * time.Hour + + go put(1) + go put(2) + + // Give the batch a chance to exhibit bugs. + time.Sleep(10 * time.Millisecond) + + // not triggered yet + select { + case <-ch: + t.Fatalf("batch triggered too early") + default: + } + + go put(3) + + // Check all responses to make sure there's no error. + for i := 0; i < size; i++ { + if err := <-ch; err != nil { + t.Fatal(err) + } + } + + // Ensure data is correct. + if err := db.View(func(tx *bolt.Tx) error { + b := tx.Bucket([]byte("widgets")) + for i := 1; i <= size; i++ { + if v := b.Get(u64tob(uint64(i))); v == nil { + t.Errorf("key not found: %d", i) + } + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +func TestDB_BatchTime(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { + _, err := tx.CreateBucket([]byte("widgets")) + return err + }); err != nil { + t.Fatal(err) + } + + const size = 1 + // buffered so we never leak goroutines + ch := make(chan error, size) + put := func(i int) { + ch <- db.Batch(func(tx *bolt.Tx) error { + return tx.Bucket([]byte("widgets")).Put(u64tob(uint64(i)), []byte{}) + }) + } + + db.MaxBatchSize = 1000 + db.MaxBatchDelay = 0 + + go put(1) + + // Batch must trigger by time alone. + + // Check all responses to make sure there's no error. + for i := 0; i < size; i++ { + if err := <-ch; err != nil { + t.Fatal(err) + } + } + + // Ensure data is correct. + if err := db.View(func(tx *bolt.Tx) error { + b := tx.Bucket([]byte("widgets")) + for i := 1; i <= size; i++ { + if v := b.Get(u64tob(uint64(i))); v == nil { + t.Errorf("key not found: %d", i) + } + } + return nil + }); err != nil { + t.Fatal(err) + } } func ExampleDB_Update() { // Open the database. - db, _ := bolt.Open(tempfile(), 0666, nil) + db, err := bolt.Open(tempfile(), 0666, nil) + if err != nil { + log.Fatal(err) + } defer os.Remove(db.Path()) - defer db.Close() - // Execute several commands within a write transaction. - err := db.Update(func(tx *bolt.Tx) error { + // Execute several commands within a read-write transaction. + if err := db.Update(func(tx *bolt.Tx) error { b, err := tx.CreateBucket([]byte("widgets")) if err != nil { return err @@ -528,15 +1205,22 @@ func ExampleDB_Update() { return err } return nil - }) + }); err != nil { + log.Fatal(err) + } - // If our transactional block didn't return an error then our data is saved. - if err == nil { - db.View(func(tx *bolt.Tx) error { - value := tx.Bucket([]byte("widgets")).Get([]byte("foo")) - fmt.Printf("The value of 'foo' is: %s\n", value) - return nil - }) + // Read the value back from a separate read-only transaction. + if err := db.View(func(tx *bolt.Tx) error { + value := tx.Bucket([]byte("widgets")).Get([]byte("foo")) + fmt.Printf("The value of 'foo' is: %s\n", value) + return nil + }); err != nil { + log.Fatal(err) + } + + // Close database to release the file lock. + if err := db.Close(); err != nil { + log.Fatal(err) } // Output: @@ -545,25 +1229,42 @@ func ExampleDB_Update() { func ExampleDB_View() { // Open the database. - db, _ := bolt.Open(tempfile(), 0666, nil) + db, err := bolt.Open(tempfile(), 0666, nil) + if err != nil { + log.Fatal(err) + } defer os.Remove(db.Path()) - defer db.Close() // Insert data into a bucket. - db.Update(func(tx *bolt.Tx) error { - tx.CreateBucket([]byte("people")) - b := tx.Bucket([]byte("people")) - b.Put([]byte("john"), []byte("doe")) - b.Put([]byte("susy"), []byte("que")) + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("people")) + if err != nil { + return err + } + if err := b.Put([]byte("john"), []byte("doe")); err != nil { + return err + } + if err := b.Put([]byte("susy"), []byte("que")); err != nil { + return err + } return nil - }) + }); err != nil { + log.Fatal(err) + } // Access data from within a read-only transactional block. - db.View(func(tx *bolt.Tx) error { + if err := db.View(func(tx *bolt.Tx) error { v := tx.Bucket([]byte("people")).Get([]byte("john")) fmt.Printf("John's last name is %s.\n", v) return nil - }) + }); err != nil { + log.Fatal(err) + } + + // Close database to release the file lock. + if err := db.Close(); err != nil { + log.Fatal(err) + } // Output: // John's last name is doe. @@ -571,31 +1272,56 @@ func ExampleDB_View() { func ExampleDB_Begin_ReadOnly() { // Open the database. - db, _ := bolt.Open(tempfile(), 0666, nil) + db, err := bolt.Open(tempfile(), 0666, nil) + if err != nil { + log.Fatal(err) + } defer os.Remove(db.Path()) - defer db.Close() - // Create a bucket. - db.Update(func(tx *bolt.Tx) error { + // Create a bucket using a read-write transaction. + if err := db.Update(func(tx *bolt.Tx) error { _, err := tx.CreateBucket([]byte("widgets")) return err - }) + }); err != nil { + log.Fatal(err) + } // Create several keys in a transaction. - tx, _ := db.Begin(true) + tx, err := db.Begin(true) + if err != nil { + log.Fatal(err) + } b := tx.Bucket([]byte("widgets")) - b.Put([]byte("john"), []byte("blue")) - b.Put([]byte("abby"), []byte("red")) - b.Put([]byte("zephyr"), []byte("purple")) - tx.Commit() + if err := b.Put([]byte("john"), []byte("blue")); err != nil { + log.Fatal(err) + } + if err := b.Put([]byte("abby"), []byte("red")); err != nil { + log.Fatal(err) + } + if err := b.Put([]byte("zephyr"), []byte("purple")); err != nil { + log.Fatal(err) + } + if err := tx.Commit(); err != nil { + log.Fatal(err) + } // Iterate over the values in sorted key order. - tx, _ = db.Begin(false) + tx, err = db.Begin(false) + if err != nil { + log.Fatal(err) + } c := tx.Bucket([]byte("widgets")).Cursor() for k, v := c.First(); k != nil; k, v = c.Next() { fmt.Printf("%s likes %s\n", k, v) } - tx.Rollback() + + if err := tx.Rollback(); err != nil { + log.Fatal(err) + } + + if err := db.Close(); err != nil { + log.Fatal(err) + } // Output: // abby likes red @@ -603,23 +1329,195 @@ func ExampleDB_Begin_ReadOnly() { // zephyr likes purple } -// TestDB represents a wrapper around a Bolt DB to handle temporary file -// creation and automatic cleanup on close. -type TestDB struct { +func BenchmarkDBBatchAutomatic(b *testing.B) { + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { + _, err := tx.CreateBucket([]byte("bench")) + return err + }); err != nil { + b.Fatal(err) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + start := make(chan struct{}) + var wg sync.WaitGroup + + for round := 0; round < 1000; round++ { + wg.Add(1) + + go func(id uint32) { + defer wg.Done() + <-start + + h := fnv.New32a() + buf := make([]byte, 4) + binary.LittleEndian.PutUint32(buf, id) + _, _ = h.Write(buf[:]) + k := h.Sum(nil) + insert := func(tx *bolt.Tx) error { + b := tx.Bucket([]byte("bench")) + return b.Put(k, []byte("filler")) + } + if err := db.Batch(insert); err != nil { + b.Error(err) + return + } + }(uint32(round)) + } + close(start) + wg.Wait() + } + + b.StopTimer() + validateBatchBench(b, db) +} + +func BenchmarkDBBatchSingle(b *testing.B) { + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { + _, err := tx.CreateBucket([]byte("bench")) + return err + }); err != nil { + b.Fatal(err) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + start := make(chan struct{}) + var wg sync.WaitGroup + + for round := 0; round < 1000; round++ { + wg.Add(1) + go func(id uint32) { + defer wg.Done() + <-start + + h := fnv.New32a() + buf := make([]byte, 4) + binary.LittleEndian.PutUint32(buf, id) + _, _ = h.Write(buf[:]) + k := h.Sum(nil) + insert := func(tx *bolt.Tx) error { + b := tx.Bucket([]byte("bench")) + return b.Put(k, []byte("filler")) + } + if err := db.Update(insert); err != nil { + b.Error(err) + return + } + }(uint32(round)) + } + close(start) + wg.Wait() + } + + b.StopTimer() + validateBatchBench(b, db) +} + +func BenchmarkDBBatchManual10x100(b *testing.B) { + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { + _, err := tx.CreateBucket([]byte("bench")) + return err + }); err != nil { + b.Fatal(err) + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + start := make(chan struct{}) + var wg sync.WaitGroup + + for major := 0; major < 10; major++ { + wg.Add(1) + go func(id uint32) { + defer wg.Done() + <-start + + insert100 := func(tx *bolt.Tx) error { + h := fnv.New32a() + buf := make([]byte, 4) + for minor := uint32(0); minor < 100; minor++ { + binary.LittleEndian.PutUint32(buf, uint32(id*100+minor)) + h.Reset() + _, _ = h.Write(buf[:]) + k := h.Sum(nil) + b := tx.Bucket([]byte("bench")) + if err := b.Put(k, []byte("filler")); err != nil { + return err + } + } + return nil + } + if err := db.Update(insert100); err != nil { + b.Fatal(err) + } + }(uint32(major)) + } + close(start) + wg.Wait() + } + + b.StopTimer() + validateBatchBench(b, db) +} + +func validateBatchBench(b *testing.B, db *DB) { + var rollback = errors.New("sentinel error to cause rollback") + validate := func(tx *bolt.Tx) error { + bucket := tx.Bucket([]byte("bench")) + h := fnv.New32a() + buf := make([]byte, 4) + for id := uint32(0); id < 1000; id++ { + binary.LittleEndian.PutUint32(buf, id) + h.Reset() + _, _ = h.Write(buf[:]) + k := h.Sum(nil) + v := bucket.Get(k) + if v == nil { + b.Errorf("not found id=%d key=%x", id, k) + continue + } + if g, e := v, []byte("filler"); !bytes.Equal(g, e) { + b.Errorf("bad value for id=%d key=%x: %s != %q", id, k, g, e) + } + if err := bucket.Delete(k); err != nil { + return err + } + } + // should be empty now + c := bucket.Cursor() + for k, v := c.First(); k != nil; k, v = c.Next() { + b.Errorf("unexpected key: %x = %q", k, v) + } + return rollback + } + if err := db.Update(validate); err != nil && err != rollback { + b.Error(err) + } +} + +// DB is a test wrapper for bolt.DB. +type DB struct { *bolt.DB } -// NewTestDB returns a new instance of TestDB. -func NewTestDB() *TestDB { +// MustOpenDB returns a new, open DB at a temporary location. +func MustOpenDB() *DB { db, err := bolt.Open(tempfile(), 0666, nil) if err != nil { - panic("cannot open db: " + err.Error()) + panic(err) } - return &TestDB{db} + return &DB{db} } // Close closes the database and deletes the underlying file. -func (db *TestDB) Close() { +func (db *DB) Close() error { // Log statistics. if *statsFlag { db.PrintStats() @@ -630,11 +1528,18 @@ func (db *TestDB) Close() { // Close database and remove file. defer os.Remove(db.Path()) - db.DB.Close() + return db.DB.Close() +} + +// MustClose closes the database and deletes the underlying file. Panic on error. +func (db *DB) MustClose() { + if err := db.Close(); err != nil { + panic(err) + } } // PrintStats prints the database stats -func (db *TestDB) PrintStats() { +func (db *DB) PrintStats() { var stats = db.Stats() fmt.Printf("[db] %-20s %-20s %-20s\n", fmt.Sprintf("pg(%d/%d)", stats.TxStats.PageCount, stats.TxStats.PageAlloc), @@ -649,8 +1554,8 @@ func (db *TestDB) PrintStats() { } // MustCheck runs a consistency check on the database and panics if any errors are found. -func (db *TestDB) MustCheck() { - db.View(func(tx *bolt.Tx) error { +func (db *DB) MustCheck() { + if err := db.Update(func(tx *bolt.Tx) error { // Collect all the errors. var errors []error for err := range tx.Check() { @@ -663,7 +1568,9 @@ func (db *TestDB) MustCheck() { // If errors occurred, copy the DB and print the errors. if len(errors) > 0 { var path = tempfile() - tx.CopyFile(path, 0600) + if err := tx.CopyFile(path, 0600); err != nil { + panic(err) + } // Print errors. fmt.Print("\n\n") @@ -679,31 +1586,46 @@ func (db *TestDB) MustCheck() { } return nil - }) + }); err != nil && err != bolt.ErrDatabaseNotOpen { + panic(err) + } } // CopyTempFile copies a database to a temporary file. -func (db *TestDB) CopyTempFile() { +func (db *DB) CopyTempFile() { path := tempfile() - db.View(func(tx *bolt.Tx) error { return tx.CopyFile(path, 0600) }) + if err := db.View(func(tx *bolt.Tx) error { + return tx.CopyFile(path, 0600) + }); err != nil { + panic(err) + } fmt.Println("db copied to: ", path) } // tempfile returns a temporary file path. func tempfile() string { - f, _ := ioutil.TempFile("", "bolt-") - f.Close() - os.Remove(f.Name()) + f, err := ioutil.TempFile("", "bolt-") + if err != nil { + panic(err) + } + if err := f.Close(); err != nil { + panic(err) + } + if err := os.Remove(f.Name()); err != nil { + panic(err) + } return f.Name() } // mustContainKeys checks that a bucket contains a given set of keys. func mustContainKeys(b *bolt.Bucket, m map[string]string) { found := make(map[string]string) - b.ForEach(func(k, _ []byte) error { + if err := b.ForEach(func(k, _ []byte) error { found[string(k)] = "" return nil - }) + }); err != nil { + panic(err) + } // Check for keys found in bucket that shouldn't be there. var keys []string diff --git a/Godeps/_workspace/src/github.com/boltdb/bolt/errors.go b/Godeps/_workspace/src/github.com/boltdb/bolt/errors.go index aa504f13..6883786d 100644 --- a/Godeps/_workspace/src/github.com/boltdb/bolt/errors.go +++ b/Godeps/_workspace/src/github.com/boltdb/bolt/errors.go @@ -36,6 +36,10 @@ var ( // ErrTxClosed is returned when committing or rolling back a transaction // that has already been committed or rolled back. ErrTxClosed = errors.New("tx closed") + + // ErrDatabaseReadOnly is returned when a mutating transaction is started on a + // read-only database. + ErrDatabaseReadOnly = errors.New("database is in read-only mode") ) // These errors can occur when putting or deleting a value or a bucket. diff --git a/Godeps/_workspace/src/github.com/boltdb/bolt/freelist.go b/Godeps/_workspace/src/github.com/boltdb/bolt/freelist.go index 1346e82e..0161948f 100644 --- a/Godeps/_workspace/src/github.com/boltdb/bolt/freelist.go +++ b/Godeps/_workspace/src/github.com/boltdb/bolt/freelist.go @@ -48,15 +48,14 @@ func (f *freelist) pending_count() int { // all returns a list of all free ids and all pending ids in one sorted list. func (f *freelist) all() []pgid { - ids := make([]pgid, len(f.ids)) - copy(ids, f.ids) + m := make(pgids, 0) for _, list := range f.pending { - ids = append(ids, list...) + m = append(m, list...) } - sort.Sort(pgids(ids)) - return ids + sort.Sort(m) + return pgids(f.ids).merge(m) } // allocate returns the starting page id of a contiguous list of pages of a given size. @@ -127,15 +126,17 @@ func (f *freelist) free(txid txid, p *page) { // release moves all page ids for a transaction id (or older) to the freelist. func (f *freelist) release(txid txid) { + m := make(pgids, 0) for tid, ids := range f.pending { if tid <= txid { // Move transaction's pending pages to the available freelist. // Don't remove from the cache since the page is still free. - f.ids = append(f.ids, ids...) + m = append(m, ids...) delete(f.pending, tid) } } - sort.Sort(pgids(f.ids)) + sort.Sort(m) + f.ids = pgids(f.ids).merge(m) } // rollback removes the pages from a given pending tx. diff --git a/Godeps/_workspace/src/github.com/boltdb/bolt/freelist_test.go b/Godeps/_workspace/src/github.com/boltdb/bolt/freelist_test.go index 792ca922..4e9b3a8d 100644 --- a/Godeps/_workspace/src/github.com/boltdb/bolt/freelist_test.go +++ b/Godeps/_workspace/src/github.com/boltdb/bolt/freelist_test.go @@ -1,7 +1,9 @@ package bolt import ( + "math/rand" "reflect" + "sort" "testing" "unsafe" ) @@ -115,7 +117,9 @@ func TestFreelist_write(t *testing.T) { f.pending[100] = []pgid{28, 11} f.pending[101] = []pgid{3} p := (*page)(unsafe.Pointer(&buf[0])) - f.write(p) + if err := f.write(p); err != nil { + t.Fatal(err) + } // Read the page back out. f2 := newFreelist() @@ -127,3 +131,28 @@ func TestFreelist_write(t *testing.T) { t.Fatalf("exp=%v; got=%v", exp, f2.ids) } } + +func Benchmark_FreelistRelease10K(b *testing.B) { benchmark_FreelistRelease(b, 10000) } +func Benchmark_FreelistRelease100K(b *testing.B) { benchmark_FreelistRelease(b, 100000) } +func Benchmark_FreelistRelease1000K(b *testing.B) { benchmark_FreelistRelease(b, 1000000) } +func Benchmark_FreelistRelease10000K(b *testing.B) { benchmark_FreelistRelease(b, 10000000) } + +func benchmark_FreelistRelease(b *testing.B, size int) { + ids := randomPgids(size) + pending := randomPgids(len(ids) / 400) + b.ResetTimer() + for i := 0; i < b.N; i++ { + f := &freelist{ids: ids, pending: map[txid][]pgid{1: pending}} + f.release(1) + } +} + +func randomPgids(n int) []pgid { + rand.Seed(42) + pgids := make(pgids, n) + for i := range pgids { + pgids[i] = pgid(rand.Int63()) + } + sort.Sort(pgids) + return pgids +} diff --git a/Godeps/_workspace/src/github.com/boltdb/bolt/node.go b/Godeps/_workspace/src/github.com/boltdb/bolt/node.go index 05aefb8a..c9fb21c7 100644 --- a/Godeps/_workspace/src/github.com/boltdb/bolt/node.go +++ b/Godeps/_workspace/src/github.com/boltdb/bolt/node.go @@ -221,11 +221,20 @@ func (n *node) write(p *page) { _assert(elem.pgid != p.id, "write: circular dependency occurred") } + // If the length of key+value is larger than the max allocation size + // then we need to reallocate the byte array pointer. + // + // See: https://github.com/boltdb/bolt/pull/335 + klen, vlen := len(item.key), len(item.value) + if len(b) < klen+vlen { + b = (*[maxAllocSize]byte)(unsafe.Pointer(&b[0]))[:] + } + // Write data for the element to the end of the page. copy(b[0:], item.key) - b = b[len(item.key):] + b = b[klen:] copy(b[0:], item.value) - b = b[len(item.value):] + b = b[vlen:] } // DEBUG ONLY: n.dump() diff --git a/Godeps/_workspace/src/github.com/boltdb/bolt/page.go b/Godeps/_workspace/src/github.com/boltdb/bolt/page.go index 58e43c4b..818aa1b1 100644 --- a/Godeps/_workspace/src/github.com/boltdb/bolt/page.go +++ b/Godeps/_workspace/src/github.com/boltdb/bolt/page.go @@ -3,6 +3,7 @@ package bolt import ( "fmt" "os" + "sort" "unsafe" ) @@ -96,7 +97,7 @@ type branchPageElement struct { // key returns a byte slice of the node key. func (n *branchPageElement) key() []byte { buf := (*[maxAllocSize]byte)(unsafe.Pointer(n)) - return buf[n.pos : n.pos+n.ksize] + return (*[maxAllocSize]byte)(unsafe.Pointer(&buf[n.pos]))[:n.ksize] } // leafPageElement represents a node on a leaf page. @@ -110,13 +111,13 @@ type leafPageElement struct { // key returns a byte slice of the node key. func (n *leafPageElement) key() []byte { buf := (*[maxAllocSize]byte)(unsafe.Pointer(n)) - return buf[n.pos : n.pos+n.ksize] + return (*[maxAllocSize]byte)(unsafe.Pointer(&buf[n.pos]))[:n.ksize] } // value returns a byte slice of the node value. func (n *leafPageElement) value() []byte { buf := (*[maxAllocSize]byte)(unsafe.Pointer(n)) - return buf[n.pos+n.ksize : n.pos+n.ksize+n.vsize] + return (*[maxAllocSize]byte)(unsafe.Pointer(&buf[n.pos+n.ksize]))[:n.vsize] } // PageInfo represents human readable information about a page. @@ -132,3 +133,40 @@ type pgids []pgid func (s pgids) Len() int { return len(s) } func (s pgids) Swap(i, j int) { s[i], s[j] = s[j], s[i] } func (s pgids) Less(i, j int) bool { return s[i] < s[j] } + +// merge returns the sorted union of a and b. +func (a pgids) merge(b pgids) pgids { + // Return the opposite slice if one is nil. + if len(a) == 0 { + return b + } else if len(b) == 0 { + return a + } + + // Create a list to hold all elements from both lists. + merged := make(pgids, 0, len(a)+len(b)) + + // Assign lead to the slice with a lower starting value, follow to the higher value. + lead, follow := a, b + if b[0] < a[0] { + lead, follow = b, a + } + + // Continue while there are elements in the lead. + for len(lead) > 0 { + // Merge largest prefix of lead that is ahead of follow[0]. + n := sort.Search(len(lead), func(i int) bool { return lead[i] > follow[0] }) + merged = append(merged, lead[:n]...) + if n >= len(lead) { + break + } + + // Swap lead and follow. + lead, follow = follow, lead[n:] + } + + // Append what's left in follow. + merged = append(merged, follow...) + + return merged +} diff --git a/Godeps/_workspace/src/github.com/boltdb/bolt/page_test.go b/Godeps/_workspace/src/github.com/boltdb/bolt/page_test.go index 7a4d327f..59f4a30e 100644 --- a/Godeps/_workspace/src/github.com/boltdb/bolt/page_test.go +++ b/Godeps/_workspace/src/github.com/boltdb/bolt/page_test.go @@ -1,7 +1,10 @@ package bolt import ( + "reflect" + "sort" "testing" + "testing/quick" ) // Ensure that the page type can be returned in human readable format. @@ -27,3 +30,43 @@ func TestPage_typ(t *testing.T) { func TestPage_dump(t *testing.T) { (&page{id: 256}).hexdump(16) } + +func TestPgids_merge(t *testing.T) { + a := pgids{4, 5, 6, 10, 11, 12, 13, 27} + b := pgids{1, 3, 8, 9, 25, 30} + c := a.merge(b) + if !reflect.DeepEqual(c, pgids{1, 3, 4, 5, 6, 8, 9, 10, 11, 12, 13, 25, 27, 30}) { + t.Errorf("mismatch: %v", c) + } + + a = pgids{4, 5, 6, 10, 11, 12, 13, 27, 35, 36} + b = pgids{8, 9, 25, 30} + c = a.merge(b) + if !reflect.DeepEqual(c, pgids{4, 5, 6, 8, 9, 10, 11, 12, 13, 25, 27, 30, 35, 36}) { + t.Errorf("mismatch: %v", c) + } +} + +func TestPgids_merge_quick(t *testing.T) { + if err := quick.Check(func(a, b pgids) bool { + // Sort incoming lists. + sort.Sort(a) + sort.Sort(b) + + // Merge the two lists together. + got := a.merge(b) + + // The expected value should be the two lists combined and sorted. + exp := append(a, b...) + sort.Sort(exp) + + if !reflect.DeepEqual(exp, got) { + t.Errorf("\nexp=%+v\ngot=%+v\n", exp, got) + return false + } + + return true + }, nil); err != nil { + t.Fatal(err) + } +} diff --git a/Godeps/_workspace/src/github.com/boltdb/bolt/simulation_test.go b/Godeps/_workspace/src/github.com/boltdb/bolt/simulation_test.go index 7d0e917d..ba9ea631 100644 --- a/Godeps/_workspace/src/github.com/boltdb/bolt/simulation_test.go +++ b/Godeps/_workspace/src/github.com/boltdb/bolt/simulation_test.go @@ -10,7 +10,7 @@ import ( "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/boltdb/bolt" ) -func TestSimulate_1op_1p(t *testing.T) { testSimulate(t, 100, 1) } +func TestSimulate_1op_1p(t *testing.T) { testSimulate(t, 1, 1) } func TestSimulate_10op_1p(t *testing.T) { testSimulate(t, 10, 1) } func TestSimulate_100op_1p(t *testing.T) { testSimulate(t, 100, 1) } func TestSimulate_1000op_1p(t *testing.T) { testSimulate(t, 1000, 1) } @@ -42,8 +42,8 @@ func testSimulate(t *testing.T, threadCount, parallelism int) { var versions = make(map[int]*QuickDB) versions[1] = NewQuickDB() - db := NewTestDB() - defer db.Close() + db := MustOpenDB() + defer db.MustClose() var mutex sync.Mutex @@ -89,10 +89,12 @@ func testSimulate(t *testing.T, threadCount, parallelism int) { versions[tx.ID()] = qdb mutex.Unlock() - ok(t, tx.Commit()) + if err := tx.Commit(); err != nil { + t.Fatal(err) + } }() } else { - defer tx.Rollback() + defer func() { _ = tx.Rollback() }() } // Ignore operation if we don't have data yet. diff --git a/Godeps/_workspace/src/github.com/boltdb/bolt/tx.go b/Godeps/_workspace/src/github.com/boltdb/bolt/tx.go index c041d738..e74d2cae 100644 --- a/Godeps/_workspace/src/github.com/boltdb/bolt/tx.go +++ b/Godeps/_workspace/src/github.com/boltdb/bolt/tx.go @@ -29,6 +29,14 @@ type Tx struct { pages map[pgid]*page stats TxStats commitHandlers []func() + + // WriteFlag specifies the flag for write-related methods like WriteTo(). + // Tx opens the database file with the specified flag to copy the data. + // + // By default, the flag is unset, which works well for mostly in-memory + // workloads. For databases that are much larger than available RAM, + // set the flag to syscall.O_DIRECT to avoid trashing the page cache. + WriteFlag int } // init initializes the transaction. @@ -87,18 +95,21 @@ func (tx *Tx) Stats() TxStats { // Bucket retrieves a bucket by name. // Returns nil if the bucket does not exist. +// The bucket instance is only valid for the lifetime of the transaction. func (tx *Tx) Bucket(name []byte) *Bucket { return tx.root.Bucket(name) } // CreateBucket creates a new bucket. // Returns an error if the bucket already exists, if the bucket name is blank, or if the bucket name is too long. +// The bucket instance is only valid for the lifetime of the transaction. func (tx *Tx) CreateBucket(name []byte) (*Bucket, error) { return tx.root.CreateBucket(name) } // CreateBucketIfNotExists creates a new bucket if it doesn't already exist. // Returns an error if the bucket name is blank, or if the bucket name is too long. +// The bucket instance is only valid for the lifetime of the transaction. func (tx *Tx) CreateBucketIfNotExists(name []byte) (*Bucket, error) { return tx.root.CreateBucketIfNotExists(name) } @@ -127,7 +138,8 @@ func (tx *Tx) OnCommit(fn func()) { } // Commit writes all changes to disk and updates the meta page. -// Returns an error if a disk write error occurs. +// Returns an error if a disk write error occurs, or if Commit is +// called on a read-only transaction. func (tx *Tx) Commit() error { _assert(!tx.managed, "managed tx commit not allowed") if tx.db == nil { @@ -156,6 +168,8 @@ func (tx *Tx) Commit() error { // Free the old root bucket. tx.meta.root.root = tx.root.root + opgid := tx.meta.pgid + // Free the freelist and allocate new pages for it. This will overestimate // the size of the freelist but not underestimate the size (which would be bad). tx.db.freelist.free(tx.meta.txid, tx.db.page(tx.meta.freelist)) @@ -170,6 +184,14 @@ func (tx *Tx) Commit() error { } tx.meta.freelist = p.id + // If the high water mark has moved up then attempt to grow the database. + if tx.meta.pgid > opgid { + if err := tx.db.grow(int(tx.meta.pgid+1) * tx.db.pageSize); err != nil { + tx.rollback() + return err + } + } + // Write dirty pages to disk. startTime = time.Now() if err := tx.write(); err != nil { @@ -203,7 +225,8 @@ func (tx *Tx) Commit() error { return nil } -// Rollback closes the transaction and ignores all previous updates. +// Rollback closes the transaction and ignores all previous updates. Read-only +// transactions must be rolled back and not committed. func (tx *Tx) Rollback() error { _assert(!tx.managed, "managed tx rollback not allowed") if tx.db == nil { @@ -234,7 +257,8 @@ func (tx *Tx) close() { var freelistPendingN = tx.db.freelist.pending_count() var freelistAlloc = tx.db.freelist.size() - // Remove writer lock. + // Remove transaction ref & writer lock. + tx.db.rwtx = nil tx.db.rwlock.Unlock() // Merge statistics. @@ -248,41 +272,47 @@ func (tx *Tx) close() { } else { tx.db.removeTx(tx) } + + // Clear all references. tx.db = nil + tx.meta = nil + tx.root = Bucket{tx: tx} + tx.pages = nil } // Copy writes the entire database to a writer. -// A reader transaction is maintained during the copy so it is safe to continue -// using the database while a copy is in progress. -// Copy will write exactly tx.Size() bytes into the writer. +// This function exists for backwards compatibility. Use WriteTo() instead. func (tx *Tx) Copy(w io.Writer) error { - var f *os.File - var err error + _, err := tx.WriteTo(w) + return err +} - // Attempt to open reader directly. - if f, err = os.OpenFile(tx.db.path, os.O_RDONLY|odirect, 0); err != nil { - // Fallback to a regular open if that doesn't work. - if f, err = os.OpenFile(tx.db.path, os.O_RDONLY, 0); err != nil { - return err - } +// WriteTo writes the entire database to a writer. +// If err == nil then exactly tx.Size() bytes will be written into the writer. +func (tx *Tx) WriteTo(w io.Writer) (n int64, err error) { + // Attempt to open reader with WriteFlag + f, err := os.OpenFile(tx.db.path, os.O_RDONLY|tx.WriteFlag, 0) + if err != nil { + return 0, err } + defer func() { _ = f.Close() }() // Copy the meta pages. tx.db.metalock.Lock() - _, err = io.CopyN(w, f, int64(tx.db.pageSize*2)) + n, err = io.CopyN(w, f, int64(tx.db.pageSize*2)) tx.db.metalock.Unlock() if err != nil { - _ = f.Close() - return fmt.Errorf("meta copy: %s", err) + return n, fmt.Errorf("meta copy: %s", err) } // Copy data pages. - if _, err := io.CopyN(w, f, tx.Size()-int64(tx.db.pageSize*2)); err != nil { - _ = f.Close() - return err + wn, err := io.CopyN(w, f, tx.Size()-int64(tx.db.pageSize*2)) + n += wn + if err != nil { + return n, err } - return f.Close() + return n, f.Close() } // CopyFile copies the entire database to file at the given path. @@ -416,15 +446,39 @@ func (tx *Tx) write() error { // Write pages to disk in order. for _, p := range pages { size := (int(p.overflow) + 1) * tx.db.pageSize - buf := (*[maxAllocSize]byte)(unsafe.Pointer(p))[:size] offset := int64(p.id) * int64(tx.db.pageSize) - if _, err := tx.db.ops.writeAt(buf, offset); err != nil { - return err - } - // Update statistics. - tx.stats.Write++ + // Write out page in "max allocation" sized chunks. + ptr := (*[maxAllocSize]byte)(unsafe.Pointer(p)) + for { + // Limit our write to our max allocation size. + sz := size + if sz > maxAllocSize-1 { + sz = maxAllocSize - 1 + } + + // Write chunk to disk. + buf := ptr[:sz] + if _, err := tx.db.ops.writeAt(buf, offset); err != nil { + return err + } + + // Update statistics. + tx.stats.Write++ + + // Exit inner for loop if we've written all the chunks. + size -= sz + if size == 0 { + break + } + + // Otherwise move offset forward and move pointer to next chunk. + offset += int64(sz) + ptr = (*[maxAllocSize]byte)(unsafe.Pointer(&ptr[sz])) + } } + + // Ignore file sync if flag is set on DB. if !tx.db.NoSync || IgnoreNoSync { if err := fdatasync(tx.db); err != nil { return err @@ -461,7 +515,7 @@ func (tx *Tx) writeMeta() error { } // page returns a reference to the page with a given id. -// If page has been written to then a temporary bufferred page is returned. +// If page has been written to then a temporary buffered page is returned. func (tx *Tx) page(id pgid) *page { // Check the dirty pages first. if tx.pages != nil { diff --git a/Godeps/_workspace/src/github.com/boltdb/bolt/tx_test.go b/Godeps/_workspace/src/github.com/boltdb/bolt/tx_test.go index 61dd03bf..d99de8c9 100644 --- a/Godeps/_workspace/src/github.com/boltdb/bolt/tx_test.go +++ b/Godeps/_workspace/src/github.com/boltdb/bolt/tx_test.go @@ -1,8 +1,10 @@ package bolt_test import ( + "bytes" "errors" "fmt" + "log" "os" "testing" @@ -10,299 +12,519 @@ import ( ) // Ensure that committing a closed transaction returns an error. -func TestTx_Commit_Closed(t *testing.T) { - db := NewTestDB() - defer db.Close() - tx, _ := db.Begin(true) - tx.CreateBucket([]byte("foo")) - ok(t, tx.Commit()) - equals(t, tx.Commit(), bolt.ErrTxClosed) +func TestTx_Commit_ErrTxClosed(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + tx, err := db.Begin(true) + if err != nil { + t.Fatal(err) + } + + if _, err := tx.CreateBucket([]byte("foo")); err != nil { + t.Fatal(err) + } + + if err := tx.Commit(); err != nil { + t.Fatal(err) + } + + if err := tx.Commit(); err != bolt.ErrTxClosed { + t.Fatalf("unexpected error: %s", err) + } } // Ensure that rolling back a closed transaction returns an error. -func TestTx_Rollback_Closed(t *testing.T) { - db := NewTestDB() - defer db.Close() - tx, _ := db.Begin(true) - ok(t, tx.Rollback()) - equals(t, tx.Rollback(), bolt.ErrTxClosed) +func TestTx_Rollback_ErrTxClosed(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + + tx, err := db.Begin(true) + if err != nil { + t.Fatal(err) + } + + if err := tx.Rollback(); err != nil { + t.Fatal(err) + } + if err := tx.Rollback(); err != bolt.ErrTxClosed { + t.Fatalf("unexpected error: %s", err) + } } // Ensure that committing a read-only transaction returns an error. -func TestTx_Commit_ReadOnly(t *testing.T) { - db := NewTestDB() - defer db.Close() - tx, _ := db.Begin(false) - equals(t, tx.Commit(), bolt.ErrTxNotWritable) +func TestTx_Commit_ErrTxNotWritable(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + tx, err := db.Begin(false) + if err != nil { + t.Fatal(err) + } + if err := tx.Commit(); err != bolt.ErrTxNotWritable { + t.Fatal(err) + } } // Ensure that a transaction can retrieve a cursor on the root bucket. func TestTx_Cursor(t *testing.T) { - db := NewTestDB() - defer db.Close() - db.Update(func(tx *bolt.Tx) error { - tx.CreateBucket([]byte("widgets")) - tx.CreateBucket([]byte("woojits")) + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { + if _, err := tx.CreateBucket([]byte("widgets")); err != nil { + t.Fatal(err) + } + + if _, err := tx.CreateBucket([]byte("woojits")); err != nil { + t.Fatal(err) + } + c := tx.Cursor() + if k, v := c.First(); !bytes.Equal(k, []byte("widgets")) { + t.Fatalf("unexpected key: %v", k) + } else if v != nil { + t.Fatalf("unexpected value: %v", v) + } - k, v := c.First() - equals(t, "widgets", string(k)) - assert(t, v == nil, "") + if k, v := c.Next(); !bytes.Equal(k, []byte("woojits")) { + t.Fatalf("unexpected key: %v", k) + } else if v != nil { + t.Fatalf("unexpected value: %v", v) + } - k, v = c.Next() - equals(t, "woojits", string(k)) - assert(t, v == nil, "") - - k, v = c.Next() - assert(t, k == nil, "") - assert(t, v == nil, "") + if k, v := c.Next(); k != nil { + t.Fatalf("unexpected key: %v", k) + } else if v != nil { + t.Fatalf("unexpected value: %v", k) + } return nil - }) + }); err != nil { + t.Fatal(err) + } } // Ensure that creating a bucket with a read-only transaction returns an error. -func TestTx_CreateBucket_ReadOnly(t *testing.T) { - db := NewTestDB() - defer db.Close() - db.View(func(tx *bolt.Tx) error { - b, err := tx.CreateBucket([]byte("foo")) - assert(t, b == nil, "") - equals(t, bolt.ErrTxNotWritable, err) +func TestTx_CreateBucket_ErrTxNotWritable(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + if err := db.View(func(tx *bolt.Tx) error { + _, err := tx.CreateBucket([]byte("foo")) + if err != bolt.ErrTxNotWritable { + t.Fatalf("unexpected error: %s", err) + } return nil - }) + }); err != nil { + t.Fatal(err) + } } // Ensure that creating a bucket on a closed transaction returns an error. -func TestTx_CreateBucket_Closed(t *testing.T) { - db := NewTestDB() - defer db.Close() - tx, _ := db.Begin(true) - tx.Commit() - b, err := tx.CreateBucket([]byte("foo")) - assert(t, b == nil, "") - equals(t, bolt.ErrTxClosed, err) +func TestTx_CreateBucket_ErrTxClosed(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + tx, err := db.Begin(true) + if err != nil { + t.Fatal(err) + } + if err := tx.Commit(); err != nil { + t.Fatal(err) + } + + if _, err := tx.CreateBucket([]byte("foo")); err != bolt.ErrTxClosed { + t.Fatalf("unexpected error: %s", err) + } } // Ensure that a Tx can retrieve a bucket. func TestTx_Bucket(t *testing.T) { - db := NewTestDB() - defer db.Close() - db.Update(func(tx *bolt.Tx) error { - tx.CreateBucket([]byte("widgets")) - b := tx.Bucket([]byte("widgets")) - assert(t, b != nil, "") + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { + if _, err := tx.CreateBucket([]byte("widgets")); err != nil { + t.Fatal(err) + } + if tx.Bucket([]byte("widgets")) == nil { + t.Fatal("expected bucket") + } return nil - }) + }); err != nil { + t.Fatal(err) + } } // Ensure that a Tx retrieving a non-existent key returns nil. -func TestTx_Get_Missing(t *testing.T) { - db := NewTestDB() - defer db.Close() - db.Update(func(tx *bolt.Tx) error { - tx.CreateBucket([]byte("widgets")) - tx.Bucket([]byte("widgets")).Put([]byte("foo"), []byte("bar")) - value := tx.Bucket([]byte("widgets")).Get([]byte("no_such_key")) - assert(t, value == nil, "") +func TestTx_Get_NotFound(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + + if err := b.Put([]byte("foo"), []byte("bar")); err != nil { + t.Fatal(err) + } + if b.Get([]byte("no_such_key")) != nil { + t.Fatal("expected nil value") + } return nil - }) + }); err != nil { + t.Fatal(err) + } } // Ensure that a bucket can be created and retrieved. func TestTx_CreateBucket(t *testing.T) { - db := NewTestDB() - defer db.Close() + db := MustOpenDB() + defer db.MustClose() // Create a bucket. - db.Update(func(tx *bolt.Tx) error { + if err := db.Update(func(tx *bolt.Tx) error { b, err := tx.CreateBucket([]byte("widgets")) - assert(t, b != nil, "") - ok(t, err) + if err != nil { + t.Fatal(err) + } else if b == nil { + t.Fatal("expected bucket") + } return nil - }) + }); err != nil { + t.Fatal(err) + } // Read the bucket through a separate transaction. - db.View(func(tx *bolt.Tx) error { - b := tx.Bucket([]byte("widgets")) - assert(t, b != nil, "") + if err := db.View(func(tx *bolt.Tx) error { + if tx.Bucket([]byte("widgets")) == nil { + t.Fatal("expected bucket") + } return nil - }) + }); err != nil { + t.Fatal(err) + } } // Ensure that a bucket can be created if it doesn't already exist. func TestTx_CreateBucketIfNotExists(t *testing.T) { - db := NewTestDB() - defer db.Close() - db.Update(func(tx *bolt.Tx) error { - b, err := tx.CreateBucketIfNotExists([]byte("widgets")) - assert(t, b != nil, "") - ok(t, err) + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { + // Create bucket. + if b, err := tx.CreateBucketIfNotExists([]byte("widgets")); err != nil { + t.Fatal(err) + } else if b == nil { + t.Fatal("expected bucket") + } - b, err = tx.CreateBucketIfNotExists([]byte("widgets")) - assert(t, b != nil, "") - ok(t, err) + // Create bucket again. + if b, err := tx.CreateBucketIfNotExists([]byte("widgets")); err != nil { + t.Fatal(err) + } else if b == nil { + t.Fatal("expected bucket") + } - b, err = tx.CreateBucketIfNotExists([]byte{}) - assert(t, b == nil, "") - equals(t, bolt.ErrBucketNameRequired, err) - - b, err = tx.CreateBucketIfNotExists(nil) - assert(t, b == nil, "") - equals(t, bolt.ErrBucketNameRequired, err) return nil - }) + }); err != nil { + t.Fatal(err) + } // Read the bucket through a separate transaction. - db.View(func(tx *bolt.Tx) error { - b := tx.Bucket([]byte("widgets")) - assert(t, b != nil, "") + if err := db.View(func(tx *bolt.Tx) error { + if tx.Bucket([]byte("widgets")) == nil { + t.Fatal("expected bucket") + } return nil - }) + }); err != nil { + t.Fatal(err) + } +} + +// Ensure transaction returns an error if creating an unnamed bucket. +func TestTx_CreateBucketIfNotExists_ErrBucketNameRequired(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { + if _, err := tx.CreateBucketIfNotExists([]byte{}); err != bolt.ErrBucketNameRequired { + t.Fatalf("unexpected error: %s", err) + } + + if _, err := tx.CreateBucketIfNotExists(nil); err != bolt.ErrBucketNameRequired { + t.Fatalf("unexpected error: %s", err) + } + + return nil + }); err != nil { + t.Fatal(err) + } } // Ensure that a bucket cannot be created twice. -func TestTx_CreateBucket_Exists(t *testing.T) { - db := NewTestDB() - defer db.Close() +func TestTx_CreateBucket_ErrBucketExists(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + // Create a bucket. - db.Update(func(tx *bolt.Tx) error { - b, err := tx.CreateBucket([]byte("widgets")) - assert(t, b != nil, "") - ok(t, err) + if err := db.Update(func(tx *bolt.Tx) error { + if _, err := tx.CreateBucket([]byte("widgets")); err != nil { + t.Fatal(err) + } return nil - }) + }); err != nil { + t.Fatal(err) + } // Create the same bucket again. - db.Update(func(tx *bolt.Tx) error { - b, err := tx.CreateBucket([]byte("widgets")) - assert(t, b == nil, "") - equals(t, bolt.ErrBucketExists, err) + if err := db.Update(func(tx *bolt.Tx) error { + if _, err := tx.CreateBucket([]byte("widgets")); err != bolt.ErrBucketExists { + t.Fatalf("unexpected error: %s", err) + } return nil - }) + }); err != nil { + t.Fatal(err) + } } // Ensure that a bucket is created with a non-blank name. -func TestTx_CreateBucket_NameRequired(t *testing.T) { - db := NewTestDB() - defer db.Close() - db.Update(func(tx *bolt.Tx) error { - b, err := tx.CreateBucket(nil) - assert(t, b == nil, "") - equals(t, bolt.ErrBucketNameRequired, err) +func TestTx_CreateBucket_ErrBucketNameRequired(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { + if _, err := tx.CreateBucket(nil); err != bolt.ErrBucketNameRequired { + t.Fatalf("unexpected error: %s", err) + } return nil - }) + }); err != nil { + t.Fatal(err) + } } // Ensure that a bucket can be deleted. func TestTx_DeleteBucket(t *testing.T) { - db := NewTestDB() - defer db.Close() + db := MustOpenDB() + defer db.MustClose() // Create a bucket and add a value. - db.Update(func(tx *bolt.Tx) error { - tx.CreateBucket([]byte("widgets")) - tx.Bucket([]byte("widgets")).Put([]byte("foo"), []byte("bar")) + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("foo"), []byte("bar")); err != nil { + t.Fatal(err) + } return nil - }) + }); err != nil { + t.Fatal(err) + } // Delete the bucket and make sure we can't get the value. - db.Update(func(tx *bolt.Tx) error { - ok(t, tx.DeleteBucket([]byte("widgets"))) - assert(t, tx.Bucket([]byte("widgets")) == nil, "") + if err := db.Update(func(tx *bolt.Tx) error { + if err := tx.DeleteBucket([]byte("widgets")); err != nil { + t.Fatal(err) + } + if tx.Bucket([]byte("widgets")) != nil { + t.Fatal("unexpected bucket") + } return nil - }) + }); err != nil { + t.Fatal(err) + } - db.Update(func(tx *bolt.Tx) error { + if err := db.Update(func(tx *bolt.Tx) error { // Create the bucket again and make sure there's not a phantom value. b, err := tx.CreateBucket([]byte("widgets")) - assert(t, b != nil, "") - ok(t, err) - assert(t, tx.Bucket([]byte("widgets")).Get([]byte("foo")) == nil, "") + if err != nil { + t.Fatal(err) + } + if v := b.Get([]byte("foo")); v != nil { + t.Fatalf("unexpected phantom value: %v", v) + } return nil - }) + }); err != nil { + t.Fatal(err) + } } // Ensure that deleting a bucket on a closed transaction returns an error. -func TestTx_DeleteBucket_Closed(t *testing.T) { - db := NewTestDB() - defer db.Close() - tx, _ := db.Begin(true) - tx.Commit() - equals(t, tx.DeleteBucket([]byte("foo")), bolt.ErrTxClosed) +func TestTx_DeleteBucket_ErrTxClosed(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + tx, err := db.Begin(true) + if err != nil { + t.Fatal(err) + } + if err := tx.Commit(); err != nil { + t.Fatal(err) + } + if err := tx.DeleteBucket([]byte("foo")); err != bolt.ErrTxClosed { + t.Fatalf("unexpected error: %s", err) + } } // Ensure that deleting a bucket with a read-only transaction returns an error. func TestTx_DeleteBucket_ReadOnly(t *testing.T) { - db := NewTestDB() - defer db.Close() - db.View(func(tx *bolt.Tx) error { - equals(t, tx.DeleteBucket([]byte("foo")), bolt.ErrTxNotWritable) + db := MustOpenDB() + defer db.MustClose() + if err := db.View(func(tx *bolt.Tx) error { + if err := tx.DeleteBucket([]byte("foo")); err != bolt.ErrTxNotWritable { + t.Fatalf("unexpected error: %s", err) + } return nil - }) + }); err != nil { + t.Fatal(err) + } } // Ensure that nothing happens when deleting a bucket that doesn't exist. func TestTx_DeleteBucket_NotFound(t *testing.T) { - db := NewTestDB() - defer db.Close() - db.Update(func(tx *bolt.Tx) error { - equals(t, bolt.ErrBucketNotFound, tx.DeleteBucket([]byte("widgets"))) + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { + if err := tx.DeleteBucket([]byte("widgets")); err != bolt.ErrBucketNotFound { + t.Fatalf("unexpected error: %s", err) + } return nil - }) + }); err != nil { + t.Fatal(err) + } +} + +// Ensure that no error is returned when a tx.ForEach function does not return +// an error. +func TestTx_ForEach_NoError(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("foo"), []byte("bar")); err != nil { + t.Fatal(err) + } + + if err := tx.ForEach(func(name []byte, b *bolt.Bucket) error { + return nil + }); err != nil { + t.Fatal(err) + } + return nil + }); err != nil { + t.Fatal(err) + } +} + +// Ensure that an error is returned when a tx.ForEach function returns an error. +func TestTx_ForEach_WithError(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("foo"), []byte("bar")); err != nil { + t.Fatal(err) + } + + marker := errors.New("marker") + if err := tx.ForEach(func(name []byte, b *bolt.Bucket) error { + return marker + }); err != marker { + t.Fatalf("unexpected error: %s", err) + } + return nil + }); err != nil { + t.Fatal(err) + } } // Ensure that Tx commit handlers are called after a transaction successfully commits. func TestTx_OnCommit(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + var x int - db := NewTestDB() - defer db.Close() - db.Update(func(tx *bolt.Tx) error { + if err := db.Update(func(tx *bolt.Tx) error { tx.OnCommit(func() { x += 1 }) tx.OnCommit(func() { x += 2 }) - _, err := tx.CreateBucket([]byte("widgets")) - return err - }) - equals(t, 3, x) + if _, err := tx.CreateBucket([]byte("widgets")); err != nil { + t.Fatal(err) + } + return nil + }); err != nil { + t.Fatal(err) + } else if x != 3 { + t.Fatalf("unexpected x: %d", x) + } } // Ensure that Tx commit handlers are NOT called after a transaction rolls back. func TestTx_OnCommit_Rollback(t *testing.T) { + db := MustOpenDB() + defer db.MustClose() + var x int - db := NewTestDB() - defer db.Close() - db.Update(func(tx *bolt.Tx) error { + if err := db.Update(func(tx *bolt.Tx) error { tx.OnCommit(func() { x += 1 }) tx.OnCommit(func() { x += 2 }) - tx.CreateBucket([]byte("widgets")) + if _, err := tx.CreateBucket([]byte("widgets")); err != nil { + t.Fatal(err) + } return errors.New("rollback this commit") - }) - equals(t, 0, x) + }); err == nil || err.Error() != "rollback this commit" { + t.Fatalf("unexpected error: %s", err) + } else if x != 0 { + t.Fatalf("unexpected x: %d", x) + } } // Ensure that the database can be copied to a file path. func TestTx_CopyFile(t *testing.T) { - db := NewTestDB() - defer db.Close() - var dest = tempfile() - db.Update(func(tx *bolt.Tx) error { - tx.CreateBucket([]byte("widgets")) - tx.Bucket([]byte("widgets")).Put([]byte("foo"), []byte("bar")) - tx.Bucket([]byte("widgets")).Put([]byte("baz"), []byte("bat")) + db := MustOpenDB() + defer db.MustClose() + + path := tempfile() + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("foo"), []byte("bar")); err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("baz"), []byte("bat")); err != nil { + t.Fatal(err) + } return nil - }) + }); err != nil { + t.Fatal(err) + } - ok(t, db.View(func(tx *bolt.Tx) error { return tx.CopyFile(dest, 0600) })) + if err := db.View(func(tx *bolt.Tx) error { + return tx.CopyFile(path, 0600) + }); err != nil { + t.Fatal(err) + } - db2, err := bolt.Open(dest, 0600, nil) - ok(t, err) - defer db2.Close() + db2, err := bolt.Open(path, 0600, nil) + if err != nil { + t.Fatal(err) + } - db2.View(func(tx *bolt.Tx) error { - equals(t, []byte("bar"), tx.Bucket([]byte("widgets")).Get([]byte("foo"))) - equals(t, []byte("bat"), tx.Bucket([]byte("widgets")).Get([]byte("baz"))) + if err := db2.View(func(tx *bolt.Tx) error { + if v := tx.Bucket([]byte("widgets")).Get([]byte("foo")); !bytes.Equal(v, []byte("bar")) { + t.Fatalf("unexpected value: %v", v) + } + if v := tx.Bucket([]byte("widgets")).Get([]byte("baz")); !bytes.Equal(v, []byte("bat")) { + t.Fatalf("unexpected value: %v", v) + } return nil - }) + }); err != nil { + t.Fatal(err) + } + + if err := db2.Close(); err != nil { + t.Fatal(err) + } } type failWriterError struct{} @@ -328,63 +550,107 @@ func (f *failWriter) Write(p []byte) (n int, err error) { // Ensure that Copy handles write errors right. func TestTx_CopyFile_Error_Meta(t *testing.T) { - db := NewTestDB() - defer db.Close() - db.Update(func(tx *bolt.Tx) error { - tx.CreateBucket([]byte("widgets")) - tx.Bucket([]byte("widgets")).Put([]byte("foo"), []byte("bar")) - tx.Bucket([]byte("widgets")).Put([]byte("baz"), []byte("bat")) + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("foo"), []byte("bar")); err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("baz"), []byte("bat")); err != nil { + t.Fatal(err) + } return nil - }) + }); err != nil { + t.Fatal(err) + } - err := db.View(func(tx *bolt.Tx) error { return tx.Copy(&failWriter{}) }) - equals(t, err.Error(), "meta copy: error injected for tests") + if err := db.View(func(tx *bolt.Tx) error { + return tx.Copy(&failWriter{}) + }); err == nil || err.Error() != "meta copy: error injected for tests" { + t.Fatalf("unexpected error: %v", err) + } } // Ensure that Copy handles write errors right. func TestTx_CopyFile_Error_Normal(t *testing.T) { - db := NewTestDB() - defer db.Close() - db.Update(func(tx *bolt.Tx) error { - tx.CreateBucket([]byte("widgets")) - tx.Bucket([]byte("widgets")).Put([]byte("foo"), []byte("bar")) - tx.Bucket([]byte("widgets")).Put([]byte("baz"), []byte("bat")) + db := MustOpenDB() + defer db.MustClose() + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("foo"), []byte("bar")); err != nil { + t.Fatal(err) + } + if err := b.Put([]byte("baz"), []byte("bat")); err != nil { + t.Fatal(err) + } return nil - }) + }); err != nil { + t.Fatal(err) + } - err := db.View(func(tx *bolt.Tx) error { return tx.Copy(&failWriter{3 * db.Info().PageSize}) }) - equals(t, err.Error(), "error injected for tests") + if err := db.View(func(tx *bolt.Tx) error { + return tx.Copy(&failWriter{3 * db.Info().PageSize}) + }); err == nil || err.Error() != "error injected for tests" { + t.Fatalf("unexpected error: %v", err) + } } func ExampleTx_Rollback() { // Open the database. - db, _ := bolt.Open(tempfile(), 0666, nil) + db, err := bolt.Open(tempfile(), 0666, nil) + if err != nil { + log.Fatal(err) + } defer os.Remove(db.Path()) - defer db.Close() // Create a bucket. - db.Update(func(tx *bolt.Tx) error { + if err := db.Update(func(tx *bolt.Tx) error { _, err := tx.CreateBucket([]byte("widgets")) return err - }) + }); err != nil { + log.Fatal(err) + } // Set a value for a key. - db.Update(func(tx *bolt.Tx) error { + if err := db.Update(func(tx *bolt.Tx) error { return tx.Bucket([]byte("widgets")).Put([]byte("foo"), []byte("bar")) - }) + }); err != nil { + log.Fatal(err) + } // Update the key but rollback the transaction so it never saves. - tx, _ := db.Begin(true) + tx, err := db.Begin(true) + if err != nil { + log.Fatal(err) + } b := tx.Bucket([]byte("widgets")) - b.Put([]byte("foo"), []byte("baz")) - tx.Rollback() + if err := b.Put([]byte("foo"), []byte("baz")); err != nil { + log.Fatal(err) + } + if err := tx.Rollback(); err != nil { + log.Fatal(err) + } // Ensure that our original value is still set. - db.View(func(tx *bolt.Tx) error { + if err := db.View(func(tx *bolt.Tx) error { value := tx.Bucket([]byte("widgets")).Get([]byte("foo")) fmt.Printf("The value for 'foo' is still: %s\n", value) return nil - }) + }); err != nil { + log.Fatal(err) + } + + // Close database to release file lock. + if err := db.Close(); err != nil { + log.Fatal(err) + } // Output: // The value for 'foo' is still: bar @@ -392,32 +658,58 @@ func ExampleTx_Rollback() { func ExampleTx_CopyFile() { // Open the database. - db, _ := bolt.Open(tempfile(), 0666, nil) + db, err := bolt.Open(tempfile(), 0666, nil) + if err != nil { + log.Fatal(err) + } defer os.Remove(db.Path()) - defer db.Close() // Create a bucket and a key. - db.Update(func(tx *bolt.Tx) error { - tx.CreateBucket([]byte("widgets")) - tx.Bucket([]byte("widgets")).Put([]byte("foo"), []byte("bar")) + if err := db.Update(func(tx *bolt.Tx) error { + b, err := tx.CreateBucket([]byte("widgets")) + if err != nil { + return err + } + if err := b.Put([]byte("foo"), []byte("bar")); err != nil { + return err + } return nil - }) + }); err != nil { + log.Fatal(err) + } // Copy the database to another file. toFile := tempfile() - db.View(func(tx *bolt.Tx) error { return tx.CopyFile(toFile, 0666) }) + if err := db.View(func(tx *bolt.Tx) error { + return tx.CopyFile(toFile, 0666) + }); err != nil { + log.Fatal(err) + } defer os.Remove(toFile) // Open the cloned database. - db2, _ := bolt.Open(toFile, 0666, nil) - defer db2.Close() + db2, err := bolt.Open(toFile, 0666, nil) + if err != nil { + log.Fatal(err) + } // Ensure that the key exists in the copy. - db2.View(func(tx *bolt.Tx) error { + if err := db2.View(func(tx *bolt.Tx) error { value := tx.Bucket([]byte("widgets")).Get([]byte("foo")) fmt.Printf("The value for 'foo' in the clone is: %s\n", value) return nil - }) + }); err != nil { + log.Fatal(err) + } + + // Close database to release file lock. + if err := db.Close(); err != nil { + log.Fatal(err) + } + + if err := db2.Close(); err != nil { + log.Fatal(err) + } // Output: // The value for 'foo' in the clone is: bar diff --git a/Godeps/_workspace/src/github.com/gorilla/websocket/README.md b/Godeps/_workspace/src/github.com/gorilla/websocket/README.md index 9ad75a0f..9d71959e 100644 --- a/Godeps/_workspace/src/github.com/gorilla/websocket/README.md +++ b/Godeps/_workspace/src/github.com/gorilla/websocket/README.md @@ -7,6 +7,8 @@ Gorilla WebSocket is a [Go](http://golang.org/) implementation of the * [API Reference](http://godoc.org/github.com/gorilla/websocket) * [Chat example](https://github.com/gorilla/websocket/tree/master/examples/chat) +* [Command example](https://github.com/gorilla/websocket/tree/master/examples/command) +* [Client and server example](https://github.com/gorilla/websocket/tree/master/examples/echo) * [File watch example](https://github.com/gorilla/websocket/tree/master/examples/filewatch) ### Status diff --git a/Godeps/_workspace/src/github.com/gorilla/websocket/client.go b/Godeps/_workspace/src/github.com/gorilla/websocket/client.go index c25d24f8..61389060 100644 --- a/Godeps/_workspace/src/github.com/gorilla/websocket/client.go +++ b/Godeps/_workspace/src/github.com/gorilla/websocket/client.go @@ -5,8 +5,12 @@ package websocket import ( + "bufio" + "bytes" "crypto/tls" "errors" + "io" + "io/ioutil" "net" "net/http" "net/url" @@ -27,50 +31,17 @@ var ErrBadHandshake = errors.New("websocket: bad handshake") // If the WebSocket handshake fails, ErrBadHandshake is returned along with a // non-nil *http.Response so that callers can handle redirects, authentication, // etc. +// +// Deprecated: Use Dialer instead. func NewClient(netConn net.Conn, u *url.URL, requestHeader http.Header, readBufSize, writeBufSize int) (c *Conn, response *http.Response, err error) { - challengeKey, err := generateChallengeKey() - if err != nil { - return nil, nil, err + d := Dialer{ + ReadBufferSize: readBufSize, + WriteBufferSize: writeBufSize, + NetDial: func(net, addr string) (net.Conn, error) { + return netConn, nil + }, } - acceptKey := computeAcceptKey(challengeKey) - - c = newConn(netConn, false, readBufSize, writeBufSize) - p := c.writeBuf[:0] - p = append(p, "GET "...) - p = append(p, u.RequestURI()...) - p = append(p, " HTTP/1.1\r\nHost: "...) - p = append(p, u.Host...) - // "Upgrade" is capitalized for servers that do not use case insensitive - // comparisons on header tokens. - p = append(p, "\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Version: 13\r\nSec-WebSocket-Key: "...) - p = append(p, challengeKey...) - p = append(p, "\r\n"...) - for k, vs := range requestHeader { - for _, v := range vs { - p = append(p, k...) - p = append(p, ": "...) - p = append(p, v...) - p = append(p, "\r\n"...) - } - } - p = append(p, "\r\n"...) - - if _, err := netConn.Write(p); err != nil { - return nil, nil, err - } - - resp, err := http.ReadResponse(c.br, &http.Request{Method: "GET", URL: u}) - if err != nil { - return nil, nil, err - } - if resp.StatusCode != 101 || - !strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") || - !strings.EqualFold(resp.Header.Get("Connection"), "upgrade") || - resp.Header.Get("Sec-Websocket-Accept") != acceptKey { - return nil, resp, ErrBadHandshake - } - c.subprotocol = resp.Header.Get("Sec-Websocket-Protocol") - return c, resp, nil + return d.Dial(u.String(), requestHeader) } // A Dialer contains options for connecting to WebSocket server. @@ -79,6 +50,12 @@ type Dialer struct { // NetDial is nil, net.Dial is used. NetDial func(network, addr string) (net.Conn, error) + // Proxy specifies a function to return a proxy for a given + // Request. If the function returns a non-nil error, the + // request is aborted with the provided error. + // If Proxy is nil or returns a nil *URL, no proxy is used. + Proxy func(*http.Request) (*url.URL, error) + // TLSClientConfig specifies the TLS configuration to use with tls.Client. // If nil, the default configuration is used. TLSClientConfig *tls.Config @@ -96,17 +73,15 @@ type Dialer struct { var errMalformedURL = errors.New("malformed ws or wss URL") -// parseURL parses the URL. The url.Parse function is not used here because -// url.Parse mangles the path. +// parseURL parses the URL. +// +// This function is a replacement for the standard library url.Parse function. +// In Go 1.4 and earlier, url.Parse loses information from the path. func parseURL(s string) (*url.URL, error) { // From the RFC: // // ws-URI = "ws:" "//" host [ ":" port ] path [ "?" query ] // wss-URI = "wss:" "//" host [ ":" port ] path [ "?" query ] - // - // We don't use the net/url parser here because the dialer interface does - // not provide a way for applications to work around percent deocding in - // the net/url parser. var u url.URL switch { @@ -120,11 +95,24 @@ func parseURL(s string) (*url.URL, error) { return nil, errMalformedURL } - u.Host = s - u.Opaque = "/" + if i := strings.Index(s, "?"); i >= 0 { + u.RawQuery = s[i+1:] + s = s[:i] + } + if i := strings.Index(s, "/"); i >= 0 { - u.Host = s[:i] u.Opaque = s[i:] + s = s[:i] + } else { + u.Opaque = "/" + } + + u.Host = s + + if strings.Contains(u.Host, "@") { + // Don't bother parsing user information because user information is + // not allowed in websocket URIs. + return nil, errMalformedURL } return &u, nil @@ -136,9 +124,12 @@ func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) { if i := strings.LastIndex(u.Host, ":"); i > strings.LastIndex(u.Host, "]") { hostNoPort = hostNoPort[:i] } else { - if u.Scheme == "wss" { + switch u.Scheme { + case "wss": hostPort += ":443" - } else { + case "https": + hostPort += ":443" + default: hostPort += ":80" } } @@ -146,7 +137,9 @@ func hostPortNoPort(u *url.URL) (hostPort, hostNoPort string) { } // DefaultDialer is a dialer with all fields set to the default zero values. -var DefaultDialer *Dialer +var DefaultDialer = &Dialer{ + Proxy: http.ProxyFromEnvironment, +} // Dial creates a new client connection. Use requestHeader to specify the // origin (Origin), subprotocols (Sec-WebSocket-Protocol) and cookies (Cookie). @@ -155,17 +148,94 @@ var DefaultDialer *Dialer // // If the WebSocket handshake fails, ErrBadHandshake is returned along with a // non-nil *http.Response so that callers can handle redirects, authentication, -// etc. +// etcetera. The response body may not contain the entire response and does not +// need to be closed by the application. func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Response, error) { + + if d == nil { + d = &Dialer{ + Proxy: http.ProxyFromEnvironment, + } + } + + challengeKey, err := generateChallengeKey() + if err != nil { + return nil, nil, err + } + u, err := parseURL(urlStr) if err != nil { return nil, nil, err } + switch u.Scheme { + case "ws": + u.Scheme = "http" + case "wss": + u.Scheme = "https" + default: + return nil, nil, errMalformedURL + } + + if u.User != nil { + // User name and password are not allowed in websocket URIs. + return nil, nil, errMalformedURL + } + + req := &http.Request{ + Method: "GET", + URL: u, + Proto: "HTTP/1.1", + ProtoMajor: 1, + ProtoMinor: 1, + Header: make(http.Header), + Host: u.Host, + } + + // Set the request headers using the capitalization for names and values in + // RFC examples. Although the capitalization shouldn't matter, there are + // servers that depend on it. The Header.Set method is not used because the + // method canonicalizes the header names. + req.Header["Upgrade"] = []string{"websocket"} + req.Header["Connection"] = []string{"Upgrade"} + req.Header["Sec-WebSocket-Key"] = []string{challengeKey} + req.Header["Sec-WebSocket-Version"] = []string{"13"} + if len(d.Subprotocols) > 0 { + req.Header["Sec-WebSocket-Protocol"] = []string{strings.Join(d.Subprotocols, ", ")} + } + for k, vs := range requestHeader { + switch { + case k == "Host": + if len(vs) > 0 { + req.Host = vs[0] + } + case k == "Upgrade" || + k == "Connection" || + k == "Sec-Websocket-Key" || + k == "Sec-Websocket-Version" || + (k == "Sec-Websocket-Protocol" && len(d.Subprotocols) > 0): + return nil, nil, errors.New("websocket: duplicate header not allowed: " + k) + default: + req.Header[k] = vs + } + } + hostPort, hostNoPort := hostPortNoPort(u) - if d == nil { - d = &Dialer{} + var proxyURL *url.URL + // Check wether the proxy method has been configured + if d.Proxy != nil { + proxyURL, err = d.Proxy(req) + } + if err != nil { + return nil, nil, err + } + + var targetHostPort string + if proxyURL != nil { + targetHostPort, _ = hostPortNoPort(proxyURL) + } else { + targetHostPort = hostPort } var deadline time.Time @@ -179,7 +249,7 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re netDial = netDialer.Dial } - netConn, err := netDial("tcp", hostPort) + netConn, err := netDial("tcp", targetHostPort) if err != nil { return nil, nil, err } @@ -194,7 +264,31 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re return nil, nil, err } - if u.Scheme == "wss" { + if proxyURL != nil { + connectReq := &http.Request{ + Method: "CONNECT", + URL: &url.URL{Opaque: hostPort}, + Host: hostPort, + Header: make(http.Header), + } + + connectReq.Write(netConn) + + // Read response. + // Okay to use and discard buffered reader here, because + // TLS server will not speak until spoken to. + br := bufio.NewReader(netConn) + resp, err := http.ReadResponse(br, connectReq) + if err != nil { + return nil, nil, err + } + if resp.StatusCode != 200 { + f := strings.SplitN(resp.Status, " ", 2) + return nil, nil, errors.New(f[1]) + } + } + + if u.Scheme == "https" { cfg := d.TLSClientConfig if cfg == nil { cfg = &tls.Config{ServerName: hostNoPort} @@ -215,19 +309,31 @@ func (d *Dialer) Dial(urlStr string, requestHeader http.Header) (*Conn, *http.Re } } - if len(d.Subprotocols) > 0 { - h := http.Header{} - for k, v := range requestHeader { - h[k] = v - } - h.Set("Sec-Websocket-Protocol", strings.Join(d.Subprotocols, ", ")) - requestHeader = h + conn := newConn(netConn, false, d.ReadBufferSize, d.WriteBufferSize) + + if err := req.Write(netConn); err != nil { + return nil, nil, err } - conn, resp, err := NewClient(netConn, u, requestHeader, d.ReadBufferSize, d.WriteBufferSize) + resp, err := http.ReadResponse(conn.br, req) if err != nil { - return nil, resp, err + return nil, nil, err } + if resp.StatusCode != 101 || + !strings.EqualFold(resp.Header.Get("Upgrade"), "websocket") || + !strings.EqualFold(resp.Header.Get("Connection"), "upgrade") || + resp.Header.Get("Sec-Websocket-Accept") != computeAcceptKey(challengeKey) { + // Before closing the network connection on return from this + // function, slurp up some of the response to aid application + // debugging. + buf := make([]byte, 1024) + n, _ := io.ReadFull(resp.Body, buf) + resp.Body = ioutil.NopCloser(bytes.NewReader(buf[:n])) + return nil, resp, ErrBadHandshake + } + + resp.Body = ioutil.NopCloser(bytes.NewReader([]byte{})) + conn.subprotocol = resp.Header.Get("Sec-Websocket-Protocol") netConn.SetDeadline(time.Time{}) netConn = nil // to avoid close in defer. diff --git a/Godeps/_workspace/src/github.com/gorilla/websocket/client_server_test.go b/Godeps/_workspace/src/github.com/gorilla/websocket/client_server_test.go index 8c608f68..c67550e9 100644 --- a/Godeps/_workspace/src/github.com/gorilla/websocket/client_server_test.go +++ b/Godeps/_workspace/src/github.com/gorilla/websocket/client_server_test.go @@ -8,11 +8,13 @@ import ( "crypto/tls" "crypto/x509" "io" + "io/ioutil" "net" "net/http" "net/http/httptest" "net/url" "reflect" + "strings" "testing" "time" ) @@ -34,29 +36,42 @@ var cstDialer = Dialer{ type cstHandler struct{ *testing.T } -type Server struct { +type cstServer struct { *httptest.Server URL string } -func newServer(t *testing.T) *Server { - var s Server +const ( + cstPath = "/a/b" + cstRawQuery = "x=y" + cstRequestURI = cstPath + "?" + cstRawQuery +) + +func newServer(t *testing.T) *cstServer { + var s cstServer s.Server = httptest.NewServer(cstHandler{t}) - s.URL = "ws" + s.Server.URL[len("http"):] + s.Server.URL += cstRequestURI + s.URL = makeWsProto(s.Server.URL) return &s } -func newTLSServer(t *testing.T) *Server { - var s Server +func newTLSServer(t *testing.T) *cstServer { + var s cstServer s.Server = httptest.NewTLSServer(cstHandler{t}) - s.URL = "ws" + s.Server.URL[len("http"):] + s.Server.URL += cstRequestURI + s.URL = makeWsProto(s.Server.URL) return &s } func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if r.Method != "GET" { - t.Logf("method %s not allowed", r.Method) - http.Error(w, "method not allowed", 405) + if r.URL.Path != cstPath { + t.Logf("path=%v, want %v", r.URL.Path, cstPath) + http.Error(w, "bad path", 400) + return + } + if r.URL.RawQuery != cstRawQuery { + t.Logf("query=%v, want %v", r.URL.RawQuery, cstRawQuery) + http.Error(w, "bad path", 400) return } subprotos := Subprotocols(r) @@ -97,6 +112,10 @@ func (t cstHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } } +func makeWsProto(s string) string { + return "ws" + strings.TrimPrefix(s, "http") +} + func sendRecv(t *testing.T, ws *Conn) { const message = "Hello World!" if err := ws.SetWriteDeadline(time.Now().Add(time.Second)); err != nil { @@ -117,6 +136,45 @@ func sendRecv(t *testing.T, ws *Conn) { } } +func TestProxyDial(t *testing.T) { + + s := newServer(t) + defer s.Close() + + surl, _ := url.Parse(s.URL) + + cstDialer.Proxy = http.ProxyURL(surl) + + connect := false + origHandler := s.Server.Config.Handler + + // Capture the request Host header. + s.Server.Config.Handler = http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + if r.Method == "CONNECT" { + connect = true + w.WriteHeader(200) + return + } + + if !connect { + t.Log("connect not recieved") + http.Error(w, "connect not recieved", 405) + return + } + origHandler.ServeHTTP(w, r) + }) + + ws, _, err := cstDialer.Dial(s.URL, nil) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer ws.Close() + sendRecv(t, ws) + + cstDialer.Proxy = http.ProxyFromEnvironment +} + func TestDial(t *testing.T) { s := newServer(t) defer s.Close() @@ -148,7 +206,7 @@ func TestDialTLS(t *testing.T) { d := cstDialer d.NetDial = func(network, addr string) (net.Conn, error) { return net.Dial(network, u.Host) } d.TLSClientConfig = &tls.Config{RootCAs: certs} - ws, _, err := d.Dial("wss://example.com/", nil) + ws, _, err := d.Dial("wss://example.com"+cstRequestURI, nil) if err != nil { t.Fatalf("Dial: %v", err) } @@ -157,6 +215,7 @@ func TestDialTLS(t *testing.T) { } func xTestDialTLSBadCert(t *testing.T) { + // This test is deactivated because of noisy logging from the net/http package. s := newTLSServer(t) defer s.Close() @@ -222,6 +281,45 @@ func TestDialBadOrigin(t *testing.T) { } } +func TestDialBadHeader(t *testing.T) { + s := newServer(t) + defer s.Close() + + for _, k := range []string{"Upgrade", + "Connection", + "Sec-Websocket-Key", + "Sec-Websocket-Version", + "Sec-Websocket-Protocol"} { + h := http.Header{} + h.Set(k, "bad") + ws, _, err := cstDialer.Dial(s.URL, http.Header{"Origin": {"bad"}}) + if err == nil { + ws.Close() + t.Errorf("Dial with header %s returned nil", k) + } + } +} + +func TestBadMethod(t *testing.T) { + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ws, err := cstUpgrader.Upgrade(w, r, nil) + if err == nil { + t.Errorf("handshake succeeded, expect fail") + ws.Close() + } + })) + defer s.Close() + + resp, err := http.PostForm(s.URL, url.Values{}) + if err != nil { + t.Fatalf("PostForm returned error %v", err) + } + resp.Body.Close() + if resp.StatusCode != http.StatusMethodNotAllowed { + t.Errorf("Status = %d, want %d", resp.StatusCode, http.StatusMethodNotAllowed) + } +} + func TestHandshake(t *testing.T) { s := newServer(t) defer s.Close() @@ -247,3 +345,66 @@ func TestHandshake(t *testing.T) { } sendRecv(t, ws) } + +func TestRespOnBadHandshake(t *testing.T) { + const expectedStatus = http.StatusGone + const expectedBody = "This is the response body." + + s := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.WriteHeader(expectedStatus) + io.WriteString(w, expectedBody) + })) + defer s.Close() + + ws, resp, err := cstDialer.Dial(makeWsProto(s.URL), nil) + if err == nil { + ws.Close() + t.Fatalf("Dial: nil") + } + + if resp == nil { + t.Fatalf("resp=nil, err=%v", err) + } + + if resp.StatusCode != expectedStatus { + t.Errorf("resp.StatusCode=%d, want %d", resp.StatusCode, expectedStatus) + } + + p, err := ioutil.ReadAll(resp.Body) + if err != nil { + t.Fatalf("ReadFull(resp.Body) returned error %v", err) + } + + if string(p) != expectedBody { + t.Errorf("resp.Body=%s, want %s", p, expectedBody) + } +} + +// TestHostHeader confirms that the host header provided in the call to Dial is +// sent to the server. +func TestHostHeader(t *testing.T) { + s := newServer(t) + defer s.Close() + + specifiedHost := make(chan string, 1) + origHandler := s.Server.Config.Handler + + // Capture the request Host header. + s.Server.Config.Handler = http.HandlerFunc( + func(w http.ResponseWriter, r *http.Request) { + specifiedHost <- r.Host + origHandler.ServeHTTP(w, r) + }) + + ws, _, err := cstDialer.Dial(s.URL, http.Header{"Host": {"testhost"}}) + if err != nil { + t.Fatalf("Dial: %v", err) + } + defer ws.Close() + + if gotHost := <-specifiedHost; gotHost != "testhost" { + t.Fatalf("gotHost = %q, want \"testhost\"", gotHost) + } + + sendRecv(t, ws) +} diff --git a/Godeps/_workspace/src/github.com/gorilla/websocket/client_test.go b/Godeps/_workspace/src/github.com/gorilla/websocket/client_test.go index d2f2ebd7..7d2b0844 100644 --- a/Godeps/_workspace/src/github.com/gorilla/websocket/client_test.go +++ b/Godeps/_workspace/src/github.com/gorilla/websocket/client_test.go @@ -11,15 +11,19 @@ import ( ) var parseURLTests = []struct { - s string - u *url.URL + s string + u *url.URL + rui string }{ - {"ws://example.com/", &url.URL{Scheme: "ws", Host: "example.com", Opaque: "/"}}, - {"ws://example.com", &url.URL{Scheme: "ws", Host: "example.com", Opaque: "/"}}, - {"ws://example.com:7777/", &url.URL{Scheme: "ws", Host: "example.com:7777", Opaque: "/"}}, - {"wss://example.com/", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/"}}, - {"wss://example.com/a/b", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/a/b"}}, - {"ss://example.com/a/b", nil}, + {"ws://example.com/", &url.URL{Scheme: "ws", Host: "example.com", Opaque: "/"}, "/"}, + {"ws://example.com", &url.URL{Scheme: "ws", Host: "example.com", Opaque: "/"}, "/"}, + {"ws://example.com:7777/", &url.URL{Scheme: "ws", Host: "example.com:7777", Opaque: "/"}, "/"}, + {"wss://example.com/", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/"}, "/"}, + {"wss://example.com/a/b", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/a/b"}, "/a/b"}, + {"ss://example.com/a/b", nil, ""}, + {"ws://webmaster@example.com/", nil, ""}, + {"wss://example.com/a/b?x=y", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/a/b", RawQuery: "x=y"}, "/a/b?x=y"}, + {"wss://example.com?x=y", &url.URL{Scheme: "wss", Host: "example.com", Opaque: "/", RawQuery: "x=y"}, "/?x=y"}, } func TestParseURL(t *testing.T) { @@ -29,14 +33,19 @@ func TestParseURL(t *testing.T) { t.Errorf("parseURL(%q) returned error %v", tt.s, err) continue } - if tt.u == nil && err == nil { - t.Errorf("parseURL(%q) did not return error", tt.s) + if tt.u == nil { + if err == nil { + t.Errorf("parseURL(%q) did not return error", tt.s) + } continue } if !reflect.DeepEqual(u, tt.u) { - t.Errorf("parseURL(%q) returned %v, want %v", tt.s, u, tt.u) + t.Errorf("parseURL(%q) = %v, want %v", tt.s, u, tt.u) continue } + if u.RequestURI() != tt.rui { + t.Errorf("parseURL(%q).RequestURI() = %v, want %v", tt.s, u.RequestURI(), tt.rui) + } } } diff --git a/Godeps/_workspace/src/github.com/gorilla/websocket/conn.go b/Godeps/_workspace/src/github.com/gorilla/websocket/conn.go index e719f1ce..e8b6b3e0 100644 --- a/Godeps/_workspace/src/github.com/gorilla/websocket/conn.go +++ b/Godeps/_workspace/src/github.com/gorilla/websocket/conn.go @@ -88,19 +88,23 @@ func (e *netError) Error() string { return e.msg } func (e *netError) Temporary() bool { return e.temporary } func (e *netError) Timeout() bool { return e.timeout } -// closeError represents close frame. -type closeError struct { - code int - text string +// CloseError represents close frame. +type CloseError struct { + + // Code is defined in RFC 6455, section 11.7. + Code int + + // Text is the optional text payload. + Text string } -func (e *closeError) Error() string { - return "websocket: close " + strconv.Itoa(e.code) + " " + e.text +func (e *CloseError) Error() string { + return "websocket: close " + strconv.Itoa(e.Code) + " " + e.Text } var ( - errWriteTimeout = &netError{msg: "websocket: write timeout", timeout: true} - errUnexpectedEOF = &closeError{code: CloseAbnormalClosure, text: io.ErrUnexpectedEOF.Error()} + errWriteTimeout = &netError{msg: "websocket: write timeout", timeout: true, temporary: true} + errUnexpectedEOF = &CloseError{Code: CloseAbnormalClosure, Text: io.ErrUnexpectedEOF.Error()} errBadWriteOpCode = errors.New("websocket: bad write message type") errWriteClosed = errors.New("websocket: write closed") errInvalidControlFrame = errors.New("websocket: invalid control frame") @@ -296,7 +300,7 @@ func (c *Conn) WriteControl(messageType int, data []byte, deadline time.Time) er if n != 0 && n != len(buf) { c.conn.Close() } - return err + return hideTempErr(err) } // NextWriter returns a writer for the next message to send. The writer's @@ -673,12 +677,7 @@ func (c *Conn) advanceFrame() (int, error) { closeCode = int(binary.BigEndian.Uint16(payload)) closeText = string(payload[2:]) } - switch closeCode { - case CloseNormalClosure, CloseGoingAway: - return noFrame, io.EOF - default: - return noFrame, &closeError{code: closeCode, text: closeText} - } + return noFrame, &CloseError{Code: closeCode, Text: closeText} } return frameType, nil @@ -790,20 +789,27 @@ func (c *Conn) SetReadLimit(limit int64) { } // SetPingHandler sets the handler for ping messages received from the peer. -// The default ping handler sends a pong to the peer. -func (c *Conn) SetPingHandler(h func(string) error) { +// The appData argument to h is the PING frame application data. The default +// ping handler sends a pong to the peer. +func (c *Conn) SetPingHandler(h func(appData string) error) { if h == nil { h = func(message string) error { - c.WriteControl(PongMessage, []byte(message), time.Now().Add(writeWait)) - return nil + err := c.WriteControl(PongMessage, []byte(message), time.Now().Add(writeWait)) + if err == ErrCloseSent { + return nil + } else if e, ok := err.(net.Error); ok && e.Temporary() { + return nil + } + return err } } c.handlePing = h } // SetPongHandler sets the handler for pong messages received from the peer. -// The default pong handler does nothing. -func (c *Conn) SetPongHandler(h func(string) error) { +// The appData argument to h is the PONG frame application data. The default +// pong handler does nothing. +func (c *Conn) SetPongHandler(h func(appData string) error) { if h == nil { h = func(string) error { return nil } } diff --git a/Godeps/_workspace/src/github.com/gorilla/websocket/conn_test.go b/Godeps/_workspace/src/github.com/gorilla/websocket/conn_test.go index 1f1197e7..02f2d4b5 100644 --- a/Godeps/_workspace/src/github.com/gorilla/websocket/conn_test.go +++ b/Godeps/_workspace/src/github.com/gorilla/websocket/conn_test.go @@ -5,11 +5,13 @@ package websocket import ( + "bufio" "bytes" "fmt" "io" "io/ioutil" "net" + "reflect" "testing" "testing/iotest" "time" @@ -146,13 +148,15 @@ func TestControl(t *testing.T) { func TestCloseBeforeFinalFrame(t *testing.T) { const bufSize = 512 + expectedErr := &CloseError{Code: CloseNormalClosure, Text: "hello"} + var b1, b2 bytes.Buffer wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, 1024, bufSize) rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, 1024, 1024) w, _ := wc.NextWriter(BinaryMessage) w.Write(make([]byte, bufSize+bufSize/2)) - wc.WriteControl(CloseMessage, FormatCloseMessage(CloseNormalClosure, ""), time.Now().Add(10*time.Second)) + wc.WriteControl(CloseMessage, FormatCloseMessage(expectedErr.Code, expectedErr.Text), time.Now().Add(10*time.Second)) w.Close() op, r, err := rc.NextReader() @@ -160,12 +164,12 @@ func TestCloseBeforeFinalFrame(t *testing.T) { t.Fatalf("NextReader() returned %d, %v", op, err) } _, err = io.Copy(ioutil.Discard, r) - if err != errUnexpectedEOF { - t.Fatalf("io.Copy() returned %v, want %v", err, errUnexpectedEOF) + if !reflect.DeepEqual(err, expectedErr) { + t.Fatalf("io.Copy() returned %v, want %v", err, expectedErr) } _, _, err = rc.NextReader() - if err != io.EOF { - t.Fatalf("NextReader() returned %v, want %v", err, io.EOF) + if !reflect.DeepEqual(err, expectedErr) { + t.Fatalf("NextReader() returned %v, want %v", err, expectedErr) } } @@ -236,3 +240,33 @@ func TestUnderlyingConn(t *testing.T) { t.Fatalf("Underlying conn is not what it should be.") } } + +func TestBufioReadBytes(t *testing.T) { + + // Test calling bufio.ReadBytes for value longer than read buffer size. + + m := make([]byte, 512) + m[len(m)-1] = '\n' + + var b1, b2 bytes.Buffer + wc := newConn(fakeNetConn{Reader: nil, Writer: &b1}, false, len(m)+64, len(m)+64) + rc := newConn(fakeNetConn{Reader: &b1, Writer: &b2}, true, len(m)-64, len(m)-64) + + w, _ := wc.NextWriter(BinaryMessage) + w.Write(m) + w.Close() + + op, r, err := rc.NextReader() + if op != BinaryMessage || err != nil { + t.Fatalf("NextReader() returned %d, %v", op, err) + } + + br := bufio.NewReader(r) + p, err := br.ReadBytes('\n') + if err != nil { + t.Fatalf("ReadBytes() returned %v", err) + } + if len(p) != len(m) { + t.Fatalf("read returnd %d bytes, want %d bytes", len(p), len(m)) + } +} diff --git a/Godeps/_workspace/src/github.com/gorilla/websocket/doc.go b/Godeps/_workspace/src/github.com/gorilla/websocket/doc.go index 0d2bd912..72286279 100644 --- a/Godeps/_workspace/src/github.com/gorilla/websocket/doc.go +++ b/Godeps/_workspace/src/github.com/gorilla/websocket/doc.go @@ -24,7 +24,7 @@ // ... Use conn to send and receive messages. // } // -// Call the connection WriteMessage and ReadMessages methods to send and +// Call the connection's WriteMessage and ReadMessage methods to send and // receive messages as a slice of bytes. This snippet of code shows how to echo // messages using these methods: // @@ -97,10 +97,13 @@ // // Concurrency // -// Connections do not support concurrent calls to the write methods -// (NextWriter, SetWriteDeadline, WriteMessage) or concurrent calls to the read -// methods methods (NextReader, SetReadDeadline, ReadMessage). Connections do -// support a concurrent reader and writer. +// Connections support one concurrent reader and one concurrent writer. +// +// Applications are responsible for ensuring that no more than one goroutine +// calls the write methods (NextWriter, SetWriteDeadline, WriteMessage, +// WriteJSON) concurrently and that no more than one goroutine calls the read +// methods (NextReader, SetReadDeadline, ReadMessage, ReadJSON, SetPongHandler, +// SetPingHandler) concurrently. // // The Close and WriteControl methods can be called concurrently with all other // methods. diff --git a/Godeps/_workspace/src/github.com/gorilla/websocket/examples/chat/README.md b/Godeps/_workspace/src/github.com/gorilla/websocket/examples/chat/README.md index 08fc3e65..5df3cf1a 100644 --- a/Godeps/_workspace/src/github.com/gorilla/websocket/examples/chat/README.md +++ b/Godeps/_workspace/src/github.com/gorilla/websocket/examples/chat/README.md @@ -17,3 +17,4 @@ using the following commands. $ cd `go list -f '{{.Dir}}' github.com/gorilla/websocket/examples/chat` $ go run *.go +To use the chat example, open http://localhost:8080/ in your browser. diff --git a/Godeps/_workspace/src/github.com/gorilla/websocket/examples/chat/conn.go b/Godeps/_workspace/src/github.com/gorilla/websocket/examples/chat/conn.go index cde45c8d..2a872e37 100644 --- a/Godeps/_workspace/src/github.com/gorilla/websocket/examples/chat/conn.go +++ b/Godeps/_workspace/src/github.com/gorilla/websocket/examples/chat/conn.go @@ -88,12 +88,8 @@ func (c *connection) writePump() { } } -// serverWs handles websocket requests from the peer. +// serveWs handles websocket requests from the peer. func serveWs(w http.ResponseWriter, r *http.Request) { - if r.Method != "GET" { - http.Error(w, "Method not allowed", 405) - return - } ws, err := upgrader.Upgrade(w, r, nil) if err != nil { log.Println(err) diff --git a/Godeps/_workspace/src/github.com/gorilla/websocket/examples/command/README.md b/Godeps/_workspace/src/github.com/gorilla/websocket/examples/command/README.md new file mode 100644 index 00000000..c30d3979 --- /dev/null +++ b/Godeps/_workspace/src/github.com/gorilla/websocket/examples/command/README.md @@ -0,0 +1,19 @@ +# Command example + +This example connects a websocket connection to stdin and stdout of a command. +Received messages are written to stdin followed by a `\n`. Each line read from +from standard out is sent as a message to the client. + + $ go get github.com/gorilla/websocket + $ cd `go list -f '{{.Dir}}' github.com/gorilla/websocket/examples/command` + $ go run main.go + # Open http://localhost:8080/ . + +Try the following commands. + + # Echo sent messages to the output area. + $ go run main.go cat + + # Run a shell.Try sending "ls" and "cat main.go". + $ go run main.go sh + diff --git a/Godeps/_workspace/src/github.com/gorilla/websocket/examples/command/home.html b/Godeps/_workspace/src/github.com/gorilla/websocket/examples/command/home.html new file mode 100644 index 00000000..72fd02b2 --- /dev/null +++ b/Godeps/_workspace/src/github.com/gorilla/websocket/examples/command/home.html @@ -0,0 +1,96 @@ + + + +Command Example + + + + + +
+
+ + +
+ + diff --git a/Godeps/_workspace/src/github.com/gorilla/websocket/examples/command/main.go b/Godeps/_workspace/src/github.com/gorilla/websocket/examples/command/main.go new file mode 100644 index 00000000..a6199df4 --- /dev/null +++ b/Godeps/_workspace/src/github.com/gorilla/websocket/examples/command/main.go @@ -0,0 +1,188 @@ +// Copyright 2015 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package main + +import ( + "bufio" + "flag" + "io" + "log" + "net/http" + "os" + "os/exec" + "text/template" + "time" + + "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/gorilla/websocket" +) + +var ( + addr = flag.String("addr", "127.0.0.1:8080", "http service address") + cmdPath string + homeTempl = template.Must(template.ParseFiles("home.html")) +) + +const ( + // Time allowed to write a message to the peer. + writeWait = 10 * time.Second + + // Maximum message size allowed from peer. + maxMessageSize = 8192 + + // Time allowed to read the next pong message from the peer. + pongWait = 60 * time.Second + + // Send pings to peer with this period. Must be less than pongWait. + pingPeriod = (pongWait * 9) / 10 +) + +func pumpStdin(ws *websocket.Conn, w io.Writer) { + defer ws.Close() + ws.SetReadLimit(maxMessageSize) + ws.SetReadDeadline(time.Now().Add(pongWait)) + ws.SetPongHandler(func(string) error { ws.SetReadDeadline(time.Now().Add(pongWait)); return nil }) + for { + _, message, err := ws.ReadMessage() + if err != nil { + break + } + message = append(message, '\n') + if _, err := w.Write(message); err != nil { + break + } + } +} + +func pumpStdout(ws *websocket.Conn, r io.Reader, done chan struct{}) { + defer func() { + ws.Close() + close(done) + }() + s := bufio.NewScanner(r) + for s.Scan() { + ws.SetWriteDeadline(time.Now().Add(writeWait)) + if err := ws.WriteMessage(websocket.TextMessage, s.Bytes()); err != nil { + break + } + } + if s.Err() != nil { + log.Println("scan:", s.Err()) + } +} + +func ping(ws *websocket.Conn, done chan struct{}) { + ticker := time.NewTicker(pingPeriod) + defer ticker.Stop() + for { + select { + case <-ticker.C: + if err := ws.WriteControl(websocket.PingMessage, []byte{}, time.Now().Add(writeWait)); err != nil { + log.Println("ping:", err) + } + case <-done: + return + } + } +} + +func internalError(ws *websocket.Conn, msg string, err error) { + log.Println(msg, err) + ws.WriteMessage(websocket.TextMessage, []byte("Internal server error.")) +} + +var upgrader = websocket.Upgrader{} + +func serveWs(w http.ResponseWriter, r *http.Request) { + ws, err := upgrader.Upgrade(w, r, nil) + if err != nil { + log.Println("upgrade:", err) + return + } + + defer ws.Close() + + outr, outw, err := os.Pipe() + if err != nil { + internalError(ws, "stdout:", err) + return + } + defer outr.Close() + defer outw.Close() + + inr, inw, err := os.Pipe() + if err != nil { + internalError(ws, "stdin:", err) + return + } + defer inr.Close() + defer inw.Close() + + proc, err := os.StartProcess(cmdPath, flag.Args(), &os.ProcAttr{ + Files: []*os.File{inr, outw, outw}, + }) + if err != nil { + internalError(ws, "start:", err) + return + } + + inr.Close() + outw.Close() + + stdoutDone := make(chan struct{}) + go pumpStdout(ws, outr, stdoutDone) + go ping(ws, stdoutDone) + + pumpStdin(ws, inw) + + // Some commands will exit when stdin is closed. + inw.Close() + + // Other commands need a bonk on the head. + if err := proc.Signal(os.Interrupt); err != nil { + log.Println("inter:", err) + } + + select { + case <-stdoutDone: + case <-time.After(time.Second): + // A bigger bonk on the head. + if err := proc.Signal(os.Kill); err != nil { + log.Println("term:", err) + } + <-stdoutDone + } + + if _, err := proc.Wait(); err != nil { + log.Println("wait:", err) + } +} + +func serveHome(w http.ResponseWriter, r *http.Request) { + if r.URL.Path != "/" { + http.Error(w, "Not found", 404) + return + } + if r.Method != "GET" { + http.Error(w, "Method not allowed", 405) + return + } + w.Header().Set("Content-Type", "text/html; charset=utf-8") + homeTempl.Execute(w, r.Host) +} + +func main() { + flag.Parse() + if len(flag.Args()) < 1 { + log.Fatal("must specify at least one argument") + } + var err error + cmdPath, err = exec.LookPath(flag.Args()[0]) + if err != nil { + log.Fatal(err) + } + http.HandleFunc("/", serveHome) + http.HandleFunc("/ws", serveWs) + log.Fatal(http.ListenAndServe(*addr, nil)) +} diff --git a/Godeps/_workspace/src/github.com/gorilla/websocket/examples/echo/README.md b/Godeps/_workspace/src/github.com/gorilla/websocket/examples/echo/README.md new file mode 100644 index 00000000..6ad79ed7 --- /dev/null +++ b/Godeps/_workspace/src/github.com/gorilla/websocket/examples/echo/README.md @@ -0,0 +1,17 @@ +# Client and server example + +This example shows a simple client and server. + +The server echoes messages sent to it. The client sends a message every second +and prints all messages received. + +To run the example, start the server: + + $ go run server.go + +Next, start the client: + + $ go run client.go + +The server includes a simple web client. To use the client, open +http://127.0.0.1:8080 in the browser and follow the instructions on the page. diff --git a/Godeps/_workspace/src/github.com/gorilla/websocket/examples/echo/client.go b/Godeps/_workspace/src/github.com/gorilla/websocket/examples/echo/client.go new file mode 100644 index 00000000..4ed3ed6e --- /dev/null +++ b/Godeps/_workspace/src/github.com/gorilla/websocket/examples/echo/client.go @@ -0,0 +1,81 @@ +// Copyright 2015 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build ignore + +package main + +import ( + "flag" + "log" + "net/url" + "os" + "os/signal" + "time" + + "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/gorilla/websocket" +) + +var addr = flag.String("addr", "localhost:8080", "http service address") + +func main() { + flag.Parse() + log.SetFlags(0) + + interrupt := make(chan os.Signal, 1) + signal.Notify(interrupt, os.Interrupt) + + u := url.URL{Scheme: "ws", Host: *addr, Path: "/echo"} + log.Printf("connecting to %s", u.String()) + + c, _, err := websocket.DefaultDialer.Dial(u.String(), nil) + if err != nil { + log.Fatal("dial:", err) + } + defer c.Close() + + done := make(chan struct{}) + + go func() { + defer c.Close() + defer close(done) + for { + _, message, err := c.ReadMessage() + if err != nil { + log.Println("read:", err) + return + } + log.Printf("recv: %s", message) + } + }() + + ticker := time.NewTicker(time.Second) + defer ticker.Stop() + + for { + select { + case t := <-ticker.C: + err := c.WriteMessage(websocket.TextMessage, []byte(t.String())) + if err != nil { + log.Println("write:", err) + return + } + case <-interrupt: + log.Println("interrupt") + // To cleanly close a connection, a client should send a close + // frame and wait for the server to close the connection. + err := c.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseNormalClosure, "")) + if err != nil { + log.Println("write close:", err) + return + } + select { + case <-done: + case <-time.After(time.Second): + } + c.Close() + return + } + } +} diff --git a/Godeps/_workspace/src/github.com/gorilla/websocket/examples/echo/server.go b/Godeps/_workspace/src/github.com/gorilla/websocket/examples/echo/server.go new file mode 100644 index 00000000..663521a1 --- /dev/null +++ b/Godeps/_workspace/src/github.com/gorilla/websocket/examples/echo/server.go @@ -0,0 +1,132 @@ +// Copyright 2015 The Gorilla WebSocket Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +// +build ignore + +package main + +import ( + "flag" + "html/template" + "log" + "net/http" + + "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/gorilla/websocket" +) + +var addr = flag.String("addr", "localhost:8080", "http service address") + +var upgrader = websocket.Upgrader{} // use default options + +func echo(w http.ResponseWriter, r *http.Request) { + c, err := upgrader.Upgrade(w, r, nil) + if err != nil { + log.Print("upgrade:", err) + return + } + defer c.Close() + for { + mt, message, err := c.ReadMessage() + if err != nil { + log.Println("read:", err) + break + } + log.Printf("recv: %s", message) + err = c.WriteMessage(mt, message) + if err != nil { + log.Println("write:", err) + break + } + } +} + +func home(w http.ResponseWriter, r *http.Request) { + homeTemplate.Execute(w, "ws://"+r.Host+"/echo") +} + +func main() { + flag.Parse() + log.SetFlags(0) + http.HandleFunc("/echo", echo) + http.HandleFunc("/", home) + log.Fatal(http.ListenAndServe(*addr, nil)) +} + +var homeTemplate = template.Must(template.New("").Parse(` + + + + + + + +
+

Click "Open" to create a connection to the server, +"Send" to send a message to the server and "Close" to close the connection. +You can change the message and send multiple times. +

+

+ + +

+ +

+
+
+
+ + +`)) diff --git a/Godeps/_workspace/src/github.com/gorilla/websocket/json.go b/Godeps/_workspace/src/github.com/gorilla/websocket/json.go index 18e62f22..4f0e3687 100644 --- a/Godeps/_workspace/src/github.com/gorilla/websocket/json.go +++ b/Godeps/_workspace/src/github.com/gorilla/websocket/json.go @@ -48,9 +48,7 @@ func (c *Conn) ReadJSON(v interface{}) error { } err = json.NewDecoder(r).Decode(v) if err == io.EOF { - // Decode returns io.EOF when the message is empty or all whitespace. - // Convert to io.ErrUnexpectedEOF so that application can distinguish - // between an error reading the JSON value and the connection closing. + // One value is expected in the message. err = io.ErrUnexpectedEOF } return err diff --git a/Godeps/_workspace/src/github.com/gorilla/websocket/json_test.go b/Godeps/_workspace/src/github.com/gorilla/websocket/json_test.go index 1b7a5ec8..61100e48 100644 --- a/Godeps/_workspace/src/github.com/gorilla/websocket/json_test.go +++ b/Godeps/_workspace/src/github.com/gorilla/websocket/json_test.go @@ -38,7 +38,7 @@ func TestJSON(t *testing.T) { } } -func TestPartialJsonRead(t *testing.T) { +func TestPartialJSONRead(t *testing.T) { var buf bytes.Buffer c := fakeNetConn{&buf, &buf} wc := newConn(c, true, 1024, 1024) @@ -87,7 +87,7 @@ func TestPartialJsonRead(t *testing.T) { } err = rc.ReadJSON(&v) - if err != io.EOF { + if _, ok := err.(*CloseError); !ok { t.Error("final", err) } } diff --git a/Godeps/_workspace/src/github.com/gorilla/websocket/server.go b/Godeps/_workspace/src/github.com/gorilla/websocket/server.go index e56a0049..3a9805f0 100644 --- a/Godeps/_workspace/src/github.com/gorilla/websocket/server.go +++ b/Godeps/_workspace/src/github.com/gorilla/websocket/server.go @@ -93,6 +93,9 @@ func (u *Upgrader) selectSubprotocol(r *http.Request, responseHeader http.Header // request. Use the responseHeader to specify cookies (Set-Cookie) and the // application negotiated subprotocol (Sec-Websocket-Protocol). func (u *Upgrader) Upgrade(w http.ResponseWriter, r *http.Request, responseHeader http.Header) (*Conn, error) { + if r.Method != "GET" { + return u.returnError(w, r, http.StatusMethodNotAllowed, "websocket: method not GET") + } if values := r.Header["Sec-Websocket-Version"]; len(values) == 0 || values[0] != "13" { return u.returnError(w, r, http.StatusBadRequest, "websocket: version != 13") } diff --git a/Godeps/_workspace/src/github.com/mitchellh/go-homedir/homedir.go b/Godeps/_workspace/src/github.com/mitchellh/go-homedir/homedir.go index 051f1116..6944957d 100644 --- a/Godeps/_workspace/src/github.com/mitchellh/go-homedir/homedir.go +++ b/Godeps/_workspace/src/github.com/mitchellh/go-homedir/homedir.go @@ -7,20 +7,49 @@ import ( "os/exec" "path/filepath" "runtime" + "strconv" "strings" + "sync" ) +// DisableCache will disable caching of the home directory. Caching is enabled +// by default. +var DisableCache bool + +var homedirCache string +var cacheLock sync.RWMutex + // Dir returns the home directory for the executing user. // // This uses an OS-specific method for discovering the home directory. // An error is returned if a home directory cannot be detected. func Dir() (string, error) { - if runtime.GOOS == "windows" { - return dirWindows() + if !DisableCache { + cacheLock.RLock() + cached := homedirCache + cacheLock.RUnlock() + if cached != "" { + return cached, nil + } } - // Unix-like system, so just assume Unix - return dirUnix() + cacheLock.Lock() + defer cacheLock.Unlock() + + var result string + var err error + if runtime.GOOS == "windows" { + result, err = dirWindows() + } else { + // Unix-like system, so just assume Unix + result, err = dirUnix() + } + + if err != nil { + return "", err + } + homedirCache = result + return result, nil } // Expand expands the path to include the home directory if the path @@ -53,9 +82,28 @@ func dirUnix() (string, error) { return home, nil } - // If that fails, try the shell + // If that fails, try getent var stdout bytes.Buffer - cmd := exec.Command("sh", "-c", "eval echo ~$USER") + cmd := exec.Command("getent", "passwd", strconv.Itoa(os.Getuid())) + cmd.Stdout = &stdout + if err := cmd.Run(); err != nil { + // If "getent" is missing, ignore it + if err != exec.ErrNotFound { + return "", err + } + } else { + if passwd := strings.TrimSpace(stdout.String()); passwd != "" { + // username:password:uid:gid:gecos:home:shell + passwdParts := strings.SplitN(passwd, ":", 7) + if len(passwdParts) > 5 { + return passwdParts[5], nil + } + } + } + + // If all else fails, try the shell + stdout.Reset() + cmd = exec.Command("sh", "-c", "cd && pwd") cmd.Stdout = &stdout if err := cmd.Run(); err != nil { return "", err diff --git a/Godeps/_workspace/src/github.com/mitchellh/go-homedir/homedir_test.go b/Godeps/_workspace/src/github.com/mitchellh/go-homedir/homedir_test.go index ddc24ee0..c34dbc7f 100644 --- a/Godeps/_workspace/src/github.com/mitchellh/go-homedir/homedir_test.go +++ b/Godeps/_workspace/src/github.com/mitchellh/go-homedir/homedir_test.go @@ -17,6 +17,18 @@ func patchEnv(key, value string) func() { return deferFunc } +func BenchmarkDir(b *testing.B) { + // We do this for any "warmups" + for i := 0; i < 10; i++ { + Dir() + } + + b.ResetTimer() + for i := 0; i < b.N; i++ { + Dir() + } +} + func TestDir(t *testing.T) { u, err := user.Current() if err != nil { @@ -86,6 +98,8 @@ func TestExpand(t *testing.T) { } } + DisableCache = true + defer func() { DisableCache = false }() defer patchEnv("HOME", "/custom/path/")() expected := "/custom/path/foo/bar" actual, err := Expand("~/foo/bar") diff --git a/Godeps/_workspace/src/github.com/spf13/cast/cast.go b/Godeps/_workspace/src/github.com/spf13/cast/cast.go index 1dde519f..0bc8d48c 100644 --- a/Godeps/_workspace/src/github.com/spf13/cast/cast.go +++ b/Godeps/_workspace/src/github.com/spf13/cast/cast.go @@ -42,6 +42,11 @@ func ToStringMapString(i interface{}) map[string]string { return v } +func ToStringMapStringSlice(i interface{}) map[string][]string { + v, _ := ToStringMapStringSliceE(i) + return v +} + func ToStringMapBool(i interface{}) map[string]bool { v, _ := ToStringMapBoolE(i) return v diff --git a/Godeps/_workspace/src/github.com/spf13/cast/cast_test.go b/Godeps/_workspace/src/github.com/spf13/cast/cast_test.go index 107d6038..40362ba7 100644 --- a/Godeps/_workspace/src/github.com/spf13/cast/cast_test.go +++ b/Godeps/_workspace/src/github.com/spf13/cast/cast_test.go @@ -6,9 +6,9 @@ package cast import ( - "testing" - "html/template" + "testing" + "time" "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/stretchr/testify/assert" ) @@ -37,8 +37,11 @@ func TestToString(t *testing.T) { assert.Equal(t, ToString(8.12), "8.12") assert.Equal(t, ToString([]byte("one time")), "one time") assert.Equal(t, ToString(template.HTML("one time")), "one time") + assert.Equal(t, ToString(template.URL("http://somehost.foo")), "http://somehost.foo") assert.Equal(t, ToString(foo), "one more time") assert.Equal(t, ToString(nil), "") + assert.Equal(t, ToString(true), "true") + assert.Equal(t, ToString(false), "false") } type foo struct { @@ -73,8 +76,43 @@ func TestErrorToString(t *testing.T) { func TestMaps(t *testing.T) { var taxonomies = map[interface{}]interface{}{"tag": "tags", "group": "groups"} var stringMapBool = map[interface{}]interface{}{"v1": true, "v2": false} + + // ToStringMapString inputs/outputs + var stringMapString = map[string]string{"key 1": "value 1", "key 2": "value 2", "key 3": "value 3"} + var stringMapInterface = map[string]interface{}{"key 1": "value 1", "key 2": "value 2", "key 3": "value 3"} + var interfaceMapString = map[interface{}]string{"key 1": "value 1", "key 2": "value 2", "key 3": "value 3"} + var interfaceMapInterface = map[interface{}]interface{}{"key 1": "value 1", "key 2": "value 2", "key 3": "value 3"} + + // ToStringMapStringSlice inputs/outputs + var stringMapStringSlice = map[string][]string{"key 1": []string{"value 1", "value 2", "value 3"}, "key 2": []string{"value 1", "value 2", "value 3"}, "key 3": []string{"value 1", "value 2", "value 3"}} + var stringMapInterfaceSlice = map[string][]interface{}{"key 1": []interface{}{"value 1", "value 2", "value 3"}, "key 2": []interface{}{"value 1", "value 2", "value 3"}, "key 3": []interface{}{"value 1", "value 2", "value 3"}} + var stringMapStringSingleSliceFieldsResult = map[string][]string{"key 1": []string{"value", "1"}, "key 2": []string{"value", "2"}, "key 3": []string{"value", "3"}} + var interfaceMapStringSlice = map[interface{}][]string{"key 1": []string{"value 1", "value 2", "value 3"}, "key 2": []string{"value 1", "value 2", "value 3"}, "key 3": []string{"value 1", "value 2", "value 3"}} + var interfaceMapInterfaceSlice = map[interface{}][]interface{}{"key 1": []interface{}{"value 1", "value 2", "value 3"}, "key 2": []interface{}{"value 1", "value 2", "value 3"}, "key 3": []interface{}{"value 1", "value 2", "value 3"}} + + var stringMapStringSliceMultiple = map[string][]string{"key 1": []string{"value 1", "value 2", "value 3"}, "key 2": []string{"value 1", "value 2", "value 3"}, "key 3": []string{"value 1", "value 2", "value 3"}} + var stringMapStringSliceSingle = map[string][]string{"key 1": []string{"value 1"}, "key 2": []string{"value 2"}, "key 3": []string{"value 3"}} + assert.Equal(t, ToStringMap(taxonomies), map[string]interface{}{"tag": "tags", "group": "groups"}) assert.Equal(t, ToStringMapBool(stringMapBool), map[string]bool{"v1": true, "v2": false}) + + // ToStringMapString tests + assert.Equal(t, ToStringMapString(stringMapString), stringMapString) + assert.Equal(t, ToStringMapString(stringMapInterface), stringMapString) + assert.Equal(t, ToStringMapString(interfaceMapString), stringMapString) + assert.Equal(t, ToStringMapString(interfaceMapInterface), stringMapString) + + // ToStringMapStringSlice tests + assert.Equal(t, ToStringMapStringSlice(stringMapStringSlice), stringMapStringSlice) + assert.Equal(t, ToStringMapStringSlice(stringMapInterfaceSlice), stringMapStringSlice) + assert.Equal(t, ToStringMapStringSlice(stringMapStringSliceMultiple), stringMapStringSlice) + assert.Equal(t, ToStringMapStringSlice(stringMapStringSliceMultiple), stringMapStringSlice) + assert.Equal(t, ToStringMapStringSlice(stringMapString), stringMapStringSliceSingle) + assert.Equal(t, ToStringMapStringSlice(stringMapInterface), stringMapStringSliceSingle) + assert.Equal(t, ToStringMapStringSlice(interfaceMapStringSlice), stringMapStringSlice) + assert.Equal(t, ToStringMapStringSlice(interfaceMapInterfaceSlice), stringMapStringSlice) + assert.Equal(t, ToStringMapStringSlice(interfaceMapString), stringMapStringSingleSliceFieldsResult) + assert.Equal(t, ToStringMapStringSlice(interfaceMapInterface), stringMapStringSingleSliceFieldsResult) } func TestSlices(t *testing.T) { @@ -115,3 +153,12 @@ func TestIndirectPointers(t *testing.T) { assert.Equal(t, ToInt(y), 13) assert.Equal(t, ToInt(z), 13) } + +func TestToDuration(t *testing.T) { + a := time.Second * 5 + ai := int64(a) + b := time.Second * 5 + bf := float64(b) + assert.Equal(t, ToDuration(ai), a) + assert.Equal(t, ToDuration(bf), b) +} diff --git a/Godeps/_workspace/src/github.com/spf13/cast/caste.go b/Godeps/_workspace/src/github.com/spf13/cast/caste.go index ec782596..cf35bfb2 100644 --- a/Godeps/_workspace/src/github.com/spf13/cast/caste.go +++ b/Godeps/_workspace/src/github.com/spf13/cast/caste.go @@ -6,7 +6,6 @@ package cast import ( - "errors" "fmt" "html/template" "reflect" @@ -17,6 +16,7 @@ import ( jww "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/spf13/jwalterweatherman" ) +// ToTimeE casts an empty interface to time.Time. func ToTimeE(i interface{}) (tim time.Time, err error) { i = indirect(i) jww.DEBUG.Println("ToTimeE called on type:", reflect.TypeOf(i)) @@ -35,6 +35,7 @@ func ToTimeE(i interface{}) (tim time.Time, err error) { } } +// ToDurationE casts an empty interface to time.Duration. func ToDurationE(i interface{}) (d time.Duration, err error) { i = indirect(i) jww.DEBUG.Println("ToDurationE called on type:", reflect.TypeOf(i)) @@ -42,6 +43,12 @@ func ToDurationE(i interface{}) (d time.Duration, err error) { switch s := i.(type) { case time.Duration: return s, nil + case int64: + d = time.Duration(s) + return + case float64: + d = time.Duration(s) + return case string: d, err = time.ParseDuration(s) return @@ -51,6 +58,7 @@ func ToDurationE(i interface{}) (d time.Duration, err error) { } } +// ToBoolE casts an empty interface to a bool. func ToBoolE(i interface{}) (bool, error) { i = indirect(i) jww.DEBUG.Println("ToBoolE called on type:", reflect.TypeOf(i)) @@ -72,6 +80,7 @@ func ToBoolE(i interface{}) (bool, error) { } } +// ToFloat64E casts an empty interface to a float64. func ToFloat64E(i interface{}) (float64, error) { i = indirect(i) jww.DEBUG.Println("ToFloat64E called on type:", reflect.TypeOf(i)) @@ -95,14 +104,14 @@ func ToFloat64E(i interface{}) (float64, error) { v, err := strconv.ParseFloat(s, 64) if err == nil { return float64(v), nil - } else { - return 0.0, fmt.Errorf("Unable to Cast %#v to float", i) } + return 0.0, fmt.Errorf("Unable to Cast %#v to float", i) default: return 0.0, fmt.Errorf("Unable to Cast %#v to float", i) } } +// ToIntE casts an empty interface to an int. func ToIntE(i interface{}) (int, error) { i = indirect(i) jww.DEBUG.Println("ToIntE called on type:", reflect.TypeOf(i)) @@ -122,17 +131,15 @@ func ToIntE(i interface{}) (int, error) { v, err := strconv.ParseInt(s, 0, 0) if err == nil { return int(v), nil - } else { - return 0, fmt.Errorf("Unable to Cast %#v to int", i) } + return 0, fmt.Errorf("Unable to Cast %#v to int", i) case float64: return int(s), nil case bool: if bool(s) { return 1, nil - } else { - return 0, nil } + return 0, nil case nil: return 0, nil default: @@ -179,6 +186,7 @@ func indirectToStringerOrError(a interface{}) interface{} { return v.Interface() } +// ToStringE casts an empty interface to a string. func ToStringE(i interface{}) (string, error) { i = indirectToStringerOrError(i) jww.DEBUG.Println("ToStringE called on type:", reflect.TypeOf(i)) @@ -186,6 +194,8 @@ func ToStringE(i interface{}) (string, error) { switch s := i.(type) { case string: return s, nil + case bool: + return strconv.FormatBool(s), nil case float64: return strconv.FormatFloat(i.(float64), 'f', -1, 64), nil case int: @@ -194,6 +204,8 @@ func ToStringE(i interface{}) (string, error) { return string(s), nil case template.HTML: return string(s), nil + case template.URL: + return string(s), nil case nil: return "", nil case fmt.Stringer: @@ -205,30 +217,92 @@ func ToStringE(i interface{}) (string, error) { } } +// ToStringMapStringE casts an empty interface to a map[string]string. func ToStringMapStringE(i interface{}) (map[string]string, error) { jww.DEBUG.Println("ToStringMapStringE called on type:", reflect.TypeOf(i)) var m = map[string]string{} switch v := i.(type) { - case map[interface{}]interface{}: - for k, val := range v { - m[ToString(k)] = ToString(val) - } - return m, nil + case map[string]string: + return v, nil case map[string]interface{}: for k, val := range v { m[ToString(k)] = ToString(val) } return m, nil - case map[string]string: - return v, nil + case map[interface{}]string: + for k, val := range v { + m[ToString(k)] = ToString(val) + } + return m, nil + case map[interface{}]interface{}: + for k, val := range v { + m[ToString(k)] = ToString(val) + } + return m, nil default: return m, fmt.Errorf("Unable to Cast %#v to map[string]string", i) } - return m, fmt.Errorf("Unable to Cast %#v to map[string]string", i) } +// ToStringMapStringSliceE casts an empty interface to a map[string][]string. +func ToStringMapStringSliceE(i interface{}) (map[string][]string, error) { + jww.DEBUG.Println("ToStringMapStringSliceE called on type:", reflect.TypeOf(i)) + + var m = map[string][]string{} + + switch v := i.(type) { + case map[string][]string: + return v, nil + case map[string][]interface{}: + for k, val := range v { + m[ToString(k)] = ToStringSlice(val) + } + return m, nil + case map[string]string: + for k, val := range v { + m[ToString(k)] = []string{val} + } + case map[string]interface{}: + for k, val := range v { + m[ToString(k)] = []string{ToString(val)} + } + return m, nil + case map[interface{}][]string: + for k, val := range v { + m[ToString(k)] = ToStringSlice(val) + } + return m, nil + case map[interface{}]string: + for k, val := range v { + m[ToString(k)] = ToStringSlice(val) + } + return m, nil + case map[interface{}][]interface{}: + for k, val := range v { + m[ToString(k)] = ToStringSlice(val) + } + return m, nil + case map[interface{}]interface{}: + for k, val := range v { + key, err := ToStringE(k) + if err != nil { + return m, fmt.Errorf("Unable to Cast %#v to map[string][]string", i) + } + value, err := ToStringSliceE(val) + if err != nil { + return m, fmt.Errorf("Unable to Cast %#v to map[string][]string", i) + } + m[key] = value + } + default: + return m, fmt.Errorf("Unable to Cast %#v to map[string][]string", i) + } + return m, nil +} + +// ToStringMapBoolE casts an empty interface to a map[string]bool. func ToStringMapBoolE(i interface{}) (map[string]bool, error) { jww.DEBUG.Println("ToStringMapBoolE called on type:", reflect.TypeOf(i)) @@ -250,9 +324,9 @@ func ToStringMapBoolE(i interface{}) (map[string]bool, error) { default: return m, fmt.Errorf("Unable to Cast %#v to map[string]bool", i) } - return m, fmt.Errorf("Unable to Cast %#v to map[string]bool", i) } +// ToStringMapE casts an empty interface to a map[string]interface{}. func ToStringMapE(i interface{}) (map[string]interface{}, error) { jww.DEBUG.Println("ToStringMapE called on type:", reflect.TypeOf(i)) @@ -269,10 +343,9 @@ func ToStringMapE(i interface{}) (map[string]interface{}, error) { default: return m, fmt.Errorf("Unable to Cast %#v to map[string]interface{}", i) } - - return m, fmt.Errorf("Unable to Cast %#v to map[string]interface{}", i) } +// ToSliceE casts an empty interface to a []interface{}. func ToSliceE(i interface{}) ([]interface{}, error) { jww.DEBUG.Println("ToSliceE called on type:", reflect.TypeOf(i)) @@ -292,10 +365,9 @@ func ToSliceE(i interface{}) ([]interface{}, error) { default: return s, fmt.Errorf("Unable to Cast %#v of type %v to []interface{}", i, reflect.TypeOf(i)) } - - return s, fmt.Errorf("Unable to Cast %#v to []interface{}", i) } +// ToStringSliceE casts an empty interface to a []string. func ToStringSliceE(i interface{}) ([]string, error) { jww.DEBUG.Println("ToStringSliceE called on type:", reflect.TypeOf(i)) @@ -311,13 +383,18 @@ func ToStringSliceE(i interface{}) ([]string, error) { return v, nil case string: return strings.Fields(v), nil + case interface{}: + str, err := ToStringE(v) + if err != nil { + return a, fmt.Errorf("Unable to Cast %#v to []string", i) + } + return []string{str}, nil default: return a, fmt.Errorf("Unable to Cast %#v to []string", i) } - - return a, fmt.Errorf("Unable to Cast %#v to []string", i) } +// ToIntSliceE casts an empty interface to a []int. func ToIntSliceE(i interface{}) ([]int, error) { jww.DEBUG.Println("ToIntSliceE called on type:", reflect.TypeOf(i)) @@ -346,10 +423,9 @@ func ToIntSliceE(i interface{}) ([]int, error) { default: return []int{}, fmt.Errorf("Unable to Cast %#v to []int", i) } - - return []int{}, fmt.Errorf("Unable to Cast %#v to []int", i) } +// StringToDate casts an empty interface to a time.Time. func StringToDate(s string) (time.Time, error) { return parseDateWith(s, []string{ time.RFC3339, @@ -365,6 +441,8 @@ func StringToDate(s string) (time.Time, error) { "02 Jan 06 15:04 MST", "2006-01-02", "02 Jan 2006", + "2006-01-02 15:04:05 -07:00", + "2006-01-02 15:04:05 -0700", }) } @@ -374,5 +452,5 @@ func parseDateWith(s string, dates []string) (d time.Time, e error) { return } } - return d, errors.New(fmt.Sprintf("Unable to parse date: %s", s)) + return d, fmt.Errorf("Unable to parse date: %s", s) } diff --git a/Godeps/_workspace/src/github.com/spf13/cobra/.mailmap b/Godeps/_workspace/src/github.com/spf13/cobra/.mailmap new file mode 100644 index 00000000..94ec5306 --- /dev/null +++ b/Godeps/_workspace/src/github.com/spf13/cobra/.mailmap @@ -0,0 +1,3 @@ +Steve Francia +Bjørn Erik Pedersen +Fabiano Franz diff --git a/Godeps/_workspace/src/github.com/spf13/cobra/.travis.yml b/Godeps/_workspace/src/github.com/spf13/cobra/.travis.yml index dc43afd6..7a6cb7fd 100644 --- a/Godeps/_workspace/src/github.com/spf13/cobra/.travis.yml +++ b/Godeps/_workspace/src/github.com/spf13/cobra/.travis.yml @@ -1,8 +1,9 @@ language: go go: - - 1.3 + - 1.3.3 - 1.4.2 + - 1.5.1 - tip script: - - go test ./... + - go test -v ./... - go build diff --git a/Godeps/_workspace/src/github.com/spf13/cobra/README.md b/Godeps/_workspace/src/github.com/spf13/cobra/README.md index 710b66eb..b5cbb6b4 100644 --- a/Godeps/_workspace/src/github.com/spf13/cobra/README.md +++ b/Godeps/_workspace/src/github.com/spf13/cobra/README.md @@ -1,122 +1,363 @@ -# Cobra +![cobra logo](https://cloud.githubusercontent.com/assets/173412/10886352/ad566232-814f-11e5-9cd0-aa101788c117.png) -A Commander for modern go CLI interactions +Cobra is both a library for creating powerful modern CLI applications as well as a program to generate applications and command files. -[![Build Status](https://travis-ci.org/spf13/cobra.svg)](https://travis-ci.org/spf13/cobra) +Many of the most widely used Go projects are built using Cobra including: -## Overview +* [Kubernetes](http://kubernetes.io/) +* [Hugo](http://gohugo.io) +* [rkt](https://github.com/coreos/rkt) +* [etcd](https://github.com/coreos/etcd) +* [Docker (distribution)](https://github.com/docker/distribution) +* [OpenShift](https://www.openshift.com/) +* [Delve](https://github.com/derekparker/delve) +* [GopherJS](http://www.gopherjs.org/) +* [CockroachDB](http://www.cockroachlabs.com/) +* [Bleve](http://www.blevesearch.com/) +* [ProjectAtomic (enterprise)](http://www.projectatomic.io/) +* [Parse (CLI)](https://parse.com/) +* [Nanobox](https://github.com/nanobox-io/nanobox)/[Nanopack](https://github.com/nanopack) -Cobra is a commander providing a simple interface to create powerful modern CLI -interfaces similar to git & go tools. In addition to providing an interface, Cobra -simultaneously provides a controller to organize your application code. -Inspired by go, go-Commander, gh and subcommand, Cobra improves on these by -providing **fully posix compliant flags** (including short & long versions), -**nesting commands**, and the ability to **define your own help and usage** for any or -all commands. +[![Build Status](https://travis-ci.org/spf13/cobra.svg "Travis CI status")](https://travis-ci.org/spf13/cobra) +[![CircleCI status](https://circleci.com/gh/spf13/cobra.png?circle-token=:circle-token "CircleCI status")](https://circleci.com/gh/spf13/cobra) + +![cobra](https://cloud.githubusercontent.com/assets/173412/10911369/84832a8e-8212-11e5-9f82-cc96660a4794.gif) + +# Overview + +Cobra is a library providing a simple interface to create powerful modern CLI +interfaces similar to git & go tools. + +Cobra is also an application that will generate your application scaffolding to rapidly +develop a Cobra-based application. + +Cobra provides: +* Easy subcommand-based CLIs: `app server`, `app fetch`, etc. +* Fully POSIX-compliant flags (including short & long versions) +* Nested subcommands +* Global, local and cascading flags +* Easy generation of applications & commands with `cobra create appname` & `cobra add cmdname` +* Intelligent suggestions (`app srver`... did you mean `app server`?) +* Automatic help generation for commands and flags +* Automatic detailed help for `app help [command]` +* Automatic help flag recognition of `-h`, `--help`, etc. +* Automatically generated bash autocomplete for your application +* Automatically generated man pages for your application +* Command aliases so you can change things without breaking them +* The flexibilty to define your own help, usage, etc. +* Optional tight integration with [viper](http://github.com/spf13/viper) for 12-factor apps Cobra has an exceptionally clean interface and simple design without needless constructors or initialization methods. -Applications built with Cobra commands are designed to be as user friendly as +Applications built with Cobra commands are designed to be as user-friendly as possible. Flags can be placed before or after the command (as long as a confusing space isn’t provided). Both short and long flags can be used. A -command need not even be fully typed. The shortest unambiguous string will -suffice. Help is automatically generated and available for the application or -for a specific command using either the help command or the --help flag. +command need not even be fully typed. Help is automatically generated and +available for the application or for a specific command using either the help +command or the `--help` flag. -## Concepts +# Concepts -Cobra is built on a structure of commands & flags. +Cobra is built on a structure of commands, arguments & flags. -**Commands** represent actions and **Flags** are modifiers for those actions. +**Commands** represent actions, **Args** are things and **Flags** are modifiers for those actions. -In the following example 'server' is a command and 'port' is a flag. +The best applications will read like sentences when used. Users will know how +to use the application because they will natively understand how to use it. - hugo server --port=1313 +The pattern to follow is +`APPNAME VERB NOUN --ADJECTIVE.` + or +`APPNAME COMMAND ARG --FLAG` -### Commands +A few good real world examples may better illustrate this point. + +In the following example, 'server' is a command, and 'port' is a flag: + + > hugo server --port=1313 + +In this command we are telling Git to clone the url bare. + + > git clone URL --bare + +## Commands Command is the central point of the application. Each interaction that the application supports will be contained in a Command. A command can have children commands and optionally run an action. -In the example above 'server' is the command +In the example above, 'server' is the command. A Command has the following structure: - type Command struct { - Use string // The one-line usage message. - Short string // The short description shown in the 'help' output. - Long string // The long message shown in the 'help ' output. - Run func(cmd *Command, args []string) // Run runs the command. - } +```go +type Command struct { + Use string // The one-line usage message. + Short string // The short description shown in the 'help' output. + Long string // The long message shown in the 'help ' output. + Run func(cmd *Command, args []string) // Run runs the command. +} +``` -### Flags +## Flags -A Flag is a way to modify the behavior of an command. Cobra supports -fully posix compliant flags as well as the go flag package. +A Flag is a way to modify the behavior of a command. Cobra supports +fully POSIX-compliant flags as well as the Go [flag package](https://golang.org/pkg/flag/). A Cobra command can define flags that persist through to children commands and flags that are only available to that command. -In the example above 'port' is the flag. +In the example above, 'port' is the flag. Flag functionality is provided by the [pflag -libary](https://github.com/ogier/pflag), a fork of the flag standard library -which maintains the same interface while adding posix compliance. +library](https://github.com/ogier/pflag), a fork of the flag standard library +which maintains the same interface while adding POSIX compliance. ## Usage Cobra works by creating a set of commands and then organizing them into a tree. The tree defines the structure of the application. -Once each command is defined with it's corresponding flags, then the +Once each command is defined with its corresponding flags, then the tree is assigned to the commander which is finally executed. -### Installing -Using Cobra is easy. First use go get to install the latest version -of the library. +# Installing +Using Cobra is easy. First, use `go get` to install the latest version +of the library. This command will install the `cobra` generator executible +along with the library: - $ go get github.com/spf13/cobra + > go get -v github.com/spf13/cobra/cobra -Next include cobra in your application. +Next, include Cobra in your application: - import "github.com/spf13/cobra" +```go +import "github.com/spf13/cobra" +``` + +# Getting Started + +While you are welcome to provide your own organization, typically a Cobra based +application will follow the following organizational structure. + +``` + ▾ appName/ + ▾ cmd/ + add.go + your.go + commands.go + here.go + main.go +``` + +In a Cobra app, typically the main.go file is very bare. It serves, one purpose, to initialize Cobra. + +```go +package main + +import "{pathToYourApp}/cmd" + +func main() { + if err := cmd.RootCmd.Execute(); err != nil { + fmt.Println(err) + os.Exit(-1) + } +} +``` + +## Using the Cobra Generator + +Cobra provides its own program that will create your application and add any +commands you want. It's the easiest way to incorporate Cobra into your application. + +### cobra init + +The `cobra init [yourApp]` command will create your initial application code +for you. It is a very powerful application that will populate your program with +the right structure so you can immediately enjoy all the benefits of Cobra. It +will also automatically apply the license you specify to your application. + +Cobra init is pretty smart. You can provide it a full path, or simply a path +similar to what is expected in the import. + +``` +cobra init github.com/spf13/newAppName +``` + +### cobra add + +Once an application is initialized Cobra can create additional commands for you. +Let's say you created an app and you wanted the following commands for it: + +* app serve +* app config +* app config create + +In your project directory (where your main.go file is) you would run the following: + +``` +cobra add serve +cobra add config +cobra add create -p 'configCmd' +``` + +Once you have run these three commands you would have an app structure that would look like: + +``` + ▾ app/ + ▾ cmd/ + serve.go + config.go + create.go + main.go +``` + +at this point you can run `go run main.go` and it would run your app. `go run +main.go serve`, `go run main.go config`, `go run main.go config create` along +with `go run main.go help serve`, etc would all work. + +Obviously you haven't added your own code to these yet, the commands are ready +for you to give them their tasks. Have fun. + +### Configuring the cobra generator + +The cobra generator will be easier to use if you provide a simple configuration +file which will help you eliminate providing a bunch of repeated information in +flags over and over. + +an example ~/.cobra.yaml file: + +```yaml +author: Steve Francia +license: MIT +``` + +## Manually implementing Cobra + +To manually implement cobra you need to create a bare main.go file and a RootCmd file. +You will optionally provide additional commands as you see fit. ### Create the root command The root command represents your binary itself. + +#### Manually create rootCmd + Cobra doesn't require any special constructors. Simply create your commands. - var HugoCmd = &cobra.Command{ - Use: "hugo", - Short: "Hugo is a very fast static site generator", - Long: `A Fast and Flexible Static Site Generator built with +Ideally you place this in app/cmd/root.go: + +```go +var RootCmd = &cobra.Command{ + Use: "hugo", + Short: "Hugo is a very fast static site generator", + Long: `A Fast and Flexible Static Site Generator built with love by spf13 and friends in Go. Complete documentation is available at http://hugo.spf13.com`, - Run: func(cmd *cobra.Command, args []string) { - // Do Stuff Here - }, - } + Run: func(cmd *cobra.Command, args []string) { + // Do Stuff Here + }, +} +``` + +You will additionally define flags and handle configuration in your init() function. + +for example cmd/root.go: + +```go +func init() { + cobra.OnInitialize(initConfig) + RootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is $HOME/.cobra.yaml)") + RootCmd.PersistentFlags().StringVarP(&projectBase, "projectbase", "b", "", "base project directory eg. github.com/spf13/") + RootCmd.PersistentFlags().StringP("author", "a", "YOUR NAME", "Author name for copyright attribution") + RootCmd.PersistentFlags().StringVarP(&userLicense, "license", "l", "", "Name of license for the project (can provide `licensetext` in config)") + RootCmd.PersistentFlags().Bool("viper", true, "Use Viper for configuration") + viper.BindPFlag("author", RootCmd.PersistentFlags().Lookup("author")) + viper.BindPFlag("projectbase", RootCmd.PersistentFlags().Lookup("projectbase")) + viper.BindPFlag("useViper", RootCmd.PersistentFlags().Lookup("viper")) + viper.SetDefault("author", "NAME HERE ") + viper.SetDefault("license", "apache") +} +``` + +### Create your main.go + +With the root command you need to have your main function execute it. +Execute should be run on the root for clarity, though it can be called on any command. + +In a Cobra app, typically the main.go file is very bare. It serves, one purpose, to initialize Cobra. + +```go +package main + +import "{pathToYourApp}/cmd" + +func main() { + if err := cmd.RootCmd.Execute(); err != nil { + fmt.Println(err) + os.Exit(-1) + } +} +``` + ### Create additional commands -Additional commands can be defined. +Additional commands can be defined and typically are each given their own file +inside of the cmd/ directory. - var versionCmd = &cobra.Command{ - Use: "version", - Short: "Print the version number of Hugo", - Long: `All software has versions. This is Hugo's`, - Run: func(cmd *cobra.Command, args []string) { - fmt.Println("Hugo Static Site Generator v0.9 -- HEAD") - }, - } +If you wanted to create a version command you would create cmd/version.go and +populate it with the following: + +```go +package cmd + +import ( + "github.com/spf13/cobra" +) + +func init() { + RootCmd.AddCommand(versionCmd) +} + +var versionCmd = &cobra.Command{ + Use: "version", + Short: "Print the version number of Hugo", + Long: `All software has versions. This is Hugo's`, + Run: func(cmd *cobra.Command, args []string) { + fmt.Println("Hugo Static Site Generator v0.9 -- HEAD") + }, +} +``` ### Attach command to its parent -In this example we are attaching it to the root, but commands can be attached at any level. - HugoCmd.AddCommand(versionCmd) + +If you notice in the above example we attach the command to its parent. In +this case the parent is the rootCmd. In this example we are attaching it to the +root, but commands can be attached at any level. + +```go +RootCmd.AddCommand(versionCmd) +``` + +### Remove a command from its parent + +Removing a command is not a common action in simple programs, but it allows 3rd +parties to customize an existing command tree. + +In this example, we remove the existing `VersionCmd` command of an existing +root command, and we replace it with our own version: + +```go +mainlib.RootCmd.RemoveCommand(mainlib.VersionCmd) +mainlib.RootCmd.AddCommand(versionCmd) +``` + +## Working with Flags + +Flags provide modifiers to control how the action command operates. ### Assign flags to a command @@ -124,43 +365,35 @@ Since the flags are defined and used in different locations, we need to define a variable outside with the correct scope to assign the flag to work with. - var Verbose bool - var Source string +```go +var Verbose bool +var Source string +``` There are two different approaches to assign a flag. -#### Persistent Flags +### Persistent Flags A flag can be 'persistent' meaning that this flag will be available to the command it's assigned to as well as every command under that command. For -global flags assign a flag as a persistent flag on the root. +global flags, assign a flag as a persistent flag on the root. - HugoCmd.PersistentFlags().BoolVarP(&Verbose, "verbose", "v", false, "verbose output") +```go +RootCmd.PersistentFlags().BoolVarP(&Verbose, "verbose", "v", false, "verbose output") +``` -#### Local Flags +### Local Flags A flag can also be assigned locally which will only apply to that specific command. - HugoCmd.Flags().StringVarP(&Source, "source", "s", "", "Source directory to read from") +```go +RootCmd.Flags().StringVarP(&Source, "source", "s", "", "Source directory to read from") +``` -### Remove a command from its parent - -Removing a command is not a common action in simple programs but it allows 3rd parties to customize an existing command tree. - -In this example, we remove the existing `VersionCmd` command of an existing root command, and we replace it by our own version. - - mainlib.RootCmd.RemoveCommand(mainlib.VersionCmd) - mainlib.RootCmd.AddCommand(versionCmd) - -### Once all commands and flags are defined, Execute the commands - -Execute should be run on the root for clarity, though it can be called on any command. - - HugoCmd.Execute() ## Example -In the example below we have defined three commands. Two are at the top level +In the example below, we have defined three commands. Two are at the top level and one (cmdTimes) is a child of one of the top commands. In this case the root is not executable meaning that a subcommand is required. This is accomplished by not providing a 'Run' for the 'rootCmd'. @@ -169,196 +402,254 @@ We have only defined one flag for a single command. More documentation about flags is available at https://github.com/spf13/pflag - import( - "github.com/spf13/cobra" - "fmt" - "strings" - ) +```go +package main - func main() { +import ( + "fmt" + "strings" - var echoTimes int + "github.com/spf13/cobra" +) - var cmdPrint = &cobra.Command{ - Use: "print [string to print]", - Short: "Print anything to the screen", - Long: `print is for printing anything back to the screen. +func main() { + + var echoTimes int + + var cmdPrint = &cobra.Command{ + Use: "print [string to print]", + Short: "Print anything to the screen", + Long: `print is for printing anything back to the screen. For many years people have printed back to the screen. `, - Run: func(cmd *cobra.Command, args []string) { - fmt.Println("Print: " + strings.Join(args, " ")) - }, - } + Run: func(cmd *cobra.Command, args []string) { + fmt.Println("Print: " + strings.Join(args, " ")) + }, + } - var cmdEcho = &cobra.Command{ - Use: "echo [string to echo]", - Short: "Echo anything to the screen", - Long: `echo is for echoing anything back. + var cmdEcho = &cobra.Command{ + Use: "echo [string to echo]", + Short: "Echo anything to the screen", + Long: `echo is for echoing anything back. Echo works a lot like print, except it has a child command. `, - Run: func(cmd *cobra.Command, args []string) { - fmt.Println("Print: " + strings.Join(args, " ")) - }, - } + Run: func(cmd *cobra.Command, args []string) { + fmt.Println("Print: " + strings.Join(args, " ")) + }, + } - var cmdTimes = &cobra.Command{ - Use: "times [# times] [string to echo]", - Short: "Echo anything to the screen more times", - Long: `echo things multiple times back to the user by providing + var cmdTimes = &cobra.Command{ + Use: "times [# times] [string to echo]", + Short: "Echo anything to the screen more times", + Long: `echo things multiple times back to the user by providing a count and a string.`, - Run: func(cmd *cobra.Command, args []string) { - for i:=0; i < echoTimes; i++ { - fmt.Println("Echo: " + strings.Join(args, " ")) - } - }, - } + Run: func(cmd *cobra.Command, args []string) { + for i := 0; i < echoTimes; i++ { + fmt.Println("Echo: " + strings.Join(args, " ")) + } + }, + } - cmdTimes.Flags().IntVarP(&echoTimes, "times", "t", 1, "times to echo the input") + cmdTimes.Flags().IntVarP(&echoTimes, "times", "t", 1, "times to echo the input") - var rootCmd = &cobra.Command{Use: "app"} - rootCmd.AddCommand(cmdPrint, cmdEcho) - cmdEcho.AddCommand(cmdTimes) - rootCmd.Execute() - } + var rootCmd = &cobra.Command{Use: "app"} + rootCmd.AddCommand(cmdPrint, cmdEcho) + cmdEcho.AddCommand(cmdTimes) + rootCmd.Execute() +} +``` -For a more complete example of a larger application, please checkout [Hugo](http://hugo.spf13.com) +For a more complete example of a larger application, please checkout [Hugo](http://gohugo.io/). ## The Help Command Cobra automatically adds a help command to your application when you have subcommands. -This will be called when a user runs 'app help'. Additionally help will also -support all other commands as input. Say for instance you have a command called -'create' without any additional configuration cobra will work when 'app help +This will be called when a user runs 'app help'. Additionally, help will also +support all other commands as input. Say, for instance, you have a command called +'create' without any additional configuration; Cobra will work when 'app help create' is called. Every command will automatically have the '--help' flag added. ### Example -The following output is automatically generated by cobra. Nothing beyond the +The following output is automatically generated by Cobra. Nothing beyond the command and flag definitions are needed. > hugo help - A Fast and Flexible Static Site Generator built with - love by spf13 and friends in Go. + hugo is the main command, used to build your Hugo site. - Complete documentation is available at http://hugo.spf13.com + Hugo is a Fast and Flexible Static Site Generator + built with love by spf13 and friends in Go. + + Complete documentation is available at http://gohugo.io/. Usage: hugo [flags] hugo [command] Available Commands: - server :: Hugo runs it's own a webserver to render the files - version :: Print the version number of Hugo - check :: Check content in the source directory - benchmark :: Benchmark hugo by building a site a number of times - help [command] :: Help about any command + server Hugo runs its own webserver to render the files + version Print the version number of Hugo + config Print the site configuration + check Check content in the source directory + benchmark Benchmark hugo by building a site a number of times. + convert Convert your content to different formats + new Create new content for your site + list Listing out various types of content + undraft Undraft changes the content's draft status from 'True' to 'False' + genautocomplete Generate shell autocompletion script for Hugo + gendoc Generate Markdown documentation for the Hugo CLI. + genman Generate man page for Hugo + import Import your site from others. - Available Flags: - -b, --base-url="": hostname (and path) to the root eg. http://spf13.com/ - -D, --build-drafts=false: include content marked as draft + Flags: + -b, --baseURL="": hostname (and path) to the root, e.g. http://spf13.com/ + -D, --buildDrafts[=false]: include content marked as draft + -F, --buildFuture[=false]: include content with publishdate in the future + --cacheDir="": filesystem path to cache directory. Defaults: $TMPDIR/hugo_cache/ + --canonifyURLs[=false]: if true, all relative URLs will be canonicalized using baseURL --config="": config file (default is path/config.yaml|json|toml) -d, --destination="": filesystem path to write files to + --disableRSS[=false]: Do not build RSS files + --disableSitemap[=false]: Do not build Sitemap file + --editor="": edit new content with this editor, if provided + --ignoreCache[=false]: Ignores the cache directory for reading but still writes to it + --log[=false]: Enable Logging + --logFile="": Log File path (if set, logging enabled automatically) + --noTimes[=false]: Don't sync modification time of files + --pluralizeListTitles[=true]: Pluralize titles in lists using inflect + --preserveTaxonomyNames[=false]: Preserve taxonomy names as written ("Gérard Depardieu" vs "gerard-depardieu") -s, --source="": filesystem path to read files relative from - --stepAnalysis=false: display memory and timing of different steps of the program - --uglyurls=false: if true, use /filename.html instead of /filename/ - -v, --verbose=false: verbose output - -w, --watch=false: watch filesystem for changes and recreate as needed - - Use "hugo help [command]" for more information about that command. + --stepAnalysis[=false]: display memory and timing of different steps of the program + -t, --theme="": theme to use (located in /themes/THEMENAME/) + --uglyURLs[=false]: if true, use /filename.html instead of /filename/ + -v, --verbose[=false]: verbose output + --verboseLog[=false]: verbose logging + -w, --watch[=false]: watch filesystem for changes and recreate as needed + Use "hugo [command] --help" for more information about a command. Help is just a command like any other. There is no special logic or behavior -around it. In fact you can provide your own if you want. +around it. In fact, you can provide your own if you want. ### Defining your own help You can provide your own Help command or you own template for the default command to use. -The default help command is +The default help command is - func (c *Command) initHelp() { - if c.helpCommand == nil { - c.helpCommand = &Command{ - Use: "help [command]", - Short: "Help about any command", - Long: `Help provides help for any command in the application. +```go +func (c *Command) initHelp() { + if c.helpCommand == nil { + c.helpCommand = &Command{ + Use: "help [command]", + Short: "Help about any command", + Long: `Help provides help for any command in the application. Simply type ` + c.Name() + ` help [path to command] for full details.`, - Run: c.HelpFunc(), - } - } - c.AddCommand(c.helpCommand) - } + Run: c.HelpFunc(), + } + } + c.AddCommand(c.helpCommand) +} +``` -You can provide your own command, function or template through the following methods. +You can provide your own command, function or template through the following methods: - command.SetHelpCommand(cmd *Command) +```go +command.SetHelpCommand(cmd *Command) - command.SetHelpFunc(f func(*Command, []string)) +command.SetHelpFunc(f func(*Command, []string)) - command.SetHelpTemplate(s string) +command.SetHelpTemplate(s string) +``` The latter two will also apply to any children commands. ## Usage -When the user provides an invalid flag or invalid command Cobra responds by -showing the user the 'usage' +When the user provides an invalid flag or invalid command, Cobra responds by +showing the user the 'usage'. ### Example You may recognize this from the help above. That's because the default help -embeds the usage as part of it's output. +embeds the usage as part of its output. Usage: hugo [flags] hugo [command] Available Commands: - server Hugo runs it's own a webserver to render the files + server Hugo runs its own webserver to render the files version Print the version number of Hugo + config Print the site configuration check Check content in the source directory - benchmark Benchmark hugo by building a site a number of times - help [command] Help about any command + benchmark Benchmark hugo by building a site a number of times. + convert Convert your content to different formats + new Create new content for your site + list Listing out various types of content + undraft Undraft changes the content's draft status from 'True' to 'False' + genautocomplete Generate shell autocompletion script for Hugo + gendoc Generate Markdown documentation for the Hugo CLI. + genman Generate man page for Hugo + import Import your site from others. - Available Flags: - -b, --base-url="": hostname (and path) to the root eg. http://spf13.com/ - -D, --build-drafts=false: include content marked as draft + Flags: + -b, --baseURL="": hostname (and path) to the root, e.g. http://spf13.com/ + -D, --buildDrafts[=false]: include content marked as draft + -F, --buildFuture[=false]: include content with publishdate in the future + --cacheDir="": filesystem path to cache directory. Defaults: $TMPDIR/hugo_cache/ + --canonifyURLs[=false]: if true, all relative URLs will be canonicalized using baseURL --config="": config file (default is path/config.yaml|json|toml) -d, --destination="": filesystem path to write files to + --disableRSS[=false]: Do not build RSS files + --disableSitemap[=false]: Do not build Sitemap file + --editor="": edit new content with this editor, if provided + --ignoreCache[=false]: Ignores the cache directory for reading but still writes to it + --log[=false]: Enable Logging + --logFile="": Log File path (if set, logging enabled automatically) + --noTimes[=false]: Don't sync modification time of files + --pluralizeListTitles[=true]: Pluralize titles in lists using inflect + --preserveTaxonomyNames[=false]: Preserve taxonomy names as written ("Gérard Depardieu" vs "gerard-depardieu") -s, --source="": filesystem path to read files relative from - --stepAnalysis=false: display memory and timing of different steps of the program - --uglyurls=false: if true, use /filename.html instead of /filename/ - -v, --verbose=false: verbose output - -w, --watch=false: watch filesystem for changes and recreate as needed + --stepAnalysis[=false]: display memory and timing of different steps of the program + -t, --theme="": theme to use (located in /themes/THEMENAME/) + --uglyURLs[=false]: if true, use /filename.html instead of /filename/ + -v, --verbose[=false]: verbose output + --verboseLog[=false]: verbose logging + -w, --watch[=false]: watch filesystem for changes and recreate as needed ### Defining your own usage -You can provide your own usage function or template for cobra to use. +You can provide your own usage function or template for Cobra to use. -The default usage function is +The default usage function is: - return func(c *Command) error { - err := tmpl(c.Out(), c.UsageTemplate(), c) - return err - } +```go +return func(c *Command) error { + err := tmpl(c.Out(), c.UsageTemplate(), c) + return err +} +``` -Like help the function and template are over ridable through public methods. +Like help, the function and template are overridable through public methods: - command.SetUsageFunc(f func(*Command) error) +```go +command.SetUsageFunc(f func(*Command) error) - command.SetUsageTemplate(s string) +command.SetUsageTemplate(s string) +``` ## PreRun or PostRun Hooks -It is possible to run functions before or after the main `Run` function of your command. The `PersistentPreRun` and `PreRun` functions will be executed before `Run`. `PersistendPostRun` and `PostRun` will be executed after `Run`. The `Persistent*Run` functions will be inherrited by children if they do not declare their own. These function are run in the following order: +It is possible to run functions before or after the main `Run` function of your command. The `PersistentPreRun` and `PreRun` functions will be executed before `Run`. `PersistentPostRun` and `PostRun` will be executed after `Run`. The `Persistent*Run` functions will be inherrited by children if they do not declare their own. These function are run in the following order: - `PersistentPreRun` - `PreRun` - `Run` - `PostRun` -- `PersistenPostRun` +- `PersistentPostRun` -And example of two commands which use all of these features is below. When the subcommand in executed it will run the root command's `PersistentPreRun` but not the root command's `PersistentPostRun` +An example of two commands which use all of these features is below. When the subcommand is executed, it will run the root command's `PersistentPreRun` but not the root command's `PersistentPostRun`: ```go package main @@ -393,7 +684,7 @@ func main() { var subCmd = &cobra.Command{ Use: "sub [no options!]", - Short: "My sub command", + Short: "My subcommand", PreRun: func(cmd *cobra.Command, args []string) { fmt.Printf("Inside subCmd PreRun with args: %v\n", args) }, @@ -418,22 +709,110 @@ func main() { } ``` -## Generating markdown formatted documentation for your command -Cobra can generate a markdown formatted document based on the subcommands, flags, etc. A simple example of how to do this for your command can be found in [Markdown Docs](md_docs.md) +## Alternative Error Handling + +Cobra also has functions where the return signature is an error. This allows for errors to bubble up to the top, providing a way to handle the errors in one location. The current list of functions that return an error is: + +* PersistentPreRunE +* PreRunE +* RunE +* PostRunE +* PersistentPostRunE + +**Example Usage using RunE:** + +```go +package main + +import ( + "errors" + "log" + + "github.com/spf13/cobra" +) + +func main() { + var rootCmd = &cobra.Command{ + Use: "hugo", + Short: "Hugo is a very fast static site generator", + Long: `A Fast and Flexible Static Site Generator built with + love by spf13 and friends in Go. + Complete documentation is available at http://hugo.spf13.com`, + RunE: func(cmd *cobra.Command, args []string) error { + // Do Stuff Here + return errors.New("some random error") + }, + } + + if err := rootCmd.Execute(); err != nil { + log.Fatal(err) + } +} +``` + +## Suggestions when "unknown command" happens + +Cobra will print automatic suggestions when "unknown command" errors happen. This allows Cobra to behave similarly to the `git` command when a typo happens. For example: + +``` +$ hugo srever +Error: unknown command "srever" for "hugo" + +Did you mean this? + server + +Run 'hugo --help' for usage. +``` + +Suggestions are automatic based on every subcommand registered and use an implementation of [Levenshtein distance](http://en.wikipedia.org/wiki/Levenshtein_distance). Every registered command that matches a minimum distance of 2 (ignoring case) will be displayed as a suggestion. + +If you need to disable suggestions or tweak the string distance in your command, use: + +```go +command.DisableSuggestions = true +``` + +or + +```go +command.SuggestionsMinimumDistance = 1 +``` + +You can also explicitly set names for which a given command will be suggested using the `SuggestFor` attribute. This allows suggestions for strings that are not close in terms of string distance, but makes sense in your set of commands and for some which you don't want aliases. Example: + +``` +$ kubectl remove +Error: unknown command "remove" for "kubectl" + +Did you mean this? + delete + +Run 'kubectl help' for usage. +``` + +## Generating Markdown-formatted documentation for your command + +Cobra can generate a Markdown-formatted document based on the subcommands, flags, etc. A simple example of how to do this for your command can be found in [Markdown Docs](doc/md_docs.md). + +## Generating man pages for your command + +Cobra can generate a man page based on the subcommands, flags, etc. A simple example of how to do this for your command can be found in [Man Docs](doc/man_docs.md). ## Generating bash completions for your command -Cobra can generate a bash completions file. If you add more information to your command these completions can be amazingly powerful and flexible. Read more about [Bash Completions](bash_completions.md) +Cobra can generate a bash-completion file. If you add more information to your command, these completions can be amazingly powerful and flexible. Read more about it in [Bash Completions](bash_completions.md). ## Debugging -Cobra provides a ‘DebugFlags’ method on a command which when called will print -out everything Cobra knows about the flags for each command +Cobra provides a ‘DebugFlags’ method on a command which, when called, will print +out everything Cobra knows about the flags for each command. ### Example - command.DebugFlags() +```go +command.DebugFlags() +``` ## Release Notes * **0.9.0** June 17, 2014 @@ -459,6 +838,12 @@ out everything Cobra knows about the flags for each command * **0.1.0** Sept 3, 2013 * Implement first draft +## Extensions + +Libraries for extending Cobra: + +* [cmdns](https://github.com/gosuri/cmdns): Enables name spacing a command's immediate children. It provides an alternative way to structure subcommands, similar to `heroku apps:create` and `ovrclk clusters:launch`. + ## ToDo * Launch proper documentation site @@ -474,7 +859,9 @@ out everything Cobra knows about the flags for each command Names in no particular order: -* [spf13](https://github.com/spf13) +* [spf13](https://github.com/spf13), +[eparis](https://github.com/eparis), +[bep](https://github.com/bep), and many more! ## License @@ -482,4 +869,3 @@ Cobra is released under the Apache 2.0 license. See [LICENSE.txt](https://github [![Bitdeli Badge](https://d2weczhvl823v0.cloudfront.net/spf13/cobra/trend.png)](https://bitdeli.com/free "Bitdeli Badge") - diff --git a/Godeps/_workspace/src/github.com/spf13/cobra/bash_completions.go b/Godeps/_workspace/src/github.com/spf13/cobra/bash_completions.go index 3084b914..360bd98a 100644 --- a/Godeps/_workspace/src/github.com/spf13/cobra/bash_completions.go +++ b/Godeps/_workspace/src/github.com/spf13/cobra/bash_completions.go @@ -1,8 +1,8 @@ package cobra import ( - "bytes" "fmt" + "io" "os" "sort" "strings" @@ -13,19 +13,30 @@ import ( const ( BashCompFilenameExt = "cobra_annotation_bash_completion_filename_extentions" BashCompOneRequiredFlag = "cobra_annotation_bash_completion_one_required_flag" + BashCompSubdirsInDir = "cobra_annotation_bash_completion_subdirs_in_dir" ) -func preamble(out *bytes.Buffer) { - fmt.Fprintf(out, `#!/bin/bash - - +func preamble(out io.Writer, name string) error { + _, err := fmt.Fprintf(out, "# bash completion for %-36s -*- shell-script -*-\n", name) + if err != nil { + return err + } + _, err = fmt.Fprintf(out, ` __debug() { if [[ -n ${BASH_COMP_DEBUG_FILE} ]]; then - echo "$*" >> ${BASH_COMP_DEBUG_FILE} + echo "$*" >> "${BASH_COMP_DEBUG_FILE}" fi } +# Homebrew on Macs have version 1.3 of bash-completion which doesn't include +# _init_completion. This is a very minimal version of that function. +__my_init_completion() +{ + COMPREPLY=() + _get_comp_words_by_ref cur prev words cword +} + __index_of_word() { local w word=$1 @@ -52,7 +63,9 @@ __handle_reply() __debug "${FUNCNAME}" case $cur in -*) - compopt -o nospace + if [[ $(type -t compopt) = "builtin" ]]; then + compopt -o nospace + fi local allflags if [ ${#must_have_one_flag[@]} -ne 0 ]; then allflags=("${must_have_one_flag[@]}") @@ -60,7 +73,9 @@ __handle_reply() allflags=("${flags[*]} ${two_word_flags[*]}") fi COMPREPLY=( $(compgen -W "${allflags[*]}" -- "$cur") ) - [[ $COMPREPLY == *= ]] || compopt +o nospace + if [[ $(type -t compopt) = "builtin" ]]; then + [[ $COMPREPLY == *= ]] || compopt +o nospace + fi return 0; ;; esac @@ -91,6 +106,21 @@ __handle_reply() if [[ ${#COMPREPLY[@]} -eq 0 ]]; then declare -F __custom_func >/dev/null && __custom_func fi + + __ltrim_colon_completions "$cur" +} + +# The arguments should be in the form "ext1|ext2|extn" +__handle_filename_extension_flag() +{ + local ext="$1" + _filedir "@(${ext})" +} + +__handle_subdirs_in_dir_flag() +{ + local dir="$1" + pushd "${dir}" >/dev/null 2>&1 && _filedir -d && popd >/dev/null 2>&1 } __handle_flag() @@ -99,8 +129,10 @@ __handle_flag() # if a command required a flag, and we found it, unset must_have_one_flag() local flagname=${words[c]} + local flagvalue # if the word contained an = if [[ ${words[c]} == *"="* ]]; then + flagvalue=${flagname#*=} # take in as flagvalue after the = flagname=${flagname%%=*} # strip everything after the = flagname="${flagname}=" # but put the = back fi @@ -109,6 +141,15 @@ __handle_flag() must_have_one_flag=() fi + # keep flag value with flagname as flaghash + if [ ${flagvalue} ] ; then + flaghash[${flagname}]=${flagvalue} + elif [ ${words[ $((c+1)) ]} ] ; then + flaghash[${flagname}]=${words[ $((c+1)) ]} + else + flaghash[${flagname}]="true" # pad "true" for bool flag + fi + # skip the argument to a two word flag if __contains_word "${words[c]}" "${two_word_flags[@]}"; then c=$((c+1)) @@ -118,7 +159,6 @@ __handle_flag() fi fi - # skip the flag itself c=$((c+1)) } @@ -141,9 +181,9 @@ __handle_command() local next_command if [[ -n ${last_command} ]]; then - next_command="_${last_command}_${words[c]}" + next_command="_${last_command}_${words[c]//:/__}" else - next_command="_${words[c]}" + next_command="_${words[c]//:/__}" fi c=$((c+1)) __debug "${FUNCNAME}: looking for ${next_command}" @@ -154,11 +194,11 @@ __handle_word() { if [[ $c -ge $cword ]]; then __handle_reply - return + return fi __debug "${FUNCNAME}: c is $c words[c] is ${words[c]}" if [[ "${words[c]}" == -* ]]; then - __handle_flag + __handle_flag elif __contains_word "${words[c]}" "${commands[@]}"; then __handle_command else @@ -168,15 +208,24 @@ __handle_word() } `) + return err } -func postscript(out *bytes.Buffer, name string) { - fmt.Fprintf(out, "__start_%s()\n", name) - fmt.Fprintf(out, `{ - local cur prev words cword split - _init_completion -s || return +func postscript(w io.Writer, name string) error { + name = strings.Replace(name, ":", "__", -1) + _, err := fmt.Fprintf(w, "__start_%s()\n", name) + if err != nil { + return err + } + _, err = fmt.Fprintf(w, `{ + local cur prev words cword + declare -A flaghash + if declare -F _init_completion >/dev/null 2>&1; then + _init_completion -s || return + else + __my_init_completion || return + fi - local completions_func local c=0 local flags=() local two_word_flags=() @@ -192,35 +241,77 @@ func postscript(out *bytes.Buffer, name string) { } `, name) - fmt.Fprintf(out, "complete -F __start_%s %s\n", name, name) - fmt.Fprintf(out, "# ex: ts=4 sw=4 et filetype=sh\n") + if err != nil { + return err + } + _, err = fmt.Fprintf(w, `if [[ $(type -t compopt) = "builtin" ]]; then + complete -o default -F __start_%s %s +else + complete -o default -o nospace -F __start_%s %s +fi + +`, name, name, name, name) + if err != nil { + return err + } + _, err = fmt.Fprintf(w, "# ex: ts=4 sw=4 et filetype=sh\n") + return err } -func writeCommands(cmd *Command, out *bytes.Buffer) { - fmt.Fprintf(out, " commands=()\n") +func writeCommands(cmd *Command, w io.Writer) error { + if _, err := fmt.Fprintf(w, " commands=()\n"); err != nil { + return err + } for _, c := range cmd.Commands() { - if len(c.Deprecated) > 0 { + if !c.IsAvailableCommand() || c == cmd.helpCommand { continue } - fmt.Fprintf(out, " commands+=(%q)\n", c.Name()) + if _, err := fmt.Fprintf(w, " commands+=(%q)\n", c.Name()); err != nil { + return err + } } - fmt.Fprintf(out, "\n") + _, err := fmt.Fprintf(w, "\n") + return err } -func writeFlagHandler(name string, annotations map[string][]string, out *bytes.Buffer) { +func writeFlagHandler(name string, annotations map[string][]string, w io.Writer) error { for key, value := range annotations { switch key { case BashCompFilenameExt: - fmt.Fprintf(out, " flags_with_completion+=(%q)\n", name) + _, err := fmt.Fprintf(w, " flags_with_completion+=(%q)\n", name) + if err != nil { + return err + } - ext := strings.Join(value, "|") - ext = "_filedir '@(" + ext + ")'" - fmt.Fprintf(out, " flags_completion+=(%q)\n", ext) + if len(value) > 0 { + ext := "__handle_filename_extension_flag " + strings.Join(value, "|") + _, err = fmt.Fprintf(w, " flags_completion+=(%q)\n", ext) + } else { + ext := "_filedir" + _, err = fmt.Fprintf(w, " flags_completion+=(%q)\n", ext) + } + if err != nil { + return err + } + case BashCompSubdirsInDir: + _, err := fmt.Fprintf(w, " flags_with_completion+=(%q)\n", name) + + if len(value) == 1 { + ext := "__handle_subdirs_in_dir_flag " + value[0] + _, err = fmt.Fprintf(w, " flags_completion+=(%q)\n", ext) + } else { + ext := "_filedir -d" + _, err = fmt.Fprintf(w, " flags_completion+=(%q)\n", ext) + } + if err != nil { + return err + } } } + return nil } -func writeShortFlag(flag *pflag.Flag, out *bytes.Buffer) { +func writeShortFlag(flag *pflag.Flag, w io.Writer) error { b := (flag.Value.Type() == "bool") name := flag.Shorthand format := " " @@ -228,11 +319,13 @@ func writeShortFlag(flag *pflag.Flag, out *bytes.Buffer) { format += "two_word_" } format += "flags+=(\"-%s\")\n" - fmt.Fprintf(out, format, name) - writeFlagHandler("-"+name, flag.Annotations, out) + if _, err := fmt.Fprintf(w, format, name); err != nil { + return err + } + return writeFlagHandler("-"+name, flag.Annotations, w) } -func writeFlag(flag *pflag.Flag, out *bytes.Buffer) { +func writeFlag(flag *pflag.Flag, w io.Writer) error { b := (flag.Value.Type() == "bool") name := flag.Name format := " flags+=(\"--%s" @@ -240,32 +333,66 @@ func writeFlag(flag *pflag.Flag, out *bytes.Buffer) { format += "=" } format += "\")\n" - fmt.Fprintf(out, format, name) - writeFlagHandler("--"+name, flag.Annotations, out) + if _, err := fmt.Fprintf(w, format, name); err != nil { + return err + } + return writeFlagHandler("--"+name, flag.Annotations, w) } -func writeFlags(cmd *Command, out *bytes.Buffer) { - fmt.Fprintf(out, ` flags=() +func writeFlags(cmd *Command, w io.Writer) error { + _, err := fmt.Fprintf(w, ` flags=() two_word_flags=() flags_with_completion=() flags_completion=() `) + if err != nil { + return err + } + var visitErr error cmd.NonInheritedFlags().VisitAll(func(flag *pflag.Flag) { - writeFlag(flag, out) + if err := writeFlag(flag, w); err != nil { + visitErr = err + return + } if len(flag.Shorthand) > 0 { - writeShortFlag(flag, out) + if err := writeShortFlag(flag, w); err != nil { + visitErr = err + return + } } }) + if visitErr != nil { + return visitErr + } + cmd.InheritedFlags().VisitAll(func(flag *pflag.Flag) { + if err := writeFlag(flag, w); err != nil { + visitErr = err + return + } + if len(flag.Shorthand) > 0 { + if err := writeShortFlag(flag, w); err != nil { + visitErr = err + return + } + } + }) + if visitErr != nil { + return visitErr + } - fmt.Fprintf(out, "\n") + _, err = fmt.Fprintf(w, "\n") + return err } -func writeRequiredFlag(cmd *Command, out *bytes.Buffer) { - fmt.Fprintf(out, " must_have_one_flag=()\n") +func writeRequiredFlag(cmd *Command, w io.Writer) error { + if _, err := fmt.Fprintf(w, " must_have_one_flag=()\n"); err != nil { + return err + } flags := cmd.NonInheritedFlags() + var visitErr error flags.VisitAll(func(flag *pflag.Flag) { - for key, _ := range flag.Annotations { + for key := range flag.Annotations { switch key { case BashCompOneRequiredFlag: format := " must_have_one_flag+=(\"--%s" @@ -274,78 +401,126 @@ func writeRequiredFlag(cmd *Command, out *bytes.Buffer) { format += "=" } format += "\")\n" - fmt.Fprintf(out, format, flag.Name) + if _, err := fmt.Fprintf(w, format, flag.Name); err != nil { + visitErr = err + return + } if len(flag.Shorthand) > 0 { - fmt.Fprintf(out, " must_have_one_flag+=(\"-%s\")\n", flag.Shorthand) + if _, err := fmt.Fprintf(w, " must_have_one_flag+=(\"-%s\")\n", flag.Shorthand); err != nil { + visitErr = err + return + } } } } }) + return visitErr } -func writeRequiredNoun(cmd *Command, out *bytes.Buffer) { - fmt.Fprintf(out, " must_have_one_noun=()\n") +func writeRequiredNoun(cmd *Command, w io.Writer) error { + if _, err := fmt.Fprintf(w, " must_have_one_noun=()\n"); err != nil { + return err + } sort.Sort(sort.StringSlice(cmd.ValidArgs)) for _, value := range cmd.ValidArgs { - fmt.Fprintf(out, " must_have_one_noun+=(%q)\n", value) + if _, err := fmt.Fprintf(w, " must_have_one_noun+=(%q)\n", value); err != nil { + return err + } } + return nil } -func gen(cmd *Command, out *bytes.Buffer) { +func gen(cmd *Command, w io.Writer) error { for _, c := range cmd.Commands() { - if len(c.Deprecated) > 0 { + if !c.IsAvailableCommand() || c == cmd.helpCommand { continue } - gen(c, out) + if err := gen(c, w); err != nil { + return err + } } commandName := cmd.CommandPath() commandName = strings.Replace(commandName, " ", "_", -1) - fmt.Fprintf(out, "_%s()\n{\n", commandName) - fmt.Fprintf(out, " last_command=%q\n", commandName) - writeCommands(cmd, out) - writeFlags(cmd, out) - writeRequiredFlag(cmd, out) - writeRequiredNoun(cmd, out) - fmt.Fprintf(out, "}\n\n") + commandName = strings.Replace(commandName, ":", "__", -1) + if _, err := fmt.Fprintf(w, "_%s()\n{\n", commandName); err != nil { + return err + } + if _, err := fmt.Fprintf(w, " last_command=%q\n", commandName); err != nil { + return err + } + if err := writeCommands(cmd, w); err != nil { + return err + } + if err := writeFlags(cmd, w); err != nil { + return err + } + if err := writeRequiredFlag(cmd, w); err != nil { + return err + } + if err := writeRequiredNoun(cmd, w); err != nil { + return err + } + if _, err := fmt.Fprintf(w, "}\n\n"); err != nil { + return err + } + return nil } -func (cmd *Command) GenBashCompletion(out *bytes.Buffer) { - preamble(out) - if len(cmd.BashCompletionFunction) > 0 { - fmt.Fprintf(out, "%s\n", cmd.BashCompletionFunction) +func (cmd *Command) GenBashCompletion(w io.Writer) error { + if err := preamble(w, cmd.Name()); err != nil { + return err } - gen(cmd, out) - postscript(out, cmd.Name()) + if len(cmd.BashCompletionFunction) > 0 { + if _, err := fmt.Fprintf(w, "%s\n", cmd.BashCompletionFunction); err != nil { + return err + } + } + if err := gen(cmd, w); err != nil { + return err + } + return postscript(w, cmd.Name()) } func (cmd *Command) GenBashCompletionFile(filename string) error { - out := new(bytes.Buffer) - - cmd.GenBashCompletion(out) - outFile, err := os.Create(filename) if err != nil { return err } defer outFile.Close() - _, err = outFile.Write(out.Bytes()) - if err != nil { - return err - } - return nil + return cmd.GenBashCompletion(outFile) } -func (cmd *Command) MarkFlagRequired(name string) { - flag := cmd.Flags().Lookup(name) - if flag == nil { - return - } - if flag.Annotations == nil { - flag.Annotations = make(map[string][]string) - } - annotation := make([]string, 1) - annotation[0] = "true" - flag.Annotations[BashCompOneRequiredFlag] = annotation +// MarkFlagRequired adds the BashCompOneRequiredFlag annotation to the named flag, if it exists. +func (cmd *Command) MarkFlagRequired(name string) error { + return MarkFlagRequired(cmd.Flags(), name) +} + +// MarkPersistentFlagRequired adds the BashCompOneRequiredFlag annotation to the named persistent flag, if it exists. +func (cmd *Command) MarkPersistentFlagRequired(name string) error { + return MarkFlagRequired(cmd.PersistentFlags(), name) +} + +// MarkFlagRequired adds the BashCompOneRequiredFlag annotation to the named flag in the flag set, if it exists. +func MarkFlagRequired(flags *pflag.FlagSet, name string) error { + return flags.SetAnnotation(name, BashCompOneRequiredFlag, []string{"true"}) +} + +// MarkFlagFilename adds the BashCompFilenameExt annotation to the named flag, if it exists. +// Generated bash autocompletion will select filenames for the flag, limiting to named extensions if provided. +func (cmd *Command) MarkFlagFilename(name string, extensions ...string) error { + return MarkFlagFilename(cmd.Flags(), name, extensions...) +} + +// MarkPersistentFlagFilename adds the BashCompFilenameExt annotation to the named persistent flag, if it exists. +// Generated bash autocompletion will select filenames for the flag, limiting to named extensions if provided. +func (cmd *Command) MarkPersistentFlagFilename(name string, extensions ...string) error { + return MarkFlagFilename(cmd.PersistentFlags(), name, extensions...) +} + +// MarkFlagFilename adds the BashCompFilenameExt annotation to the named flag in the flag set, if it exists. +// Generated bash autocompletion will select filenames for the flag, limiting to named extensions if provided. +func MarkFlagFilename(flags *pflag.FlagSet, name string, extensions ...string) error { + return flags.SetAnnotation(name, BashCompFilenameExt, extensions) } diff --git a/Godeps/_workspace/src/github.com/spf13/cobra/bash_completions.md b/Godeps/_workspace/src/github.com/spf13/cobra/bash_completions.md index e1a5d56d..204704ef 100644 --- a/Godeps/_workspace/src/github.com/spf13/cobra/bash_completions.md +++ b/Godeps/_workspace/src/github.com/spf13/cobra/bash_completions.md @@ -118,20 +118,23 @@ and you'll get something like -c --container= -p --pod= ``` -# Specify valid filename extentions for flags that take a filename +# Specify valid filename extensions for flags that take a filename In this example we use --filename= and expect to get a json or yaml file as the argument. To make this easier we annotate the --filename flag with valid filename extensions. ```go - annotations := make([]string, 3) - annotations[0] = "json" - annotations[1] = "yaml" - annotations[2] = "yml" - + annotations := []string{"json", "yaml", "yml"} annotation := make(map[string][]string) annotation[cobra.BashCompFilenameExt] = annotations - flag := &pflag.Flag{"filename", "f", usage, value, value.String(), false, annotation} + flag := &pflag.Flag{ + Name: "filename", + Shorthand: "f", + Usage: usage, + Value: value, + DefValue: value.String(), + Annotations: annotation, + } cmd.Flags().AddFlag(flag) ``` diff --git a/Godeps/_workspace/src/github.com/spf13/cobra/bash_completions_test.go b/Godeps/_workspace/src/github.com/spf13/cobra/bash_completions_test.go index 4b7d06c6..86f3d010 100644 --- a/Godeps/_workspace/src/github.com/spf13/cobra/bash_completions_test.go +++ b/Godeps/_workspace/src/github.com/spf13/cobra/bash_completions_test.go @@ -34,7 +34,7 @@ COMPREPLY=( "hello" ) func TestBashCompletions(t *testing.T) { c := initializeWithRootCmd() cmdEcho.AddCommand(cmdTimes) - c.AddCommand(cmdEcho, cmdPrint, cmdDeprecated) + c.AddCommand(cmdEcho, cmdPrint, cmdDeprecated, cmdColon) // custom completion function c.BashCompletionFunction = bash_completion_func @@ -42,23 +42,30 @@ func TestBashCompletions(t *testing.T) { // required flag c.MarkFlagRequired("introot") - // valid nounds + // valid nouns validArgs := []string{"pods", "nodes", "services", "replicationControllers"} c.ValidArgs = validArgs - // filename extentions - annotations := make([]string, 3) - annotations[0] = "json" - annotations[1] = "yaml" - annotations[2] = "yml" - - annotation := make(map[string][]string) - annotation[BashCompFilenameExt] = annotations - + // filename var flagval string c.Flags().StringVar(&flagval, "filename", "", "Enter a filename") - flag := c.Flags().Lookup("filename") - flag.Annotations = annotation + c.MarkFlagFilename("filename", "json", "yaml", "yml") + + // persistent filename + var flagvalPersistent string + c.PersistentFlags().StringVar(&flagvalPersistent, "persistent-filename", "", "Enter a filename") + c.MarkPersistentFlagFilename("persistent-filename") + c.MarkPersistentFlagRequired("persistent-filename") + + // filename extensions + var flagvalExt string + c.Flags().StringVar(&flagvalExt, "filename-ext", "", "Enter a filename (extension limited)") + c.MarkFlagFilename("filename-ext") + + // subdirectories in a given directory + var flagvalTheme string + c.Flags().StringVar(&flagvalTheme, "theme", "", "theme to use (located in /themes/THEMENAME/)") + c.Flags().SetAnnotation("theme", BashCompSubdirsInDir, []string{"themes"}) out := new(bytes.Buffer) c.GenBashCompletion(out) @@ -68,15 +75,21 @@ func TestBashCompletions(t *testing.T) { check(t, str, "_cobra-test_echo") check(t, str, "_cobra-test_echo_times") check(t, str, "_cobra-test_print") + check(t, str, "_cobra-test_cmd__colon") // check for required flags check(t, str, `must_have_one_flag+=("--introot=")`) + check(t, str, `must_have_one_flag+=("--persistent-filename=")`) // check for custom completion function check(t, str, `COMPREPLY=( "hello" )`) // check for required nouns check(t, str, `must_have_one_noun+=("pods")`) - // check for filename extention flags - check(t, str, `flags_completion+=("_filedir '@(json|yaml|yml)'")`) + // check for filename extension flags + check(t, str, `flags_completion+=("_filedir")`) + // check for filename extension flags + check(t, str, `flags_completion+=("__handle_filename_extension_flag json|yaml|yml")`) + // check for subdirs_in_dir flags + check(t, str, `flags_completion+=("__handle_subdirs_in_dir_flag themes")`) checkOmit(t, str, cmdDeprecated.Name()) } diff --git a/Godeps/_workspace/src/github.com/spf13/cobra/cobra.go b/Godeps/_workspace/src/github.com/spf13/cobra/cobra.go index 78b92b0a..7c8da2b4 100644 --- a/Godeps/_workspace/src/github.com/spf13/cobra/cobra.go +++ b/Godeps/_workspace/src/github.com/spf13/cobra/cobra.go @@ -23,21 +23,36 @@ import ( "strconv" "strings" "text/template" + "unicode" ) +var templateFuncs template.FuncMap = template.FuncMap{ + "trim": strings.TrimSpace, + "trimRightSpace": trimRightSpace, + "rpad": rpad, + "gt": Gt, + "eq": Eq, +} + var initializers []func() // automatic prefix matching can be a dangerous thing to automatically enable in CLI tools. // Set this to true to enable it var EnablePrefixMatching bool = false -// enables an information splash screen on Windows if the CLI is started from explorer.exe. -var EnableWindowsMouseTrap bool = true +//AddTemplateFunc adds a template function that's available to Usage and Help +//template generation. +func AddTemplateFunc(name string, tmplFunc interface{}) { + templateFuncs[name] = tmplFunc +} -var MousetrapHelpText string = `This is a command line tool - -You need to open cmd.exe and run it from there. -` +//AddTemplateFuncs adds multiple template functions availalble to Usage and +//Help template generation. +func AddTemplateFuncs(tmplFuncs template.FuncMap) { + for k, v := range tmplFuncs { + templateFuncs[k] = v + } +} //OnInitialize takes a series of func() arguments and appends them to a slice of func(). func OnInitialize(y ...func()) { @@ -92,6 +107,10 @@ func Eq(a interface{}, b interface{}) bool { return false } +func trimRightSpace(s string) string { + return strings.TrimRightFunc(s, unicode.IsSpace) +} + //rpad adds padding to the right of a string func rpad(s string, padding int) string { template := fmt.Sprintf("%%-%ds", padding) @@ -101,12 +120,43 @@ func rpad(s string, padding int) string { // tmpl executes the given template text on data, writing the result to w. func tmpl(w io.Writer, text string, data interface{}) error { t := template.New("top") - t.Funcs(template.FuncMap{ - "trim": strings.TrimSpace, - "rpad": rpad, - "gt": Gt, - "eq": Eq, - }) + t.Funcs(templateFuncs) template.Must(t.Parse(text)) return t.Execute(w, data) } + +// ld compares two strings and returns the levenshtein distance between them +func ld(s, t string, ignoreCase bool) int { + if ignoreCase { + s = strings.ToLower(s) + t = strings.ToLower(t) + } + d := make([][]int, len(s)+1) + for i := range d { + d[i] = make([]int, len(t)+1) + } + for i := range d { + d[i][0] = i + } + for j := range d[0] { + d[0][j] = j + } + for j := 1; j <= len(t); j++ { + for i := 1; i <= len(s); i++ { + if s[i-1] == t[j-1] { + d[i][j] = d[i-1][j-1] + } else { + min := d[i-1][j] + if d[i][j-1] < min { + min = d[i][j-1] + } + if d[i-1][j-1] < min { + min = d[i-1][j-1] + } + d[i][j] = min + 1 + } + } + + } + return d[len(s)][len(t)] +} diff --git a/Godeps/_workspace/src/github.com/spf13/cobra/cobra/cmd/add.go b/Godeps/_workspace/src/github.com/spf13/cobra/cobra/cmd/add.go new file mode 100644 index 00000000..4afad6e3 --- /dev/null +++ b/Godeps/_workspace/src/github.com/spf13/cobra/cobra/cmd/add.go @@ -0,0 +1,128 @@ +// Copyright © 2015 Steve Francia . +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cmd + +import ( + "fmt" + "path/filepath" + "strings" + + "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/spf13/cobra" + "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/spf13/viper" +) + +func init() { + RootCmd.AddCommand(addCmd) +} + +var pName string + +// initialize Command +var addCmd = &cobra.Command{ + Use: "add [command name]", + Aliases: []string{"command"}, + Short: "Add a command to a Cobra Application", + Long: `Add (cobra add) will create a new command, with a license and +the appropriate structure for a Cobra-based CLI application, +and register it to its parent (default RootCmd). + +If you want your command to be public, pass in the command name +with an initial uppercase letter. + +Example: cobra add server -> resulting in a new cmd/server.go + `, + + Run: func(cmd *cobra.Command, args []string) { + if len(args) != 1 { + er("add needs a name for the command") + } + guessProjectPath() + createCmdFile(args[0]) + }, +} + +func init() { + addCmd.Flags().StringVarP(&pName, "parent", "p", "RootCmd", "name of parent command for this command") +} + +func parentName() string { + if !strings.HasSuffix(strings.ToLower(pName), "cmd") { + return pName + "Cmd" + } + + return pName +} + +func createCmdFile(cmdName string) { + lic := getLicense() + + template := `{{ comment .copyright }} +{{ comment .license }} + +package cmd + +import ( + "fmt" + + "github.com/spf13/cobra" +) + +// {{.cmdName}}Cmd represents the {{.cmdName}} command +var {{ .cmdName }}Cmd = &cobra.Command{ + Use: "{{ .cmdName }}", + Short: "A brief description of your command", + Long: ` + "`" + `A longer description that spans multiple lines and likely contains examples +and usage of using your command. For example: + +Cobra is a CLI library for Go that empowers applications. +This application is a tool to generate the needed files +to quickly create a Cobra application.` + "`" + `, + Run: func(cmd *cobra.Command, args []string) { + // TODO: Work your own magic here + fmt.Println("{{ .cmdName }} called") + }, +} + +func init() { + {{ .parentName }}.AddCommand({{ .cmdName }}Cmd) + + // Here you will define your flags and configuration settings. + + // Cobra supports Persistent Flags which will work for this command + // and all subcommands, e.g.: + // {{.cmdName}}Cmd.PersistentFlags().String("foo", "", "A help for foo") + + // Cobra supports local flags which will only run when this command + // is called directly, e.g.: + // {{.cmdName}}Cmd.Flags().BoolP("toggle", "t", false, "Help message for toggle") + +} +` + + var data map[string]interface{} + data = make(map[string]interface{}) + + data["copyright"] = copyrightLine() + data["license"] = lic.Header + data["appName"] = projectName() + data["viper"] = viper.GetBool("useViper") + data["parentName"] = parentName() + data["cmdName"] = cmdName + + err := writeTemplateToFile(filepath.Join(ProjectPath(), guessCmdDir()), cmdName+".go", template, data) + if err != nil { + er(err) + } + fmt.Println(cmdName, "created at", filepath.Join(ProjectPath(), guessCmdDir(), cmdName+".go")) +} diff --git a/Godeps/_workspace/src/github.com/spf13/cobra/cobra/cmd/helpers.go b/Godeps/_workspace/src/github.com/spf13/cobra/cobra/cmd/helpers.go new file mode 100644 index 00000000..7afd8ef3 --- /dev/null +++ b/Godeps/_workspace/src/github.com/spf13/cobra/cobra/cmd/helpers.go @@ -0,0 +1,347 @@ +// Copyright © 2015 Steve Francia . +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cmd + +import ( + "bytes" + "fmt" + "io" + "os" + "path/filepath" + "strings" + "text/template" + "time" + + "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/spf13/viper" +) + +// var BaseDir = "" +// var AppName = "" +// var CommandDir = "" + +var funcMap template.FuncMap +var projectPath = "" +var inputPath = "" +var projectBase = "" + +// for testing only +var testWd = "" + +var cmdDirs = []string{"cmd", "cmds", "command", "commands"} + +func init() { + funcMap = template.FuncMap{ + "comment": commentifyString, + } +} + +func er(msg interface{}) { + fmt.Println("Error:", msg) + os.Exit(-1) +} + +// Check if a file or directory exists. +func exists(path string) (bool, error) { + _, err := os.Stat(path) + if err == nil { + return true, nil + } + if os.IsNotExist(err) { + return false, nil + } + return false, err +} + +func ProjectPath() string { + if projectPath == "" { + guessProjectPath() + } + + return projectPath +} + +// wrapper of the os package so we can test better +func getWd() (string, error) { + if testWd == "" { + return os.Getwd() + } + return testWd, nil +} + +func guessCmdDir() string { + guessProjectPath() + if b, _ := isEmpty(projectPath); b { + return "cmd" + } + + files, _ := filepath.Glob(projectPath + string(os.PathSeparator) + "c*") + for _, f := range files { + for _, c := range cmdDirs { + if f == c { + return c + } + } + } + + return "cmd" +} + +func guessImportPath() string { + guessProjectPath() + + if !strings.HasPrefix(projectPath, getSrcPath()) { + er("Cobra only supports project within $GOPATH") + } + + return filepath.ToSlash(filepath.Clean(strings.TrimPrefix(projectPath, getSrcPath()))) +} + +func getSrcPath() string { + return filepath.Join(os.Getenv("GOPATH"), "src") + string(os.PathSeparator) +} + +func projectName() string { + return filepath.Base(ProjectPath()) +} + +func guessProjectPath() { + // if no path is provided... assume CWD. + if inputPath == "" { + x, err := getWd() + if err != nil { + er(err) + } + + // inspect CWD + base := filepath.Base(x) + + // if we are in the cmd directory.. back up + for _, c := range cmdDirs { + if base == c { + projectPath = filepath.Dir(x) + return + } + } + + if projectPath == "" { + projectPath = filepath.Clean(x) + return + } + } + + srcPath := getSrcPath() + // if provided, inspect for logical locations + if strings.ContainsRune(inputPath, os.PathSeparator) { + if filepath.IsAbs(inputPath) || filepath.HasPrefix(inputPath, string(os.PathSeparator)) { + // if Absolute, use it + projectPath = filepath.Clean(inputPath) + return + } + // If not absolute but contains slashes, + // assuming it means create it from $GOPATH + count := strings.Count(inputPath, string(os.PathSeparator)) + + switch count { + // If only one directory deep, assume "github.com" + case 1: + projectPath = filepath.Join(srcPath, "github.com", inputPath) + return + case 2: + projectPath = filepath.Join(srcPath, inputPath) + return + default: + er("Unknown directory") + } + } else { + // hardest case.. just a word. + if projectBase == "" { + x, err := getWd() + if err == nil { + projectPath = filepath.Join(x, inputPath) + return + } + er(err) + } else { + projectPath = filepath.Join(srcPath, projectBase, inputPath) + return + } + } +} + +// isEmpty checks if a given path is empty. +func isEmpty(path string) (bool, error) { + if b, _ := exists(path); !b { + return false, fmt.Errorf("%q path does not exist", path) + } + fi, err := os.Stat(path) + if err != nil { + return false, err + } + if fi.IsDir() { + f, err := os.Open(path) + // FIX: Resource leak - f.close() should be called here by defer or is missed + // if the err != nil branch is taken. + defer f.Close() + if err != nil { + return false, err + } + list, err := f.Readdir(-1) + // f.Close() - see bug fix above + return len(list) == 0, nil + } + return fi.Size() == 0, nil +} + +// isDir checks if a given path is a directory. +func isDir(path string) (bool, error) { + fi, err := os.Stat(path) + if err != nil { + return false, err + } + return fi.IsDir(), nil +} + +// dirExists checks if a path exists and is a directory. +func dirExists(path string) (bool, error) { + fi, err := os.Stat(path) + if err == nil && fi.IsDir() { + return true, nil + } + if os.IsNotExist(err) { + return false, nil + } + return false, err +} + +func writeTemplateToFile(path string, file string, template string, data interface{}) error { + filename := filepath.Join(path, file) + + r, err := templateToReader(template, data) + + if err != nil { + return err + } + + err = safeWriteToDisk(filename, r) + + if err != nil { + return err + } + return nil +} + +func writeStringToFile(path, file, text string) error { + filename := filepath.Join(path, file) + + r := strings.NewReader(text) + err := safeWriteToDisk(filename, r) + + if err != nil { + return err + } + return nil +} + +func templateToReader(tpl string, data interface{}) (io.Reader, error) { + tmpl := template.New("") + tmpl.Funcs(funcMap) + tmpl, err := tmpl.Parse(tpl) + + if err != nil { + return nil, err + } + buf := new(bytes.Buffer) + err = tmpl.Execute(buf, data) + + return buf, err +} + +// Same as WriteToDisk but checks to see if file/directory already exists. +func safeWriteToDisk(inpath string, r io.Reader) (err error) { + dir, _ := filepath.Split(inpath) + ospath := filepath.FromSlash(dir) + + if ospath != "" { + err = os.MkdirAll(ospath, 0777) // rwx, rw, r + if err != nil { + return + } + } + + ex, err := exists(inpath) + if err != nil { + return + } + if ex { + return fmt.Errorf("%v already exists", inpath) + } + + file, err := os.Create(inpath) + if err != nil { + return + } + defer file.Close() + + _, err = io.Copy(file, r) + return +} + +func getLicense() License { + l := whichLicense() + if l != "" { + if x, ok := Licenses[l]; ok { + return x + } + } + + return Licenses["apache"] +} + +func whichLicense() string { + // if explicitly flagged, use that + if userLicense != "" { + return matchLicense(userLicense) + } + + // if already present in the project, use that + // TODO: Inspect project for existing license + + // default to viper's setting + + return matchLicense(viper.GetString("license")) +} + +func copyrightLine() string { + author := viper.GetString("author") + year := time.Now().Format("2006") + + return "Copyright © " + year + " " + author +} + +func commentifyString(in string) string { + var newlines []string + lines := strings.Split(in, "\n") + for _, x := range lines { + if !strings.HasPrefix(x, "//") { + if x != "" { + newlines = append(newlines, "// "+x) + } else { + newlines = append(newlines, "//") + } + } else { + newlines = append(newlines, x) + } + } + return strings.Join(newlines, "\n") +} diff --git a/Godeps/_workspace/src/github.com/spf13/cobra/cobra/cmd/helpers_test.go b/Godeps/_workspace/src/github.com/spf13/cobra/cobra/cmd/helpers_test.go new file mode 100644 index 00000000..bd0f7595 --- /dev/null +++ b/Godeps/_workspace/src/github.com/spf13/cobra/cobra/cmd/helpers_test.go @@ -0,0 +1,40 @@ +package cmd + +import ( + "fmt" + "os" + "path/filepath" + "testing" +) + +var _ = fmt.Println +var _ = os.Stderr + +func checkGuess(t *testing.T, wd, input, expected string) { + testWd = wd + inputPath = input + guessProjectPath() + + if projectPath != expected { + t.Errorf("Unexpected Project Path. \n Got: %q\nExpected: %q\n", projectPath, expected) + } + + reset() +} + +func reset() { + testWd = "" + inputPath = "" + projectPath = "" +} + +func TestProjectPath(t *testing.T) { + checkGuess(t, "", filepath.Join("github.com", "spf13", "hugo"), filepath.Join(getSrcPath(), "github.com", "spf13", "hugo")) + checkGuess(t, "", filepath.Join("spf13", "hugo"), filepath.Join(getSrcPath(), "github.com", "spf13", "hugo")) + checkGuess(t, "", filepath.Join("/", "bar", "foo"), filepath.Join("/", "bar", "foo")) + checkGuess(t, "/bar/foo", "baz", filepath.Join("/", "bar", "foo", "baz")) + checkGuess(t, "/bar/foo/cmd", "", filepath.Join("/", "bar", "foo")) + checkGuess(t, "/bar/foo/command", "", filepath.Join("/", "bar", "foo")) + checkGuess(t, "/bar/foo/commands", "", filepath.Join("/", "bar", "foo")) + checkGuess(t, "github.com/spf13/hugo/../hugo", "", filepath.Join("github.com", "spf13", "hugo")) +} diff --git a/Godeps/_workspace/src/github.com/spf13/cobra/cobra/cmd/init.go b/Godeps/_workspace/src/github.com/spf13/cobra/cobra/cmd/init.go new file mode 100644 index 00000000..c514dfed --- /dev/null +++ b/Godeps/_workspace/src/github.com/spf13/cobra/cobra/cmd/init.go @@ -0,0 +1,226 @@ +// Copyright © 2015 Steve Francia . +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cmd + +import ( + "fmt" + "os" + "strings" + + "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/spf13/cobra" + "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/spf13/viper" +) + +func init() { + RootCmd.AddCommand(initCmd) +} + +// initialize Command +var initCmd = &cobra.Command{ + Use: "init [name]", + Aliases: []string{"initialize", "initialise", "create"}, + Short: "Initialize a Cobra Application", + Long: `Initialize (cobra init) will create a new application, with a license +and the appropriate structure for a Cobra-based CLI application. + + * If a name is provided, it will be created in the current directory; + * If no name is provided, the current directory will be assumed; + * If a relative path is provided, it will be created inside $GOPATH + (e.g. github.com/spf13/hugo); + * If an absolute path is provided, it will be created; + * If the directory already exists but is empty, it will be used. + +Init will not use an existing directory with contents.`, + + Run: func(cmd *cobra.Command, args []string) { + switch len(args) { + case 0: + inputPath = "" + + case 1: + inputPath = args[0] + + default: + er("init doesn't support more than 1 parameter") + } + guessProjectPath() + initializePath(projectPath) + }, +} + +func initializePath(path string) { + b, err := exists(path) + if err != nil { + er(err) + } + + if !b { // If path doesn't yet exist, create it + err := os.MkdirAll(path, os.ModePerm) + if err != nil { + er(err) + } + } else { // If path exists and is not empty don't use it + empty, err := exists(path) + if err != nil { + er(err) + } + if !empty { + er("Cobra will not create a new project in a non empty directory") + } + } + // We have a directory and it's empty.. Time to initialize it. + + createLicenseFile() + createMainFile() + createRootCmdFile() +} + +func createLicenseFile() { + lic := getLicense() + + template := lic.Text + + var data map[string]interface{} + data = make(map[string]interface{}) + + // Try to remove the email address, if any + data["copyright"] = strings.Split(copyrightLine(), " <")[0] + + err := writeTemplateToFile(ProjectPath(), "LICENSE", template, data) + _ = err + // if err != nil { + // er(err) + // } +} + +func createMainFile() { + lic := getLicense() + + template := `{{ comment .copyright }} +{{ comment .license }} + +package main + +import "{{ .importpath }}" + +func main() { + cmd.Execute() +} +` + var data map[string]interface{} + data = make(map[string]interface{}) + + data["copyright"] = copyrightLine() + data["license"] = lic.Header + data["importpath"] = guessImportPath() + "/" + guessCmdDir() + + err := writeTemplateToFile(ProjectPath(), "main.go", template, data) + _ = err + // if err != nil { + // er(err) + // } +} + +func createRootCmdFile() { + lic := getLicense() + + template := `{{ comment .copyright }} +{{ comment .license }} + +package cmd + +import ( + "fmt" + "os" + + "github.com/spf13/cobra" +{{ if .viper }} "github.com/spf13/viper" +{{ end }}) +{{if .viper}} +var cfgFile string +{{ end }} +// This represents the base command when called without any subcommands +var RootCmd = &cobra.Command{ + Use: "{{ .appName }}", + Short: "A brief description of your application", + Long: ` + "`" + `A longer description that spans multiple lines and likely contains +examples and usage of using your application. For example: + +Cobra is a CLI library for Go that empowers applications. +This application is a tool to generate the needed files +to quickly create a Cobra application.` + "`" + `, +// Uncomment the following line if your bare application +// has an action associated with it: +// Run: func(cmd *cobra.Command, args []string) { }, +} + +// Execute adds all child commands to the root command sets flags appropriately. +// This is called by main.main(). It only needs to happen once to the rootCmd. +func Execute() { + if err := RootCmd.Execute(); err != nil { + fmt.Println(err) + os.Exit(-1) + } +} + +func init() { +{{ if .viper }} cobra.OnInitialize(initConfig) + +{{ end }} // Here you will define your flags and configuration settings. + // Cobra supports Persistent Flags, which, if defined here, + // will be global for your application. +{{ if .viper }} + RootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is $HOME/.{{ .appName }}.yaml)") +{{ else }} + // RootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is $HOME/.{{ .appName }}.yaml)") +{{ end }} // Cobra also supports local flags, which will only run + // when this action is called directly. + RootCmd.Flags().BoolP("toggle", "t", false, "Help message for toggle") +} +{{ if .viper }} +// initConfig reads in config file and ENV variables if set. +func initConfig() { + if cfgFile != "" { // enable ability to specify config file via flag + viper.SetConfigFile(cfgFile) + } + + viper.SetConfigName(".{{ .appName }}") // name of config file (without extension) + viper.AddConfigPath("$HOME") // adding home directory as first search path + viper.AutomaticEnv() // read in environment variables that match + + // If a config file is found, read it in. + if err := viper.ReadInConfig(); err == nil { + fmt.Println("Using config file:", viper.ConfigFileUsed()) + } +} +{{ end }}` + + var data map[string]interface{} + data = make(map[string]interface{}) + + data["copyright"] = copyrightLine() + data["license"] = lic.Header + data["appName"] = projectName() + data["viper"] = viper.GetBool("useViper") + + err := writeTemplateToFile(ProjectPath()+string(os.PathSeparator)+guessCmdDir(), "root.go", template, data) + if err != nil { + er(err) + } + + fmt.Println("Your Cobra application is ready at") + fmt.Println(ProjectPath()) + fmt.Println("Give it a try by going there and running `go run main.go`") + fmt.Println("Add commands to it by running `cobra add [cmdname]`") +} diff --git a/Godeps/_workspace/src/github.com/spf13/cobra/cobra/cmd/licenses.go b/Godeps/_workspace/src/github.com/spf13/cobra/cobra/cmd/licenses.go new file mode 100644 index 00000000..5ad9c96e --- /dev/null +++ b/Godeps/_workspace/src/github.com/spf13/cobra/cobra/cmd/licenses.go @@ -0,0 +1,1133 @@ +// Copyright © 2015 Steve Francia . +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +// Parts inspired by https://github.com/ryanuber/go-license + +package cmd + +import "strings" + +//Licenses contains all possible licenses a user can chose from +var Licenses map[string]License + +//License represents a software license agreement, containing the Name of +// the license, its possible matches (on the command line as given to cobra) +// the header to be used with each file on the file's creating, and the text +// of the license +type License struct { + Name string // The type of license in use + PossibleMatches []string // Similar names to guess + Text string // License text data + Header string // License header for source files +} + +// given a license name (in), try to match the license indicated +func matchLicense(in string) string { + for key, lic := range Licenses { + for _, match := range lic.PossibleMatches { + if strings.EqualFold(in, match) { + return key + } + } + } + return "" +} + +func init() { + Licenses = make(map[string]License) + + Licenses["apache"] = License{ + Name: "Apache 2.0", + PossibleMatches: []string{"apache", "apache20", "apache 2.0", "apache2.0", "apache-2.0"}, + Header: ` +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License.`, + Text: ` + Apache License + Version 2.0, January 2004 + http://www.apache.org/licenses/ + + TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION + + 1. Definitions. + + "License" shall mean the terms and conditions for use, reproduction, + and distribution as defined by Sections 1 through 9 of this document. + + "Licensor" shall mean the copyright owner or entity authorized by + the copyright owner that is granting the License. + + "Legal Entity" shall mean the union of the acting entity and all + other entities that control, are controlled by, or are under common + control with that entity. For the purposes of this definition, + "control" means (i) the power, direct or indirect, to cause the + direction or management of such entity, whether by contract or + otherwise, or (ii) ownership of fifty percent (50%) or more of the + outstanding shares, or (iii) beneficial ownership of such entity. + + "You" (or "Your") shall mean an individual or Legal Entity + exercising permissions granted by this License. + + "Source" form shall mean the preferred form for making modifications, + including but not limited to software source code, documentation + source, and configuration files. + + "Object" form shall mean any form resulting from mechanical + transformation or translation of a Source form, including but + not limited to compiled object code, generated documentation, + and conversions to other media types. + + "Work" shall mean the work of authorship, whether in Source or + Object form, made available under the License, as indicated by a + copyright notice that is included in or attached to the work + (an example is provided in the Appendix below). + + "Derivative Works" shall mean any work, whether in Source or Object + form, that is based on (or derived from) the Work and for which the + editorial revisions, annotations, elaborations, or other modifications + represent, as a whole, an original work of authorship. For the purposes + of this License, Derivative Works shall not include works that remain + separable from, or merely link (or bind by name) to the interfaces of, + the Work and Derivative Works thereof. + + "Contribution" shall mean any work of authorship, including + the original version of the Work and any modifications or additions + to that Work or Derivative Works thereof, that is intentionally + submitted to Licensor for inclusion in the Work by the copyright owner + or by an individual or Legal Entity authorized to submit on behalf of + the copyright owner. For the purposes of this definition, "submitted" + means any form of electronic, verbal, or written communication sent + to the Licensor or its representatives, including but not limited to + communication on electronic mailing lists, source code control systems, + and issue tracking systems that are managed by, or on behalf of, the + Licensor for the purpose of discussing and improving the Work, but + excluding communication that is conspicuously marked or otherwise + designated in writing by the copyright owner as "Not a Contribution." + + "Contributor" shall mean Licensor and any individual or Legal Entity + on behalf of whom a Contribution has been received by Licensor and + subsequently incorporated within the Work. + + 2. Grant of Copyright License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + copyright license to reproduce, prepare Derivative Works of, + publicly display, publicly perform, sublicense, and distribute the + Work and such Derivative Works in Source or Object form. + + 3. Grant of Patent License. Subject to the terms and conditions of + this License, each Contributor hereby grants to You a perpetual, + worldwide, non-exclusive, no-charge, royalty-free, irrevocable + (except as stated in this section) patent license to make, have made, + use, offer to sell, sell, import, and otherwise transfer the Work, + where such license applies only to those patent claims licensable + by such Contributor that are necessarily infringed by their + Contribution(s) alone or by combination of their Contribution(s) + with the Work to which such Contribution(s) was submitted. If You + institute patent litigation against any entity (including a + cross-claim or counterclaim in a lawsuit) alleging that the Work + or a Contribution incorporated within the Work constitutes direct + or contributory patent infringement, then any patent licenses + granted to You under this License for that Work shall terminate + as of the date such litigation is filed. + + 4. Redistribution. You may reproduce and distribute copies of the + Work or Derivative Works thereof in any medium, with or without + modifications, and in Source or Object form, provided that You + meet the following conditions: + + (a) You must give any other recipients of the Work or + Derivative Works a copy of this License; and + + (b) You must cause any modified files to carry prominent notices + stating that You changed the files; and + + (c) You must retain, in the Source form of any Derivative Works + that You distribute, all copyright, patent, trademark, and + attribution notices from the Source form of the Work, + excluding those notices that do not pertain to any part of + the Derivative Works; and + + (d) If the Work includes a "NOTICE" text file as part of its + distribution, then any Derivative Works that You distribute must + include a readable copy of the attribution notices contained + within such NOTICE file, excluding those notices that do not + pertain to any part of the Derivative Works, in at least one + of the following places: within a NOTICE text file distributed + as part of the Derivative Works; within the Source form or + documentation, if provided along with the Derivative Works; or, + within a display generated by the Derivative Works, if and + wherever such third-party notices normally appear. The contents + of the NOTICE file are for informational purposes only and + do not modify the License. You may add Your own attribution + notices within Derivative Works that You distribute, alongside + or as an addendum to the NOTICE text from the Work, provided + that such additional attribution notices cannot be construed + as modifying the License. + + You may add Your own copyright statement to Your modifications and + may provide additional or different license terms and conditions + for use, reproduction, or distribution of Your modifications, or + for any such Derivative Works as a whole, provided Your use, + reproduction, and distribution of the Work otherwise complies with + the conditions stated in this License. + + 5. Submission of Contributions. Unless You explicitly state otherwise, + any Contribution intentionally submitted for inclusion in the Work + by You to the Licensor shall be under the terms and conditions of + this License, without any additional terms or conditions. + Notwithstanding the above, nothing herein shall supersede or modify + the terms of any separate license agreement you may have executed + with Licensor regarding such Contributions. + + 6. Trademarks. This License does not grant permission to use the trade + names, trademarks, service marks, or product names of the Licensor, + except as required for reasonable and customary use in describing the + origin of the Work and reproducing the content of the NOTICE file. + + 7. Disclaimer of Warranty. Unless required by applicable law or + agreed to in writing, Licensor provides the Work (and each + Contributor provides its Contributions) on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or + implied, including, without limitation, any warranties or conditions + of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A + PARTICULAR PURPOSE. You are solely responsible for determining the + appropriateness of using or redistributing the Work and assume any + risks associated with Your exercise of permissions under this License. + + 8. Limitation of Liability. In no event and under no legal theory, + whether in tort (including negligence), contract, or otherwise, + unless required by applicable law (such as deliberate and grossly + negligent acts) or agreed to in writing, shall any Contributor be + liable to You for damages, including any direct, indirect, special, + incidental, or consequential damages of any character arising as a + result of this License or out of the use or inability to use the + Work (including but not limited to damages for loss of goodwill, + work stoppage, computer failure or malfunction, or any and all + other commercial damages or losses), even if such Contributor + has been advised of the possibility of such damages. + + 9. Accepting Warranty or Additional Liability. While redistributing + the Work or Derivative Works thereof, You may choose to offer, + and charge a fee for, acceptance of support, warranty, indemnity, + or other liability obligations and/or rights consistent with this + License. However, in accepting such obligations, You may act only + on Your own behalf and on Your sole responsibility, not on behalf + of any other Contributor, and only if You agree to indemnify, + defend, and hold each Contributor harmless for any liability + incurred by, or claims asserted against, such Contributor by reason + of your accepting any such warranty or additional liability. + + END OF TERMS AND CONDITIONS + + APPENDIX: How to apply the Apache License to your work. + + To apply the Apache License to your work, attach the following + boilerplate notice, with the fields enclosed by brackets "[]" + replaced with your own identifying information. (Don't include + the brackets!) The text should be enclosed in the appropriate + comment syntax for the file format. We also recommend that a + file or class name and description of purpose be included on the + same "printed page" as the copyright notice for easier + identification within third-party archives. + + Copyright [yyyy] [name of copyright owner] + + Licensed under the Apache License, Version 2.0 (the "License"); + you may not use this file except in compliance with the License. + You may obtain a copy of the License at + + http://www.apache.org/licenses/LICENSE-2.0 + + Unless required by applicable law or agreed to in writing, software + distributed under the License is distributed on an "AS IS" BASIS, + WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + See the License for the specific language governing permissions and + limitations under the License. +`, + } + + Licenses["mit"] = License{ + Name: "Mit", + PossibleMatches: []string{"mit"}, + Header: ` +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE.`, + Text: `The MIT License (MIT) + +{{ .copyright }} + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in +all copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN +THE SOFTWARE. +`, + } + + Licenses["bsd"] = License{ + Name: "NewBSD", + PossibleMatches: []string{"bsd", "newbsd", "3 clause bsd"}, + Header: ` +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors + may be used to endorse or promote products derived from this software + without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE.`, + Text: `{{ .copyright }} +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +3. Neither the name of the copyright holder nor the names of its contributors + may be used to endorse or promote products derived from this software + without specific prior written permission. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +`, + } + + Licenses["freebsd"] = License{ + Name: "Simplified BSD License", + PossibleMatches: []string{"freebsd", "simpbsd", "simple bsd", "2 clause bsd"}, + Header: ` +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE +ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE +LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR +CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF +SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS +INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN +CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) +ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE +POSSIBILITY OF SUCH DAMAGE.`, + Text: `{{ .copyright }} +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, + this list of conditions and the following disclaimer. + +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +`, + } + + Licenses["gpl3"] = License{ + Name: "GNU General Public License 3.0", + PossibleMatches: []string{"gpl3", "gpl", "gnu gpl3", "gnu gpl"}, + Header: `{{ .copyright }} + + This file is part of {{ .appName }}. + + {{ .appName }} is free software: you can redistribute it and/or modify + it under the terms of the GNU Lesser General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + {{ .appName }} is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU Lesser General Public License for more details. + + You should have received a copy of the GNU Lesser General Public License + along with {{ .appName }}. If not, see . + `, + Text: ` GNU GENERAL PUBLIC LICENSE + Version 3, 29 June 2007 + + Copyright (C) 2007 Free Software Foundation, Inc. + Everyone is permitted to copy and distribute verbatim copies + of this license document, but changing it is not allowed. + + Preamble + + The GNU General Public License is a free, copyleft license for +software and other kinds of works. + + The licenses for most software and other practical works are designed +to take away your freedom to share and change the works. By contrast, +the GNU General Public License is intended to guarantee your freedom to +share and change all versions of a program--to make sure it remains free +software for all its users. We, the Free Software Foundation, use the +GNU General Public License for most of our software; it applies also to +any other work released this way by its authors. You can apply it to +your programs, too. + + When we speak of free software, we are referring to freedom, not +price. Our General Public Licenses are designed to make sure that you +have the freedom to distribute copies of free software (and charge for +them if you wish), that you receive source code or can get it if you +want it, that you can change the software or use pieces of it in new +free programs, and that you know you can do these things. + + To protect your rights, we need to prevent others from denying you +these rights or asking you to surrender the rights. Therefore, you have +certain responsibilities if you distribute copies of the software, or if +you modify it: responsibilities to respect the freedom of others. + + For example, if you distribute copies of such a program, whether +gratis or for a fee, you must pass on to the recipients the same +freedoms that you received. You must make sure that they, too, receive +or can get the source code. And you must show them these terms so they +know their rights. + + Developers that use the GNU GPL protect your rights with two steps: +(1) assert copyright on the software, and (2) offer you this License +giving you legal permission to copy, distribute and/or modify it. + + For the developers' and authors' protection, the GPL clearly explains +that there is no warranty for this free software. For both users' and +authors' sake, the GPL requires that modified versions be marked as +changed, so that their problems will not be attributed erroneously to +authors of previous versions. + + Some devices are designed to deny users access to install or run +modified versions of the software inside them, although the manufacturer +can do so. This is fundamentally incompatible with the aim of +protecting users' freedom to change the software. The systematic +pattern of such abuse occurs in the area of products for individuals to +use, which is precisely where it is most unacceptable. Therefore, we +have designed this version of the GPL to prohibit the practice for those +products. If such problems arise substantially in other domains, we +stand ready to extend this provision to those domains in future versions +of the GPL, as needed to protect the freedom of users. + + Finally, every program is threatened constantly by software patents. +States should not allow patents to restrict development and use of +software on general-purpose computers, but in those that do, we wish to +avoid the special danger that patents applied to a free program could +make it effectively proprietary. To prevent this, the GPL assures that +patents cannot be used to render the program non-free. + + The precise terms and conditions for copying, distribution and +modification follow. + + TERMS AND CONDITIONS + + 0. Definitions. + + "This License" refers to version 3 of the GNU General Public License. + + "Copyright" also means copyright-like laws that apply to other kinds of +works, such as semiconductor masks. + + "The Program" refers to any copyrightable work licensed under this +License. Each licensee is addressed as "you". "Licensees" and +"recipients" may be individuals or organizations. + + To "modify" a work means to copy from or adapt all or part of the work +in a fashion requiring copyright permission, other than the making of an +exact copy. The resulting work is called a "modified version" of the +earlier work or a work "based on" the earlier work. + + A "covered work" means either the unmodified Program or a work based +on the Program. + + To "propagate" a work means to do anything with it that, without +permission, would make you directly or secondarily liable for +infringement under applicable copyright law, except executing it on a +computer or modifying a private copy. Propagation includes copying, +distribution (with or without modification), making available to the +public, and in some countries other activities as well. + + To "convey" a work means any kind of propagation that enables other +parties to make or receive copies. Mere interaction with a user through +a computer network, with no transfer of a copy, is not conveying. + + An interactive user interface displays "Appropriate Legal Notices" +to the extent that it includes a convenient and prominently visible +feature that (1) displays an appropriate copyright notice, and (2) +tells the user that there is no warranty for the work (except to the +extent that warranties are provided), that licensees may convey the +work under this License, and how to view a copy of this License. If +the interface presents a list of user commands or options, such as a +menu, a prominent item in the list meets this criterion. + + 1. Source Code. + + The "source code" for a work means the preferred form of the work +for making modifications to it. "Object code" means any non-source +form of a work. + + A "Standard Interface" means an interface that either is an official +standard defined by a recognized standards body, or, in the case of +interfaces specified for a particular programming language, one that +is widely used among developers working in that language. + + The "System Libraries" of an executable work include anything, other +than the work as a whole, that (a) is included in the normal form of +packaging a Major Component, but which is not part of that Major +Component, and (b) serves only to enable use of the work with that +Major Component, or to implement a Standard Interface for which an +implementation is available to the public in source code form. A +"Major Component", in this context, means a major essential component +(kernel, window system, and so on) of the specific operating system +(if any) on which the executable work runs, or a compiler used to +produce the work, or an object code interpreter used to run it. + + The "Corresponding Source" for a work in object code form means all +the source code needed to generate, install, and (for an executable +work) run the object code and to modify the work, including scripts to +control those activities. However, it does not include the work's +System Libraries, or general-purpose tools or generally available free +programs which are used unmodified in performing those activities but +which are not part of the work. For example, Corresponding Source +includes interface definition files associated with source files for +the work, and the source code for shared libraries and dynamically +linked subprograms that the work is specifically designed to require, +such as by intimate data communication or control flow between those +subprograms and other parts of the work. + + The Corresponding Source need not include anything that users +can regenerate automatically from other parts of the Corresponding +Source. + + The Corresponding Source for a work in source code form is that +same work. + + 2. Basic Permissions. + + All rights granted under this License are granted for the term of +copyright on the Program, and are irrevocable provided the stated +conditions are met. This License explicitly affirms your unlimited +permission to run the unmodified Program. The output from running a +covered work is covered by this License only if the output, given its +content, constitutes a covered work. This License acknowledges your +rights of fair use or other equivalent, as provided by copyright law. + + You may make, run and propagate covered works that you do not +convey, without conditions so long as your license otherwise remains +in force. You may convey covered works to others for the sole purpose +of having them make modifications exclusively for you, or provide you +with facilities for running those works, provided that you comply with +the terms of this License in conveying all material for which you do +not control copyright. Those thus making or running the covered works +for you must do so exclusively on your behalf, under your direction +and control, on terms that prohibit them from making any copies of +your copyrighted material outside their relationship with you. + + Conveying under any other circumstances is permitted solely under +the conditions stated below. Sublicensing is not allowed; section 10 +makes it unnecessary. + + 3. Protecting Users' Legal Rights From Anti-Circumvention Law. + + No covered work shall be deemed part of an effective technological +measure under any applicable law fulfilling obligations under article +11 of the WIPO copyright treaty adopted on 20 December 1996, or +similar laws prohibiting or restricting circumvention of such +measures. + + When you convey a covered work, you waive any legal power to forbid +circumvention of technological measures to the extent such circumvention +is effected by exercising rights under this License with respect to +the covered work, and you disclaim any intention to limit operation or +modification of the work as a means of enforcing, against the work's +users, your or third parties' legal rights to forbid circumvention of +technological measures. + + 4. Conveying Verbatim Copies. + + You may convey verbatim copies of the Program's source code as you +receive it, in any medium, provided that you conspicuously and +appropriately publish on each copy an appropriate copyright notice; +keep intact all notices stating that this License and any +non-permissive terms added in accord with section 7 apply to the code; +keep intact all notices of the absence of any warranty; and give all +recipients a copy of this License along with the Program. + + You may charge any price or no price for each copy that you convey, +and you may offer support or warranty protection for a fee. + + 5. Conveying Modified Source Versions. + + You may convey a work based on the Program, or the modifications to +produce it from the Program, in the form of source code under the +terms of section 4, provided that you also meet all of these conditions: + + a) The work must carry prominent notices stating that you modified + it, and giving a relevant date. + + b) The work must carry prominent notices stating that it is + released under this License and any conditions added under section + 7. This requirement modifies the requirement in section 4 to + "keep intact all notices". + + c) You must license the entire work, as a whole, under this + License to anyone who comes into possession of a copy. This + License will therefore apply, along with any applicable section 7 + additional terms, to the whole of the work, and all its parts, + regardless of how they are packaged. This License gives no + permission to license the work in any other way, but it does not + invalidate such permission if you have separately received it. + + d) If the work has interactive user interfaces, each must display + Appropriate Legal Notices; however, if the Program has interactive + interfaces that do not display Appropriate Legal Notices, your + work need not make them do so. + + A compilation of a covered work with other separate and independent +works, which are not by their nature extensions of the covered work, +and which are not combined with it such as to form a larger program, +in or on a volume of a storage or distribution medium, is called an +"aggregate" if the compilation and its resulting copyright are not +used to limit the access or legal rights of the compilation's users +beyond what the individual works permit. Inclusion of a covered work +in an aggregate does not cause this License to apply to the other +parts of the aggregate. + + 6. Conveying Non-Source Forms. + + You may convey a covered work in object code form under the terms +of sections 4 and 5, provided that you also convey the +machine-readable Corresponding Source under the terms of this License, +in one of these ways: + + a) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by the + Corresponding Source fixed on a durable physical medium + customarily used for software interchange. + + b) Convey the object code in, or embodied in, a physical product + (including a physical distribution medium), accompanied by a + written offer, valid for at least three years and valid for as + long as you offer spare parts or customer support for that product + model, to give anyone who possesses the object code either (1) a + copy of the Corresponding Source for all the software in the + product that is covered by this License, on a durable physical + medium customarily used for software interchange, for a price no + more than your reasonable cost of physically performing this + conveying of source, or (2) access to copy the + Corresponding Source from a network server at no charge. + + c) Convey individual copies of the object code with a copy of the + written offer to provide the Corresponding Source. This + alternative is allowed only occasionally and noncommercially, and + only if you received the object code with such an offer, in accord + with subsection 6b. + + d) Convey the object code by offering access from a designated + place (gratis or for a charge), and offer equivalent access to the + Corresponding Source in the same way through the same place at no + further charge. You need not require recipients to copy the + Corresponding Source along with the object code. If the place to + copy the object code is a network server, the Corresponding Source + may be on a different server (operated by you or a third party) + that supports equivalent copying facilities, provided you maintain + clear directions next to the object code saying where to find the + Corresponding Source. Regardless of what server hosts the + Corresponding Source, you remain obligated to ensure that it is + available for as long as needed to satisfy these requirements. + + e) Convey the object code using peer-to-peer transmission, provided + you inform other peers where the object code and Corresponding + Source of the work are being offered to the general public at no + charge under subsection 6d. + + A separable portion of the object code, whose source code is excluded +from the Corresponding Source as a System Library, need not be +included in conveying the object code work. + + A "User Product" is either (1) a "consumer product", which means any +tangible personal property which is normally used for personal, family, +or household purposes, or (2) anything designed or sold for incorporation +into a dwelling. In determining whether a product is a consumer product, +doubtful cases shall be resolved in favor of coverage. For a particular +product received by a particular user, "normally used" refers to a +typical or common use of that class of product, regardless of the status +of the particular user or of the way in which the particular user +actually uses, or expects or is expected to use, the product. A product +is a consumer product regardless of whether the product has substantial +commercial, industrial or non-consumer uses, unless such uses represent +the only significant mode of use of the product. + + "Installation Information" for a User Product means any methods, +procedures, authorization keys, or other information required to install +and execute modified versions of a covered work in that User Product from +a modified version of its Corresponding Source. The information must +suffice to ensure that the continued functioning of the modified object +code is in no case prevented or interfered with solely because +modification has been made. + + If you convey an object code work under this section in, or with, or +specifically for use in, a User Product, and the conveying occurs as +part of a transaction in which the right of possession and use of the +User Product is transferred to the recipient in perpetuity or for a +fixed term (regardless of how the transaction is characterized), the +Corresponding Source conveyed under this section must be accompanied +by the Installation Information. But this requirement does not apply +if neither you nor any third party retains the ability to install +modified object code on the User Product (for example, the work has +been installed in ROM). + + The requirement to provide Installation Information does not include a +requirement to continue to provide support service, warranty, or updates +for a work that has been modified or installed by the recipient, or for +the User Product in which it has been modified or installed. Access to a +network may be denied when the modification itself materially and +adversely affects the operation of the network or violates the rules and +protocols for communication across the network. + + Corresponding Source conveyed, and Installation Information provided, +in accord with this section must be in a format that is publicly +documented (and with an implementation available to the public in +source code form), and must require no special password or key for +unpacking, reading or copying. + + 7. Additional Terms. + + "Additional permissions" are terms that supplement the terms of this +License by making exceptions from one or more of its conditions. +Additional permissions that are applicable to the entire Program shall +be treated as though they were included in this License, to the extent +that they are valid under applicable law. If additional permissions +apply only to part of the Program, that part may be used separately +under those permissions, but the entire Program remains governed by +this License without regard to the additional permissions. + + When you convey a copy of a covered work, you may at your option +remove any additional permissions from that copy, or from any part of +it. (Additional permissions may be written to require their own +removal in certain cases when you modify the work.) You may place +additional permissions on material, added by you to a covered work, +for which you have or can give appropriate copyright permission. + + Notwithstanding any other provision of this License, for material you +add to a covered work, you may (if authorized by the copyright holders of +that material) supplement the terms of this License with terms: + + a) Disclaiming warranty or limiting liability differently from the + terms of sections 15 and 16 of this License; or + + b) Requiring preservation of specified reasonable legal notices or + author attributions in that material or in the Appropriate Legal + Notices displayed by works containing it; or + + c) Prohibiting misrepresentation of the origin of that material, or + requiring that modified versions of such material be marked in + reasonable ways as different from the original version; or + + d) Limiting the use for publicity purposes of names of licensors or + authors of the material; or + + e) Declining to grant rights under trademark law for use of some + trade names, trademarks, or service marks; or + + f) Requiring indemnification of licensors and authors of that + material by anyone who conveys the material (or modified versions of + it) with contractual assumptions of liability to the recipient, for + any liability that these contractual assumptions directly impose on + those licensors and authors. + + All other non-permissive additional terms are considered "further +restrictions" within the meaning of section 10. If the Program as you +received it, or any part of it, contains a notice stating that it is +governed by this License along with a term that is a further +restriction, you may remove that term. If a license document contains +a further restriction but permits relicensing or conveying under this +License, you may add to a covered work material governed by the terms +of that license document, provided that the further restriction does +not survive such relicensing or conveying. + + If you add terms to a covered work in accord with this section, you +must place, in the relevant source files, a statement of the +additional terms that apply to those files, or a notice indicating +where to find the applicable terms. + + Additional terms, permissive or non-permissive, may be stated in the +form of a separately written license, or stated as exceptions; +the above requirements apply either way. + + 8. Termination. + + You may not propagate or modify a covered work except as expressly +provided under this License. Any attempt otherwise to propagate or +modify it is void, and will automatically terminate your rights under +this License (including any patent licenses granted under the third +paragraph of section 11). + + However, if you cease all violation of this License, then your +license from a particular copyright holder is reinstated (a) +provisionally, unless and until the copyright holder explicitly and +finally terminates your license, and (b) permanently, if the copyright +holder fails to notify you of the violation by some reasonable means +prior to 60 days after the cessation. + + Moreover, your license from a particular copyright holder is +reinstated permanently if the copyright holder notifies you of the +violation by some reasonable means, this is the first time you have +received notice of violation of this License (for any work) from that +copyright holder, and you cure the violation prior to 30 days after +your receipt of the notice. + + Termination of your rights under this section does not terminate the +licenses of parties who have received copies or rights from you under +this License. If your rights have been terminated and not permanently +reinstated, you do not qualify to receive new licenses for the same +material under section 10. + + 9. Acceptance Not Required for Having Copies. + + You are not required to accept this License in order to receive or +run a copy of the Program. Ancillary propagation of a covered work +occurring solely as a consequence of using peer-to-peer transmission +to receive a copy likewise does not require acceptance. However, +nothing other than this License grants you permission to propagate or +modify any covered work. These actions infringe copyright if you do +not accept this License. Therefore, by modifying or propagating a +covered work, you indicate your acceptance of this License to do so. + + 10. Automatic Licensing of Downstream Recipients. + + Each time you convey a covered work, the recipient automatically +receives a license from the original licensors, to run, modify and +propagate that work, subject to this License. You are not responsible +for enforcing compliance by third parties with this License. + + An "entity transaction" is a transaction transferring control of an +organization, or substantially all assets of one, or subdividing an +organization, or merging organizations. If propagation of a covered +work results from an entity transaction, each party to that +transaction who receives a copy of the work also receives whatever +licenses to the work the party's predecessor in interest had or could +give under the previous paragraph, plus a right to possession of the +Corresponding Source of the work from the predecessor in interest, if +the predecessor has it or can get it with reasonable efforts. + + You may not impose any further restrictions on the exercise of the +rights granted or affirmed under this License. For example, you may +not impose a license fee, royalty, or other charge for exercise of +rights granted under this License, and you may not initiate litigation +(including a cross-claim or counterclaim in a lawsuit) alleging that +any patent claim is infringed by making, using, selling, offering for +sale, or importing the Program or any portion of it. + + 11. Patents. + + A "contributor" is a copyright holder who authorizes use under this +License of the Program or a work on which the Program is based. The +work thus licensed is called the contributor's "contributor version". + + A contributor's "essential patent claims" are all patent claims +owned or controlled by the contributor, whether already acquired or +hereafter acquired, that would be infringed by some manner, permitted +by this License, of making, using, or selling its contributor version, +but do not include claims that would be infringed only as a +consequence of further modification of the contributor version. For +purposes of this definition, "control" includes the right to grant +patent sublicenses in a manner consistent with the requirements of +this License. + + Each contributor grants you a non-exclusive, worldwide, royalty-free +patent license under the contributor's essential patent claims, to +make, use, sell, offer for sale, import and otherwise run, modify and +propagate the contents of its contributor version. + + In the following three paragraphs, a "patent license" is any express +agreement or commitment, however denominated, not to enforce a patent +(such as an express permission to practice a patent or covenant not to +sue for patent infringement). To "grant" such a patent license to a +party means to make such an agreement or commitment not to enforce a +patent against the party. + + If you convey a covered work, knowingly relying on a patent license, +and the Corresponding Source of the work is not available for anyone +to copy, free of charge and under the terms of this License, through a +publicly available network server or other readily accessible means, +then you must either (1) cause the Corresponding Source to be so +available, or (2) arrange to deprive yourself of the benefit of the +patent license for this particular work, or (3) arrange, in a manner +consistent with the requirements of this License, to extend the patent +license to downstream recipients. "Knowingly relying" means you have +actual knowledge that, but for the patent license, your conveying the +covered work in a country, or your recipient's use of the covered work +in a country, would infringe one or more identifiable patents in that +country that you have reason to believe are valid. + + If, pursuant to or in connection with a single transaction or +arrangement, you convey, or propagate by procuring conveyance of, a +covered work, and grant a patent license to some of the parties +receiving the covered work authorizing them to use, propagate, modify +or convey a specific copy of the covered work, then the patent license +you grant is automatically extended to all recipients of the covered +work and works based on it. + + A patent license is "discriminatory" if it does not include within +the scope of its coverage, prohibits the exercise of, or is +conditioned on the non-exercise of one or more of the rights that are +specifically granted under this License. You may not convey a covered +work if you are a party to an arrangement with a third party that is +in the business of distributing software, under which you make payment +to the third party based on the extent of your activity of conveying +the work, and under which the third party grants, to any of the +parties who would receive the covered work from you, a discriminatory +patent license (a) in connection with copies of the covered work +conveyed by you (or copies made from those copies), or (b) primarily +for and in connection with specific products or compilations that +contain the covered work, unless you entered into that arrangement, +or that patent license was granted, prior to 28 March 2007. + + Nothing in this License shall be construed as excluding or limiting +any implied license or other defenses to infringement that may +otherwise be available to you under applicable patent law. + + 12. No Surrender of Others' Freedom. + + If conditions are imposed on you (whether by court order, agreement or +otherwise) that contradict the conditions of this License, they do not +excuse you from the conditions of this License. If you cannot convey a +covered work so as to satisfy simultaneously your obligations under this +License and any other pertinent obligations, then as a consequence you may +not convey it at all. For example, if you agree to terms that obligate you +to collect a royalty for further conveying from those to whom you convey +the Program, the only way you could satisfy both those terms and this +License would be to refrain entirely from conveying the Program. + + 13. Use with the GNU Affero General Public License. + + Notwithstanding any other provision of this License, you have +permission to link or combine any covered work with a work licensed +under version 3 of the GNU Affero General Public License into a single +combined work, and to convey the resulting work. The terms of this +License will continue to apply to the part which is the covered work, +but the special requirements of the GNU Affero General Public License, +section 13, concerning interaction through a network will apply to the +combination as such. + + 14. Revised Versions of this License. + + The Free Software Foundation may publish revised and/or new versions of +the GNU General Public License from time to time. Such new versions will +be similar in spirit to the present version, but may differ in detail to +address new problems or concerns. + + Each version is given a distinguishing version number. If the +Program specifies that a certain numbered version of the GNU General +Public License "or any later version" applies to it, you have the +option of following the terms and conditions either of that numbered +version or of any later version published by the Free Software +Foundation. If the Program does not specify a version number of the +GNU General Public License, you may choose any version ever published +by the Free Software Foundation. + + If the Program specifies that a proxy can decide which future +versions of the GNU General Public License can be used, that proxy's +public statement of acceptance of a version permanently authorizes you +to choose that version for the Program. + + Later license versions may give you additional or different +permissions. However, no additional obligations are imposed on any +author or copyright holder as a result of your choosing to follow a +later version. + + 15. Disclaimer of Warranty. + + THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY +APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT +HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY +OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO, +THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR +PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM +IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF +ALL NECESSARY SERVICING, REPAIR OR CORRECTION. + + 16. Limitation of Liability. + + IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING +WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS +THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY +GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE +USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF +DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD +PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS), +EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF +SUCH DAMAGES. + + 17. Interpretation of Sections 15 and 16. + + If the disclaimer of warranty and limitation of liability provided +above cannot be given local legal effect according to their terms, +reviewing courts shall apply local law that most closely approximates +an absolute waiver of all civil liability in connection with the +Program, unless a warranty or assumption of liability accompanies a +copy of the Program in return for a fee. + + END OF TERMS AND CONDITIONS + + How to Apply These Terms to Your New Programs + + If you develop a new program, and you want it to be of the greatest +possible use to the public, the best way to achieve this is to make it +free software which everyone can redistribute and change under these terms. + + To do so, attach the following notices to the program. It is safest +to attach them to the start of each source file to most effectively +state the exclusion of warranty; and each file should have at least +the "copyright" line and a pointer to where the full notice is found. + + + Copyright (C) + + This program is free software: you can redistribute it and/or modify + it under the terms of the GNU General Public License as published by + the Free Software Foundation, either version 3 of the License, or + (at your option) any later version. + + This program is distributed in the hope that it will be useful, + but WITHOUT ANY WARRANTY; without even the implied warranty of + MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + GNU General Public License for more details. + + You should have received a copy of the GNU General Public License + along with this program. If not, see . + +Also add information on how to contact you by electronic and paper mail. + + If the program does terminal interaction, make it output a short +notice like this when it starts in an interactive mode: + + Copyright (C) + This program comes with ABSOLUTELY NO WARRANTY; for details type 'show w'. + This is free software, and you are welcome to redistribute it + under certain conditions; type 'show c' for details. + +The hypothetical commands 'show w' and 'show c' should show the appropriate +parts of the General Public License. Of course, your program's commands +might be different; for a GUI interface, you would use an "about box". + + You should also get your employer (if you work as a programmer) or school, +if any, to sign a "copyright disclaimer" for the program, if necessary. +For more information on this, and how to apply and follow the GNU GPL, see +. + + The GNU General Public License does not permit incorporating your program +into proprietary programs. If your program is a subroutine library, you +may consider it more useful to permit linking proprietary applications with +the library. If this is what you want to do, use the GNU Lesser General +Public License instead of this License. But first, please read +. +`, + } + + // Licenses["apache20"] = License{ + // Name: "Apache 2.0", + // PossibleMatches: []string{"apache", "apache20", ""}, + // Header: ` + // `, + // Text: ` + // `, + // } +} diff --git a/Godeps/_workspace/src/github.com/spf13/cobra/cobra/cmd/root.go b/Godeps/_workspace/src/github.com/spf13/cobra/cobra/cmd/root.go new file mode 100644 index 00000000..07687940 --- /dev/null +++ b/Godeps/_workspace/src/github.com/spf13/cobra/cobra/cmd/root.go @@ -0,0 +1,84 @@ +// Copyright © 2015 Steve Francia . +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package cmd + +import ( + "fmt" + "os" + + "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/spf13/cobra" + "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/spf13/viper" +) + +var cfgFile string +var userLicense string + +// This represents the base command when called without any subcommands +var RootCmd = &cobra.Command{ + Use: "cobra", + Short: "A generator for Cobra based Applications", + Long: `Cobra is a CLI library for Go that empowers applications. +This application is a tool to generate the needed files +to quickly create a Cobra application.`, +} + +//Execute adds all child commands to the root command sets flags appropriately. +func Execute() { + if err := RootCmd.Execute(); err != nil { + fmt.Println(err) + os.Exit(-1) + } +} + +func init() { + cobra.OnInitialize(initConfig) + RootCmd.PersistentFlags().StringVar(&cfgFile, "config", "", "config file (default is $HOME/.cobra.yaml)") + RootCmd.PersistentFlags().StringVarP(&projectBase, "projectbase", "b", "", "base project directory, e.g. github.com/spf13/") + RootCmd.PersistentFlags().StringP("author", "a", "YOUR NAME", "Author name for copyright attribution") + RootCmd.PersistentFlags().StringVarP(&userLicense, "license", "l", "", "Name of license for the project (can provide `licensetext` in config)") + RootCmd.PersistentFlags().Bool("viper", true, "Use Viper for configuration") + viper.BindPFlag("author", RootCmd.PersistentFlags().Lookup("author")) + viper.BindPFlag("projectbase", RootCmd.PersistentFlags().Lookup("projectbase")) + viper.BindPFlag("useViper", RootCmd.PersistentFlags().Lookup("viper")) + viper.SetDefault("author", "NAME HERE ") + viper.SetDefault("license", "apache") + viper.SetDefault("licenseText", ` +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. +`) +} + +// Read in config file and ENV variables if set. +func initConfig() { + if cfgFile != "" { // enable ability to specify config file via flag + viper.SetConfigFile(cfgFile) + } + + viper.SetConfigName(".cobra") // name of config file (without extension) + viper.AddConfigPath("$HOME") // adding home directory as first search path + viper.AutomaticEnv() // read in environment variables that match + + // If a config file is found, read it in. + if err := viper.ReadInConfig(); err == nil { + fmt.Println("Using config file:", viper.ConfigFileUsed()) + } +} diff --git a/Godeps/_workspace/src/github.com/spf13/cobra/cobra/main.go b/Godeps/_workspace/src/github.com/spf13/cobra/cobra/main.go new file mode 100644 index 00000000..4c455b27 --- /dev/null +++ b/Godeps/_workspace/src/github.com/spf13/cobra/cobra/main.go @@ -0,0 +1,20 @@ +// Copyright © 2015 Steve Francia . +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package main + +import "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/spf13/cobra/cobra/cmd" + +func main() { + cmd.Execute() +} diff --git a/Godeps/_workspace/src/github.com/spf13/cobra/cobra_test.go b/Godeps/_workspace/src/github.com/spf13/cobra/cobra_test.go index 080b8ddd..c4bcbd8c 100644 --- a/Godeps/_workspace/src/github.com/spf13/cobra/cobra_test.go +++ b/Godeps/_workspace/src/github.com/spf13/cobra/cobra_test.go @@ -4,18 +4,22 @@ import ( "bytes" "fmt" "os" + "reflect" "runtime" "strings" "testing" + "text/template" + + "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/spf13/pflag" ) var _ = fmt.Println var _ = os.Stderr -var tp, te, tt, t1 []string +var tp, te, tt, t1, tr []string var rootPersPre, echoPre, echoPersPre, timesPersPre []string var flagb1, flagb2, flagb3, flagbr, flagbp bool -var flags1, flags2a, flags2b, flags3 string +var flags1, flags2a, flags2b, flags3, outs string var flagi1, flagi2, flagi3, flagir int var globalFlag1 bool var flagEcho, rootcalled bool @@ -24,6 +28,16 @@ var versionUsed int const strtwoParentHelp = "help message for parent flag strtwo" const strtwoChildHelp = "help message for child flag strtwo" +var cmdHidden = &Command{ + Use: "hide [secret string to print]", + Short: "Print anything to screen (if command is known)", + Long: `an absolutely utterly useless command for testing.`, + Run: func(cmd *Command, args []string) { + outs = "hidden" + }, + Hidden: true, +} + var cmdPrint = &Command{ Use: "print [string to print]", Short: "Print anything to the screen", @@ -68,9 +82,10 @@ var cmdDeprecated = &Command{ } var cmdTimes = &Command{ - Use: "times [# times] [string to echo]", - Short: "Echo anything to the screen more times", - Long: `a slightly useless command for testing.`, + Use: "times [# times] [string to echo]", + SuggestFor: []string{"counts"}, + Short: "Echo anything to the screen more times", + Long: `a slightly useless command for testing.`, PersistentPreRun: func(cmd *Command, args []string) { timesPersPre = args }, @@ -81,7 +96,7 @@ var cmdTimes = &Command{ var cmdRootNoRun = &Command{ Use: "cobra-test", - Short: "The root can run it's own function", + Short: "The root can run its own function", Long: "The root description for help", PersistentPreRun: func(cmd *Command, args []string) { rootPersPre = args @@ -96,9 +111,10 @@ var cmdRootSameName = &Command{ var cmdRootWithRun = &Command{ Use: "cobra-test", - Short: "The root can run it's own function", + Short: "The root can run its own function", Long: "The root description for help", Run: func(cmd *Command, args []string) { + tr = args rootcalled = true }, } @@ -127,6 +143,12 @@ var cmdVersion2 = &Command{ }, } +var cmdColon = &Command{ + Use: "cmd:colon", + Run: func(cmd *Command, args []string) { + }, +} + func flagInit() { cmdEcho.ResetFlags() cmdPrint.ResetFlags() @@ -181,7 +203,7 @@ func initializeWithSameName() *Command { func initializeWithRootCmd() *Command { cmdRootWithRun.ResetCommands() - tt, tp, te, rootcalled = nil, nil, nil, false + tt, tp, te, tr, rootcalled = nil, nil, nil, nil, false flagInit() cmdRootWithRun.Flags().BoolVarP(&flagbr, "boolroot", "b", false, "help message for flag boolroot") cmdRootWithRun.Flags().IntVarP(&flagir, "introot", "i", 321, "help message for flag introot") @@ -201,6 +223,13 @@ func fullSetupTest(input string) resulter { return fullTester(c, input) } +func noRRSetupTestSilenced(input string) resulter { + c := initialize() + c.SilenceErrors = true + c.SilenceUsage = true + return fullTester(c, input) +} + func noRRSetupTest(input string) resulter { c := initialize() @@ -225,6 +254,18 @@ func simpleTester(c *Command, input string) resulter { return resulter{err, output, c} } +func simpleTesterC(c *Command, input string) resulter { + buf := new(bytes.Buffer) + // Testing flag with invalid input + c.SetOutput(buf) + c.SetArgs(strings.Split(input, " ")) + + cmd, err := c.ExecuteC() + output := buf.String() + + return resulter{err, output, cmd} +} + func fullTester(c *Command, input string) resulter { buf := new(bytes.Buffer) // Testing flag with invalid input @@ -250,16 +291,24 @@ func logErr(t *testing.T, found, expected string) { t.Errorf(out.String()) } +func checkStringContains(t *testing.T, found, expected string) { + if !strings.Contains(found, expected) { + logErr(t, found, expected) + } +} + func checkResultContains(t *testing.T, x resulter, check string) { - if !strings.Contains(x.Output, check) { - logErr(t, x.Output, check) + checkStringContains(t, x.Output, check) +} + +func checkStringOmits(t *testing.T, found, expected string) { + if strings.Contains(found, expected) { + logErr(t, found, expected) } } func checkResultOmits(t *testing.T, x resulter, check string) { - if strings.Contains(x.Output, check) { - logErr(t, x.Output, check) - } + checkStringOmits(t, x.Output, check) } func checkOutputContains(t *testing.T, c *Command, check string) { @@ -374,9 +423,30 @@ func TestChildSameName(t *testing.T) { } } -func TestFlagLong(t *testing.T) { - noRRSetupTest("echo --intone=13 something here") +func TestGrandChildSameName(t *testing.T) { + c := initializeWithSameName() + cmdTimes.AddCommand(cmdPrint) + c.AddCommand(cmdTimes) + c.SetArgs(strings.Split("times print one two", " ")) + c.Execute() + if te != nil || tt != nil { + t.Error("Wrong command called") + } + if tp == nil { + t.Error("Wrong command called") + } + if strings.Join(tp, " ") != "one two" { + t.Error("Command didn't parse correctly") + } +} + +func TestFlagLong(t *testing.T) { + noRRSetupTest("echo --intone=13 something -- here") + + if cmdEcho.ArgsLenAtDash() != 1 { + t.Errorf("expected argsLenAtDash: %d but got %d", 1, cmdRootNoRun.ArgsLenAtDash()) + } if strings.Join(te, " ") != "something here" { t.Errorf("flags didn't leave proper args remaining..%s given", te) } @@ -389,8 +459,11 @@ func TestFlagLong(t *testing.T) { } func TestFlagShort(t *testing.T) { - noRRSetupTest("echo -i13 something here") + noRRSetupTest("echo -i13 -- something here") + if cmdEcho.ArgsLenAtDash() != 0 { + t.Errorf("expected argsLenAtDash: %d but got %d", 0, cmdRootNoRun.ArgsLenAtDash()) + } if strings.Join(te, " ") != "something here" { t.Errorf("flags didn't leave proper args remaining..%s given", te) } @@ -440,8 +513,8 @@ func TestChildCommandFlags(t *testing.T) { t.Errorf("invalid flag should generate error") } - if !strings.Contains(r.Output, "unknown shorthand") { - t.Errorf("Wrong error message displayed, \n %s", r.Output) + if !strings.Contains(r.Error.Error(), "unknown shorthand") { + t.Errorf("Wrong error message displayed, \n %s", r.Error) } if flagi2 != 99 { @@ -458,9 +531,8 @@ func TestChildCommandFlags(t *testing.T) { if r.Error == nil { t.Errorf("invalid flag should generate error") } - - if !strings.Contains(r.Output, "unknown shorthand flag") { - t.Errorf("Wrong error message displayed, \n %s", r.Output) + if !strings.Contains(r.Error.Error(), "unknown shorthand flag") { + t.Errorf("Wrong error message displayed, \n %s", r.Error) } // Testing with persistent flag overwritten by child @@ -480,9 +552,8 @@ func TestChildCommandFlags(t *testing.T) { if r.Error == nil { t.Errorf("invalid input should generate error") } - - if !strings.Contains(r.Output, "invalid argument \"10E\" for -i10E") { - t.Errorf("Wrong error message displayed, \n %s", r.Output) + if !strings.Contains(r.Error.Error(), "invalid argument \"10E\" for i10E") { + t.Errorf("Wrong error message displayed, \n %s", r.Error) } } @@ -494,21 +565,56 @@ func TestTrailingCommandFlags(t *testing.T) { } } -func TestInvalidSubCommandFlags(t *testing.T) { +func TestInvalidSubcommandFlags(t *testing.T) { cmd := initializeWithRootCmd() cmd.AddCommand(cmdTimes) result := simpleTester(cmd, "times --inttwo=2 --badflag=bar") - - checkResultContains(t, result, "unknown flag: --badflag") - - if strings.Contains(result.Output, "unknown flag: --inttwo") { + // given that we are not checking here result.Error we check for + // stock usage message + checkResultContains(t, result, "cobra-test times [# times]") + if strings.Contains(result.Error.Error(), "unknown flag: --inttwo") { t.Errorf("invalid --badflag flag shouldn't fail on 'unknown' --inttwo flag") } } -func TestSubCommandArgEvaluation(t *testing.T) { +func TestSubcommandExecuteC(t *testing.T) { + cmd := initializeWithRootCmd() + double := &Command{ + Use: "double message", + Run: func(c *Command, args []string) { + msg := strings.Join(args, " ") + c.Println(msg, msg) + }, + } + + echo := &Command{ + Use: "echo message", + Run: func(c *Command, args []string) { + msg := strings.Join(args, " ") + c.Println(msg, msg) + }, + } + + cmd.AddCommand(double, echo) + + result := simpleTesterC(cmd, "double hello world") + checkResultContains(t, result, "hello world hello world") + + if result.Command.Name() != "double" { + t.Errorf("invalid cmd returned from ExecuteC: should be 'double' but got %s", result.Command.Name()) + } + + result = simpleTesterC(cmd, "echo msg to be echoed") + checkResultContains(t, result, "msg to be echoed") + + if result.Command.Name() != "echo" { + t.Errorf("invalid cmd returned from ExecuteC: should be 'echo' but got %s", result.Command.Name()) + } +} + +func TestSubcommandArgEvaluation(t *testing.T) { cmd := initializeWithRootCmd() first := &Command{ @@ -537,7 +643,7 @@ func TestSubCommandArgEvaluation(t *testing.T) { func TestPersistentFlags(t *testing.T) { fullSetupTest("echo -s something -p more here") - // persistentFlag should act like normal flag on it's own command + // persistentFlag should act like normal flag on its own command if strings.Join(te, " ") != "more here" { t.Errorf("flags didn't leave proper args remaining..%s given", te) } @@ -548,7 +654,7 @@ func TestPersistentFlags(t *testing.T) { t.Errorf("persistent bool flag not parsed correctly. Expected true, had %v", flagbp) } - // persistentFlag should act like normal flag on it's own command + // persistentFlag should act like normal flag on its own command fullSetupTest("echo times -s again -c -p test here") if strings.Join(tt, " ") != "test here" { @@ -591,10 +697,38 @@ func TestNonRunChildHelp(t *testing.T) { } func TestRunnableRootCommand(t *testing.T) { - fullSetupTest("") + x := fullSetupTest("") if rootcalled != true { - t.Errorf("Root Function was not called") + t.Errorf("Root Function was not called\n out:%v", x.Error) + } +} + +func TestVisitParents(t *testing.T) { + c := &Command{Use: "app"} + sub := &Command{Use: "sub"} + dsub := &Command{Use: "dsub"} + sub.AddCommand(dsub) + c.AddCommand(sub) + total := 0 + add := func(x *Command) { + total++ + } + sub.VisitParents(add) + if total != 1 { + t.Errorf("Should have visited 1 parent but visited %d", total) + } + + total = 0 + dsub.VisitParents(add) + if total != 2 { + t.Errorf("Should have visited 2 parent but visited %d", total) + } + + total = 0 + c.VisitParents(add) + if total != 0 { + t.Errorf("Should have not visited any parent but visited %d", total) } } @@ -609,7 +743,10 @@ func TestRunnableRootCommandNilInput(t *testing.T) { c.AddCommand(cmdPrint, cmdEcho) c.SetArgs(empty_arg) - c.Execute() + err := c.Execute() + if err != nil { + t.Errorf("Execute() failed with %v", err) + } if rootcalled != true { t.Errorf("Root Function was not called") @@ -746,11 +883,69 @@ func TestRootNoCommandHelp(t *testing.T) { func TestRootUnknownCommand(t *testing.T) { r := noRRSetupTest("bogus") - s := "Error: unknown command \"bogus\"\nRun 'cobra-test help' for usage.\n" + s := "Error: unknown command \"bogus\" for \"cobra-test\"\nRun 'cobra-test --help' for usage.\n" if r.Output != s { t.Errorf("Unexpected response.\nExpecting to be:\n %q\nGot:\n %q\n", s, r.Output) } + + r = noRRSetupTest("--strtwo=a bogus") + if r.Output != s { + t.Errorf("Unexpected response.\nExpecting to be:\n %q\nGot:\n %q\n", s, r.Output) + } +} + +func TestRootUnknownCommandSilenced(t *testing.T) { + r := noRRSetupTestSilenced("bogus") + s := "Run 'cobra-test --help' for usage.\n" + + if r.Output != "" { + t.Errorf("Unexpected response.\nExpecting to be: \n\"\"\n Got:\n %q\n", s, r.Output) + } + + r = noRRSetupTestSilenced("--strtwo=a bogus") + if r.Output != "" { + t.Errorf("Unexpected response.\nExpecting to be:\n\"\"\nGot:\n %q\n", s, r.Output) + } +} + +func TestRootSuggestions(t *testing.T) { + outputWithSuggestions := "Error: unknown command \"%s\" for \"cobra-test\"\n\nDid you mean this?\n\t%s\n\nRun 'cobra-test --help' for usage.\n" + outputWithoutSuggestions := "Error: unknown command \"%s\" for \"cobra-test\"\nRun 'cobra-test --help' for usage.\n" + + cmd := initializeWithRootCmd() + cmd.AddCommand(cmdTimes) + + tests := map[string]string{ + "time": "times", + "tiems": "times", + "tims": "times", + "timeS": "times", + "rimes": "times", + "ti": "times", + "t": "times", + "timely": "times", + "ri": "", + "timezone": "", + "foo": "", + "counts": "times", + } + + for typo, suggestion := range tests { + for _, suggestionsDisabled := range []bool{false, true} { + cmd.DisableSuggestions = suggestionsDisabled + result := simpleTester(cmd, typo) + expected := "" + if len(suggestion) == 0 || suggestionsDisabled { + expected = fmt.Sprintf(outputWithoutSuggestions, typo) + } else { + expected = fmt.Sprintf(outputWithSuggestions, typo, suggestion) + } + if result.Output != expected { + t.Errorf("Unexpected response.\nExpecting to be:\n %q\nGot:\n %q\n", expected, result.Output) + } + } + } } func TestFlagsBeforeCommand(t *testing.T) { @@ -777,8 +972,8 @@ func TestFlagsBeforeCommand(t *testing.T) { // With parsing error properly reported x = fullSetupTest("-i10E echo") - if !strings.Contains(x.Output, "invalid argument \"10E\" for -i10E") { - t.Errorf("Wrong error message displayed, \n %s", x.Output) + if !strings.Contains(x.Error.Error(), "invalid argument \"10E\" for i10E") { + t.Errorf("Wrong error message displayed, \n %s", x.Error) } //With quotes @@ -819,6 +1014,31 @@ func TestRemoveCommand(t *testing.T) { } } +func TestCommandWithoutSubcommands(t *testing.T) { + c := initializeWithRootCmd() + + x := simpleTester(c, "") + if x.Error != nil { + t.Errorf("Calling command without subcommands should not have error: %v", x.Error) + return + } +} + +func TestCommandWithoutSubcommandsWithArg(t *testing.T) { + c := initializeWithRootCmd() + expectedArgs := []string{"arg"} + + x := simpleTester(c, "arg") + if x.Error != nil { + t.Errorf("Calling command without subcommands but with arg should not have error: %v", x.Error) + return + } + if !reflect.DeepEqual(expectedArgs, tr) { + t.Errorf("Calling command without subcommands but with arg has wrong args: expected: %v, actual: %v", expectedArgs, tr) + return + } +} + func TestReplaceCommandWithRemove(t *testing.T) { versionUsed = 0 c := initializeWithRootCmd() @@ -869,3 +1089,81 @@ func TestPreRun(t *testing.T) { t.Error("Wrong *Pre functions called!") } } + +// Check if cmdEchoSub gets PersistentPreRun from rootCmd even if is added last +func TestPeristentPreRunPropagation(t *testing.T) { + rootCmd := initialize() + + // First add the cmdEchoSub to cmdPrint + cmdPrint.AddCommand(cmdEchoSub) + // Now add cmdPrint to rootCmd + rootCmd.AddCommand(cmdPrint) + + rootCmd.SetArgs(strings.Split("print echosub lala", " ")) + rootCmd.Execute() + + if rootPersPre == nil || len(rootPersPre) == 0 || rootPersPre[0] != "lala" { + t.Error("RootCmd PersistentPreRun not called but should have been") + } +} + +func TestGlobalNormFuncPropagation(t *testing.T) { + normFunc := func(f *pflag.FlagSet, name string) pflag.NormalizedName { + return pflag.NormalizedName(name) + } + + rootCmd := initialize() + rootCmd.SetGlobalNormalizationFunc(normFunc) + if reflect.ValueOf(normFunc) != reflect.ValueOf(rootCmd.GlobalNormalizationFunc()) { + t.Error("rootCmd seems to have a wrong normalization function") + } + + // First add the cmdEchoSub to cmdPrint + cmdPrint.AddCommand(cmdEchoSub) + if cmdPrint.GlobalNormalizationFunc() != nil && cmdEchoSub.GlobalNormalizationFunc() != nil { + t.Error("cmdPrint and cmdEchoSub should had no normalization functions") + } + + // Now add cmdPrint to rootCmd + rootCmd.AddCommand(cmdPrint) + if reflect.ValueOf(cmdPrint.GlobalNormalizationFunc()).Pointer() != reflect.ValueOf(rootCmd.GlobalNormalizationFunc()).Pointer() || + reflect.ValueOf(cmdEchoSub.GlobalNormalizationFunc()).Pointer() != reflect.ValueOf(rootCmd.GlobalNormalizationFunc()).Pointer() { + t.Error("cmdPrint and cmdEchoSub should had the normalization function of rootCmd") + } +} + +func TestFlagOnPflagCommandLine(t *testing.T) { + flagName := "flagOnCommandLine" + pflag.CommandLine.String(flagName, "", "about my flag") + r := fullSetupTest("--help") + + checkResultContains(t, r, flagName) +} + +func TestAddTemplateFunctions(t *testing.T) { + AddTemplateFunc("t", func() bool { return true }) + AddTemplateFuncs(template.FuncMap{ + "f": func() bool { return false }, + "h": func() string { return "Hello," }, + "w": func() string { return "world." }}) + + const usage = "Hello, world." + + c := &Command{} + c.SetUsageTemplate(`{{if t}}{{h}}{{end}}{{if f}}{{h}}{{end}} {{w}}`) + + if us := c.UsageString(); us != usage { + t.Errorf("c.UsageString() != \"%s\", is \"%s\"", usage, us) + } +} + +func TestUsageIsNotPrintedTwice(t *testing.T) { + var cmd = &Command{Use: "root"} + var sub = &Command{Use: "sub"} + cmd.AddCommand(sub) + + r := simpleTester(cmd, "") + if strings.Count(r.Output, "Usage:") != 1 { + t.Error("Usage output is not printed exactly once") + } +} diff --git a/Godeps/_workspace/src/github.com/spf13/cobra/command.go b/Godeps/_workspace/src/github.com/spf13/cobra/command.go index f578f365..4e3fce33 100644 --- a/Godeps/_workspace/src/github.com/spf13/cobra/command.go +++ b/Godeps/_workspace/src/github.com/spf13/cobra/command.go @@ -18,13 +18,12 @@ package cobra import ( "bytes" "fmt" - "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/inconshreveable/mousetrap" - flag "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/spf13/pflag" "io" "os" - "runtime" + "path/filepath" "strings" - "time" + + flag "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/spf13/pflag" ) // Command is just that, a command for your application. @@ -38,6 +37,8 @@ type Command struct { Use string // An array of aliases that can be used instead of the first word in Use. Aliases []string + // An array of command names for which this command will be suggested - similar to aliases but only suggests. + SuggestFor []string // The short description shown in the 'help' output. Short string // The long message shown in the 'help ' output. @@ -50,12 +51,18 @@ type Command struct { BashCompletionFunction string // Is this command deprecated and should print this string when used? Deprecated string + // Is this command hidden and should NOT show up in the list of available commands? + Hidden bool // Full set of flags flags *flag.FlagSet // Set of flags childrens of this command will inherit pflags *flag.FlagSet // Flags that are declared specifically by this command (not inherited). lflags *flag.FlagSet + // SilenceErrors is an option to quiet errors down stream + SilenceErrors bool + // Silence Usage is an option to silence usage when an error occurs. + SilenceUsage bool // The *Run functions are executed in the following order: // * PersistentPreRun() // * PreRun() @@ -65,14 +72,26 @@ type Command struct { // All functions get the same args, the arguments after the command name // PersistentPreRun: children of this command will inherit and execute PersistentPreRun func(cmd *Command, args []string) + // PersistentPreRunE: PersistentPreRun but returns an error + PersistentPreRunE func(cmd *Command, args []string) error // PreRun: children of this command will not inherit. PreRun func(cmd *Command, args []string) + // PreRunE: PreRun but returns an error + PreRunE func(cmd *Command, args []string) error // Run: Typically the actual work function. Most commands will only implement this Run func(cmd *Command, args []string) + // RunE: Run but returns an error + RunE func(cmd *Command, args []string) error // PostRun: run after the Run command. PostRun func(cmd *Command, args []string) + // PostRunE: PostRun but returns an error + PostRunE func(cmd *Command, args []string) error // PersistentPostRun: children of this command will inherit and execute after PostRun PersistentPostRun func(cmd *Command, args []string) + // PersistentPostRunE: PersistentPostRun but returns an error + PersistentPostRunE func(cmd *Command, args []string) error + // DisableAutoGenTag remove + DisableAutoGenTag bool // Commands is the list of commands supported by this program. commands []*Command // Parent Command for this command @@ -83,7 +102,6 @@ type Command struct { commandsMaxNameLen int flagErrorBuf *bytes.Buffer - cmdErrorBuf *bytes.Buffer args []string // actual args parsed from flags output *io.Writer // nil means stderr; use Out() method instead @@ -92,7 +110,13 @@ type Command struct { helpTemplate string // Can be defined by Application helpFunc func(*Command, []string) // Help can be defined by application helpCommand *Command // The help command - helpFlagVal bool + // The global normalization function that we can use on every pFlag set and children commands + globNormFunc func(f *flag.FlagSet, name string) flag.NormalizedName + + // Disable the suggestions based on Levenshtein distance that go along with 'unknown command' messages + DisableSuggestions bool + // If displaying suggestions, allows to set the minimum levenshtein distance to display, must be > 0 + SuggestionsMinimumDistance int } // os.Args[1:] by default, if desired, can be overridden @@ -151,6 +175,18 @@ func (c *Command) SetHelpTemplate(s string) { c.helpTemplate = s } +// SetGlobalNormalizationFunc sets a normalization function to all flag sets and also to child commands. +// The user should not have a cyclic dependency on commands. +func (c *Command) SetGlobalNormalizationFunc(n func(f *flag.FlagSet, name string) flag.NormalizedName) { + c.Flags().SetNormalizeFunc(n) + c.PersistentFlags().SetNormalizeFunc(n) + c.globNormFunc = n + + for _, command := range c.commands { + command.SetGlobalNormalizationFunc(n) + } +} + func (c *Command) UsageFunc() (f func(*Command) error) { if c.usageFunc != nil { return c.usageFunc @@ -161,36 +197,28 @@ func (c *Command) UsageFunc() (f func(*Command) error) { } else { return func(c *Command) error { err := tmpl(c.Out(), c.UsageTemplate(), c) + if err != nil { + fmt.Print(err) + } return err } } } + +// HelpFunc returns either the function set by SetHelpFunc for this command +// or a parent, or it returns a function which calls c.Help() func (c *Command) HelpFunc() func(*Command, []string) { - if c.helpFunc != nil { - return c.helpFunc + cmd := c + for cmd != nil { + if cmd.helpFunc != nil { + return cmd.helpFunc + } + cmd = cmd.parent } - - if c.HasParent() { - return c.parent.HelpFunc() - } else { - return func(c *Command, args []string) { - if len(args) == 0 { - // Help called without any topic, calling on root - c.Root().Help() - return - } - - cmd, _, e := c.Root().Find(args) - if cmd == nil || e != nil { - c.Printf("Unknown help topic %#q.", args) - - c.Root().Usage() - } else { - err := cmd.Help() - if err != nil { - c.Println(err) - } - } + return func(*Command, []string) { + err := c.Help() + if err != nil { + c.Println(err) } } } @@ -234,8 +262,7 @@ func (c *Command) UsageTemplate() string { if c.HasParent() { return c.parent.UsageTemplate() } else { - return `{{ $cmd := . }} -Usage: {{if .Runnable}} + return `Usage:{{if .Runnable}} {{.UseLine}}{{if .HasFlags}} [flags]{{end}}{{end}}{{if .HasSubCommands}} {{ .CommandPath}} [command]{{end}}{{if gt .Aliases 0}} @@ -244,22 +271,22 @@ Aliases: {{end}}{{if .HasExample}} Examples: -{{ .Example }} -{{end}}{{ if .HasRunnableSubCommands}} +{{ .Example }}{{end}}{{ if .HasAvailableSubCommands}} -Available Commands: {{range .Commands}}{{if and (.Runnable) (not .Deprecated)}} - {{rpad .Name .NamePadding }} {{.Short}}{{end}}{{end}} -{{end}} -{{ if .HasLocalFlags}}Flags: -{{.LocalFlags.FlagUsages}}{{end}} -{{ if .HasInheritedFlags}}Global Flags: -{{.InheritedFlags.FlagUsages}}{{end}}{{if or (.HasHelpSubCommands) (.HasRunnableSiblings)}} -Additional help topics: -{{if .HasHelpSubCommands}}{{range .Commands}}{{if and (not .Runnable) (not .Deprecated)}} {{rpad .CommandPath .CommandPathPadding}} {{.Short}}{{end}}{{end}}{{end}}{{if .HasRunnableSiblings }}{{range .Parent.Commands}}{{if and (not .Runnable) (not .Deprecated)}}{{if not (eq .Name $cmd.Name) }} - {{rpad .CommandPath .CommandPathPadding}} {{.Short}}{{end}}{{end}}{{end}}{{end}} -{{end}}{{ if .HasSubCommands }} -Use "{{.Root.Name}} help [command]" for more information about a command. -{{end}}` +Available Commands:{{range .Commands}}{{if .IsAvailableCommand}} + {{rpad .Name .NamePadding }} {{.Short}}{{end}}{{end}}{{end}}{{ if .HasLocalFlags}} + +Flags: +{{.LocalFlags.FlagUsages | trimRightSpace}}{{end}}{{ if .HasInheritedFlags}} + +Global Flags: +{{.InheritedFlags.FlagUsages | trimRightSpace}}{{end}}{{if .HasHelpSubCommands}} + +Additional help topics:{{range .Commands}}{{if .IsHelpCommand}} + {{rpad .CommandPath .CommandPathPadding}} {{.Short}}{{end}}{{end}}{{end}}{{ if .HasSubCommands }} + +Use "{{.CommandPath}} [command] --help" for more information about a command.{{end}} +` } } @@ -271,9 +298,9 @@ func (c *Command) HelpTemplate() string { if c.HasParent() { return c.parent.HelpTemplate() } else { - return `{{with or .Long .Short }}{{. | trim}}{{end}} -{{if or .Runnable .HasSubCommands}}{{.UsageString}}{{end}} -` + return `{{with or .Long .Short }}{{. | trim}} + +{{end}}{{if or .Runnable .HasSubCommands}}{{.UsageString}}{{end}}` } } @@ -360,71 +387,128 @@ func argsMinusFirstX(args []string, x string) []string { // find the target command given the args and command tree // Meant to be run on the highest node. Only searches down. -func (c *Command) Find(arrs []string) (*Command, []string, error) { +func (c *Command) Find(args []string) (*Command, []string, error) { if c == nil { return nil, nil, fmt.Errorf("Called find() on a nil Command") } - if len(arrs) == 0 { - return c.Root(), arrs, nil - } - var innerfind func(*Command, []string) (*Command, []string) - innerfind = func(c *Command, args []string) (*Command, []string) { - if len(args) > 0 && c.HasSubCommands() { - argsWOflags := stripFlags(args, c) - if len(argsWOflags) > 0 { - matches := make([]*Command, 0) - for _, cmd := range c.commands { - if cmd.Name() == argsWOflags[0] || cmd.HasAlias(argsWOflags[0]) { // exact name or alias match - return innerfind(cmd, argsMinusFirstX(args, argsWOflags[0])) - } else if EnablePrefixMatching { - if strings.HasPrefix(cmd.Name(), argsWOflags[0]) { // prefix match - matches = append(matches, cmd) - } - for _, x := range cmd.Aliases { - if strings.HasPrefix(x, argsWOflags[0]) { - matches = append(matches, cmd) - } - } - } + innerfind = func(c *Command, innerArgs []string) (*Command, []string) { + argsWOflags := stripFlags(innerArgs, c) + if len(argsWOflags) == 0 { + return c, innerArgs + } + nextSubCmd := argsWOflags[0] + matches := make([]*Command, 0) + for _, cmd := range c.commands { + if cmd.Name() == nextSubCmd || cmd.HasAlias(nextSubCmd) { // exact name or alias match + return innerfind(cmd, argsMinusFirstX(innerArgs, nextSubCmd)) + } + if EnablePrefixMatching { + if strings.HasPrefix(cmd.Name(), nextSubCmd) { // prefix match + matches = append(matches, cmd) } - - // only accept a single prefix match - multiple matches would be ambiguous - if len(matches) == 1 { - return innerfind(matches[0], argsMinusFirstX(args, argsWOflags[0])) + for _, x := range cmd.Aliases { + if strings.HasPrefix(x, nextSubCmd) { + matches = append(matches, cmd) + } } } } - return c, args + // only accept a single prefix match - multiple matches would be ambiguous + if len(matches) == 1 { + return innerfind(matches[0], argsMinusFirstX(innerArgs, argsWOflags[0])) + } + + return c, innerArgs } - commandFound, a := innerfind(c, arrs) + commandFound, a := innerfind(c, args) + argsWOflags := stripFlags(a, commandFound) - // If we matched on the root, but we asked for a subcommand, return an error - if commandFound.Name() == c.Name() && len(stripFlags(arrs, c)) > 0 && commandFound.Name() != arrs[0] { - return nil, a, fmt.Errorf("unknown command %q", a[0]) + // no subcommand, always take args + if !commandFound.HasSubCommands() { + return commandFound, a, nil + } + + // root command with subcommands, do subcommand checking + if commandFound == c && len(argsWOflags) > 0 { + suggestionsString := "" + if !c.DisableSuggestions { + if c.SuggestionsMinimumDistance <= 0 { + c.SuggestionsMinimumDistance = 2 + } + if suggestions := c.SuggestionsFor(argsWOflags[0]); len(suggestions) > 0 { + suggestionsString += "\n\nDid you mean this?\n" + for _, s := range suggestions { + suggestionsString += fmt.Sprintf("\t%v\n", s) + } + } + } + return commandFound, a, fmt.Errorf("unknown command %q for %q%s", argsWOflags[0], commandFound.CommandPath(), suggestionsString) } return commandFound, a, nil } +func (c *Command) SuggestionsFor(typedName string) []string { + suggestions := []string{} + for _, cmd := range c.commands { + if cmd.IsAvailableCommand() { + levenshteinDistance := ld(typedName, cmd.Name(), true) + suggestByLevenshtein := levenshteinDistance <= c.SuggestionsMinimumDistance + suggestByPrefix := strings.HasPrefix(strings.ToLower(cmd.Name()), strings.ToLower(typedName)) + if suggestByLevenshtein || suggestByPrefix { + suggestions = append(suggestions, cmd.Name()) + } + for _, explicitSuggestion := range cmd.SuggestFor { + if strings.EqualFold(typedName, explicitSuggestion) { + suggestions = append(suggestions, cmd.Name()) + } + } + } + } + return suggestions +} + +func (c *Command) VisitParents(fn func(*Command)) { + var traverse func(*Command) *Command + + traverse = func(x *Command) *Command { + if x != c { + fn(x) + } + if x.HasParent() { + return traverse(x.parent) + } + return x + } + traverse(c) +} + func (c *Command) Root() *Command { var findRoot func(*Command) *Command findRoot = func(x *Command) *Command { if x.HasParent() { return findRoot(x.parent) - } else { - return x } + return x } return findRoot(c) } +// ArgsLenAtDash will return the length of f.Args at the moment when a -- was +// found during arg parsing. This allows your program to know which args were +// before the -- and which came after. (Description from +// https://godoc.org/github.com/spf13/pflag#FlagSet.ArgsLenAtDash). +func (c *Command) ArgsLenAtDash() int { + return c.Flags().ArgsLenAtDash() +} + func (c *Command) execute(a []string) (err error) { if c == nil { return fmt.Errorf("Called Execute() on a nil Command") @@ -434,49 +518,73 @@ func (c *Command) execute(a []string) (err error) { c.Printf("Command %q is deprecated, %s\n", c.Name(), c.Deprecated) } + // initialize help flag as the last point possible to allow for user + // overriding + c.initHelpFlag() + err = c.ParseFlags(a) - if err == flag.ErrHelp { - c.Help() - return nil - } if err != nil { - // We're writing subcommand usage to root command's error buffer to have it displayed to the user - r := c.Root() - if r.cmdErrorBuf == nil { - r.cmdErrorBuf = new(bytes.Buffer) - } - // for writing the usage to the buffer we need to switch the output temporarily - // since Out() returns root output, you also need to revert that on root - out := r.Out() - r.SetOutput(r.cmdErrorBuf) - c.Usage() - r.SetOutput(out) return err } - // If help is called, regardless of other flags, we print that. - // Print help also if c.Run is nil. - if c.helpFlagVal || !c.Runnable() { - c.Help() - return nil + // If help is called, regardless of other flags, return we want help + // Also say we need help if the command isn't runnable. + helpVal, err := c.Flags().GetBool("help") + if err != nil { + // should be impossible to get here as we always declare a help + // flag in initHelpFlag() + c.Println("\"help\" flag declared as non-bool. Please correct your code") + return err + } + if helpVal || !c.Runnable() { + return flag.ErrHelp } c.preRun() argWoFlags := c.Flags().Args() - if c.PersistentPreRun != nil { - c.PersistentPreRun(c, argWoFlags) + for p := c; p != nil; p = p.Parent() { + if p.PersistentPreRunE != nil { + if err := p.PersistentPreRunE(c, argWoFlags); err != nil { + return err + } + break + } else if p.PersistentPreRun != nil { + p.PersistentPreRun(c, argWoFlags) + break + } } - if c.PreRun != nil { + if c.PreRunE != nil { + if err := c.PreRunE(c, argWoFlags); err != nil { + return err + } + } else if c.PreRun != nil { c.PreRun(c, argWoFlags) } - c.Run(c, argWoFlags) - - if c.PostRun != nil { + if c.RunE != nil { + if err := c.RunE(c, argWoFlags); err != nil { + return err + } + } else { + c.Run(c, argWoFlags) + } + if c.PostRunE != nil { + if err := c.PostRunE(c, argWoFlags); err != nil { + return err + } + } else if c.PostRun != nil { c.PostRun(c, argWoFlags) } - if c.PersistentPostRun != nil { - c.PersistentPostRun(c, argWoFlags) + for p := c; p != nil; p = p.Parent() { + if p.PersistentPostRunE != nil { + if err := p.PersistentPostRunE(c, argWoFlags); err != nil { + return err + } + break + } else if p.PersistentPostRun != nil { + p.PersistentPostRun(c, argWoFlags) + break + } } return nil @@ -503,52 +611,80 @@ func (c *Command) errorMsgFromParse() string { // Call execute to use the args (os.Args[1:] by default) // and run through the command tree finding appropriate matches // for commands and then corresponding flags. -func (c *Command) Execute() (err error) { +func (c *Command) Execute() error { + _, err := c.ExecuteC() + return err +} + +func (c *Command) ExecuteC() (cmd *Command, err error) { // Regardless of what command execute is called on, run on Root only if c.HasParent() { - return c.Root().Execute() + return c.Root().ExecuteC() } - if EnableWindowsMouseTrap && runtime.GOOS == "windows" { - if mousetrap.StartedByExplorer() { - c.Print(MousetrapHelpText) - time.Sleep(5 * time.Second) - os.Exit(1) - } + // windows hook + if preExecHookFn != nil { + preExecHookFn(c) } // initialize help as the last point possible to allow for user // overriding - c.initHelp() + c.initHelpCmd() var args []string - if len(c.args) == 0 { + // Workaround FAIL with "go test -v" or "cobra.test -test.v", see #155 + if len(c.args) == 0 && filepath.Base(os.Args[0]) != "cobra.test" { args = os.Args[1:] } else { args = c.args } cmd, flags, err := c.Find(args) - if err == nil { - err = cmd.execute(flags) - } - if err != nil { - if err == flag.ErrHelp { - c.Help() - - } else { - c.Println("Error:", err.Error()) - c.Printf("Run '%v help' for usage.\n", c.Root().Name()) + // If found parse to a subcommand and then failed, talk about the subcommand + if cmd != nil { + c = cmd } + if !c.SilenceErrors { + c.Println("Error:", err.Error()) + c.Printf("Run '%v --help' for usage.\n", c.CommandPath()) + } + return c, err } + err = cmd.execute(flags) + if err != nil { + // Always show help if requested, even if SilenceErrors is in + // effect + if err == flag.ErrHelp { + cmd.HelpFunc()(cmd, args) + return cmd, nil + } - return + // If root command has SilentErrors flagged, + // all subcommands should respect it + if !cmd.SilenceErrors && !c.SilenceErrors { + c.Println("Error:", err.Error()) + } + + // If root command has SilentUsage flagged, + // all subcommands should respect it + if !cmd.SilenceUsage && !c.SilenceUsage { + c.Println(cmd.UsageString()) + } + return cmd, err + } + return cmd, nil } -func (c *Command) initHelp() { +func (c *Command) initHelpFlag() { + if c.Flags().Lookup("help") == nil { + c.Flags().BoolP("help", "h", false, "help for "+c.Name()) + } +} + +func (c *Command) initHelpCmd() { if c.helpCommand == nil { if !c.HasSubCommands() { return @@ -559,9 +695,19 @@ func (c *Command) initHelp() { Short: "Help about any command", Long: `Help provides help for any command in the application. Simply type ` + c.Name() + ` help [path to command] for full details.`, - Run: c.HelpFunc(), PersistentPreRun: func(cmd *Command, args []string) {}, PersistentPostRun: func(cmd *Command, args []string) {}, + + Run: func(c *Command, args []string) { + cmd, _, e := c.Root().Find(args) + if cmd == nil || e != nil { + c.Printf("Unknown help topic %#q.", args) + c.Root().Usage() + } else { + helpFunc := cmd.HelpFunc() + helpFunc(cmd, args) + } + }, } } c.AddCommand(c.helpCommand) @@ -571,8 +717,6 @@ func (c *Command) initHelp() { func (c *Command) ResetCommands() { c.commands = nil c.helpCommand = nil - c.cmdErrorBuf = new(bytes.Buffer) - c.cmdErrorBuf.Reset() } //Commands returns a slice of child commands. @@ -600,15 +744,11 @@ func (c *Command) AddCommand(cmds ...*Command) { if nameLen > c.commandsMaxNameLen { c.commandsMaxNameLen = nameLen } + // If global normalization function exists, update all children + if c.globNormFunc != nil { + x.SetGlobalNormalizationFunc(c.globNormFunc) + } c.commands = append(c.commands, x) - - // Pass on peristent pre/post functions to children - if x.PersistentPreRun == nil { - x.PersistentPreRun = c.PersistentPreRun - } - if x.PersistentPostRun == nil { - x.PersistentPostRun = c.PersistentPostRun - } } } @@ -788,7 +928,7 @@ func (c *Command) HasExample() bool { // Determine if the command is itself runnable func (c *Command) Runnable() bool { - return c.Run != nil + return c.Run != nil || c.RunE != nil } // Determine if the command has children commands @@ -796,34 +936,75 @@ func (c *Command) HasSubCommands() bool { return len(c.commands) > 0 } -func (c *Command) HasRunnableSiblings() bool { - if !c.HasParent() { +// IsAvailableCommand determines if a command is available as a non-help command +// (this includes all non deprecated/hidden commands) +func (c *Command) IsAvailableCommand() bool { + if len(c.Deprecated) != 0 || c.Hidden { return false } - for _, sub := range c.parent.commands { - if sub.Runnable() { - return true - } + + if c.HasParent() && c.Parent().helpCommand == c { + return false } + + if c.Runnable() || c.HasAvailableSubCommands() { + return true + } + return false } +// IsHelpCommand determines if a command is a 'help' command; a help command is +// determined by the fact that it is NOT runnable/hidden/deprecated, and has no +// sub commands that are runnable/hidden/deprecated +func (c *Command) IsHelpCommand() bool { + + // if a command is runnable, deprecated, or hidden it is not a 'help' command + if c.Runnable() || len(c.Deprecated) != 0 || c.Hidden { + return false + } + + // if any non-help sub commands are found, the command is not a 'help' command + for _, sub := range c.commands { + if !sub.IsHelpCommand() { + return false + } + } + + // the command either has no sub commands, or no non-help sub commands + return true +} + +// HasHelpSubCommands determines if a command has any avilable 'help' sub commands +// that need to be shown in the usage/help default template under 'additional help +// topics' func (c *Command) HasHelpSubCommands() bool { + + // return true on the first found available 'help' sub command for _, sub := range c.commands { - if !sub.Runnable() { + if sub.IsHelpCommand() { return true } } + + // the command either has no sub commands, or no available 'help' sub commands return false } -// Determine if the command has runnable children commands -func (c *Command) HasRunnableSubCommands() bool { +// HasAvailableSubCommands determines if a command has available sub commands that +// need to be shown in the usage/help default template under 'available commands' +func (c *Command) HasAvailableSubCommands() bool { + + // return true on the first found available (non deprecated/help/hidden) + // sub command for _, sub := range c.commands { - if sub.Runnable() { + if sub.IsAvailableCommand() { return true } } + + // the command either has no sub comamnds, or no available (non deprecated/help/hidden) + // sub commands return false } @@ -832,6 +1013,11 @@ func (c *Command) HasParent() bool { return c.parent != nil } +// GlobalNormalizationFunc returns the global normalization function or nil if doesn't exists +func (c *Command) GlobalNormalizationFunc() func(f *flag.FlagSet, name string) flag.NormalizedName { + return c.globNormFunc +} + // Get the complete FlagSet that applies to this command (local and persistent declared here and by all parents) func (c *Command) Flags() *flag.FlagSet { if c.flags == nil { @@ -840,7 +1026,6 @@ func (c *Command) Flags() *flag.FlagSet { c.flagErrorBuf = new(bytes.Buffer) } c.flags.SetOutput(c.flagErrorBuf) - c.PersistentFlags().BoolVarP(&c.helpFlagVal, "help", "h", false, "help for "+c.Name()) } return c.flags } @@ -853,6 +1038,13 @@ func (c *Command) LocalFlags() *flag.FlagSet { c.lflags.VisitAll(func(f *flag.Flag) { local.AddFlag(f) }) + if !c.HasParent() { + flag.CommandLine.VisitAll(func(f *flag.Flag) { + if local.Lookup(f.Name) == nil { + local.AddFlag(f) + } + }) + } return local } diff --git a/Godeps/_workspace/src/github.com/spf13/cobra/command_notwin.go b/Godeps/_workspace/src/github.com/spf13/cobra/command_notwin.go new file mode 100644 index 00000000..073dd353 --- /dev/null +++ b/Godeps/_workspace/src/github.com/spf13/cobra/command_notwin.go @@ -0,0 +1,5 @@ +// +build !windows + +package cobra + +var preExecHookFn func(*Command) = nil diff --git a/Godeps/_workspace/src/github.com/spf13/cobra/command_test.go b/Godeps/_workspace/src/github.com/spf13/cobra/command_test.go index 477d84e7..43ed7a34 100644 --- a/Godeps/_workspace/src/github.com/spf13/cobra/command_test.go +++ b/Godeps/_workspace/src/github.com/spf13/cobra/command_test.go @@ -5,6 +5,30 @@ import ( "testing" ) +// test to ensure hidden commands run as intended +func TestHiddenCommandExecutes(t *testing.T) { + + // ensure that outs does not already equal what the command will be setting it + // to, if it did this test would not actually be testing anything... + if outs == "hidden" { + t.Errorf("outs should NOT EQUAL hidden") + } + + cmdHidden.Execute() + + // upon running the command, the value of outs should now be 'hidden' + if outs != "hidden" { + t.Errorf("Hidden command failed to run!") + } +} + +// test to ensure hidden commands do not show up in usage/help text +func TestHiddenCommandIsHidden(t *testing.T) { + if cmdHidden.IsAvailableCommand() { + t.Errorf("Hidden command found!") + } +} + func TestStripFlags(t *testing.T) { tests := []struct { input []string diff --git a/Godeps/_workspace/src/github.com/spf13/cobra/command_win.go b/Godeps/_workspace/src/github.com/spf13/cobra/command_win.go new file mode 100644 index 00000000..09d59db5 --- /dev/null +++ b/Godeps/_workspace/src/github.com/spf13/cobra/command_win.go @@ -0,0 +1,26 @@ +// +build windows + +package cobra + +import ( + "os" + "time" + + "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/inconshreveable/mousetrap" +) + +var preExecHookFn = preExecHook + +// enables an information splash screen on Windows if the CLI is started from explorer.exe. +var MousetrapHelpText string = `This is a command line tool + +You need to open cmd.exe and run it from there. +` + +func preExecHook(c *Command) { + if mousetrap.StartedByExplorer() { + c.Print(MousetrapHelpText) + time.Sleep(5 * time.Second) + os.Exit(1) + } +} diff --git a/Godeps/_workspace/src/github.com/spf13/cobra/doc/cmd_test.go b/Godeps/_workspace/src/github.com/spf13/cobra/doc/cmd_test.go new file mode 100644 index 00000000..48934183 --- /dev/null +++ b/Godeps/_workspace/src/github.com/spf13/cobra/doc/cmd_test.go @@ -0,0 +1,145 @@ +package doc + +import ( + "bytes" + "fmt" + "runtime" + "strings" + "testing" + + "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/spf13/cobra" +) + +var flagb1, flagb2, flagb3, flagbr, flagbp bool +var flags1, flags2a, flags2b, flags3 string +var flagi1, flagi2, flagi3, flagir int + +const strtwoParentHelp = "help message for parent flag strtwo" +const strtwoChildHelp = "help message for child flag strtwo" + +var cmdEcho = &cobra.Command{ + Use: "echo [string to echo]", + Aliases: []string{"say"}, + Short: "Echo anything to the screen", + Long: `an utterly useless command for testing.`, + Example: "Just run cobra-test echo", +} + +var cmdEchoSub = &cobra.Command{ + Use: "echosub [string to print]", + Short: "second sub command for echo", + Long: `an absolutely utterly useless command for testing gendocs!.`, + Run: func(cmd *cobra.Command, args []string) {}, +} + +var cmdDeprecated = &cobra.Command{ + Use: "deprecated [can't do anything here]", + Short: "A command which is deprecated", + Long: `an absolutely utterly useless command for testing deprecation!.`, + Deprecated: "Please use echo instead", +} + +var cmdTimes = &cobra.Command{ + Use: "times [# times] [string to echo]", + SuggestFor: []string{"counts"}, + Short: "Echo anything to the screen more times", + Long: `a slightly useless command for testing.`, + PersistentPreRun: func(cmd *cobra.Command, args []string) {}, + Run: func(cmd *cobra.Command, args []string) {}, +} + +var cmdPrint = &cobra.Command{ + Use: "print [string to print]", + Short: "Print anything to the screen", + Long: `an absolutely utterly useless command for testing.`, +} + +var cmdRootNoRun = &cobra.Command{ + Use: "cobra-test", + Short: "The root can run its own function", + Long: "The root description for help", +} + +var cmdRootSameName = &cobra.Command{ + Use: "print", + Short: "Root with the same name as a subcommand", + Long: "The root description for help", +} + +var cmdRootWithRun = &cobra.Command{ + Use: "cobra-test", + Short: "The root can run its own function", + Long: "The root description for help", +} + +var cmdSubNoRun = &cobra.Command{ + Use: "subnorun", + Short: "A subcommand without a Run function", + Long: "A long output about a subcommand without a Run function", +} + +var cmdVersion1 = &cobra.Command{ + Use: "version", + Short: "Print the version number", + Long: `First version of the version command`, +} + +var cmdVersion2 = &cobra.Command{ + Use: "version", + Short: "Print the version number", + Long: `Second version of the version command`, +} + +func flagInit() { + cmdEcho.ResetFlags() + cmdPrint.ResetFlags() + cmdTimes.ResetFlags() + cmdRootNoRun.ResetFlags() + cmdRootSameName.ResetFlags() + cmdRootWithRun.ResetFlags() + cmdSubNoRun.ResetFlags() + cmdRootNoRun.PersistentFlags().StringVarP(&flags2a, "strtwo", "t", "two", strtwoParentHelp) + cmdEcho.Flags().IntVarP(&flagi1, "intone", "i", 123, "help message for flag intone") + cmdTimes.Flags().IntVarP(&flagi2, "inttwo", "j", 234, "help message for flag inttwo") + cmdPrint.Flags().IntVarP(&flagi3, "intthree", "i", 345, "help message for flag intthree") + cmdEcho.PersistentFlags().StringVarP(&flags1, "strone", "s", "one", "help message for flag strone") + cmdEcho.PersistentFlags().BoolVarP(&flagbp, "persistentbool", "p", false, "help message for flag persistentbool") + cmdTimes.PersistentFlags().StringVarP(&flags2b, "strtwo", "t", "2", strtwoChildHelp) + cmdPrint.PersistentFlags().StringVarP(&flags3, "strthree", "s", "three", "help message for flag strthree") + cmdEcho.Flags().BoolVarP(&flagb1, "boolone", "b", true, "help message for flag boolone") + cmdTimes.Flags().BoolVarP(&flagb2, "booltwo", "c", false, "help message for flag booltwo") + cmdPrint.Flags().BoolVarP(&flagb3, "boolthree", "b", true, "help message for flag boolthree") + cmdVersion1.ResetFlags() + cmdVersion2.ResetFlags() +} + +func initializeWithRootCmd() *cobra.Command { + cmdRootWithRun.ResetCommands() + flagInit() + cmdRootWithRun.Flags().BoolVarP(&flagbr, "boolroot", "b", false, "help message for flag boolroot") + cmdRootWithRun.Flags().IntVarP(&flagir, "introot", "i", 321, "help message for flag introot") + return cmdRootWithRun +} + +func checkStringContains(t *testing.T, found, expected string) { + if !strings.Contains(found, expected) { + logErr(t, found, expected) + } +} + +func checkStringOmits(t *testing.T, found, expected string) { + if strings.Contains(found, expected) { + logErr(t, found, expected) + } +} + +func logErr(t *testing.T, found, expected string) { + out := new(bytes.Buffer) + + _, _, line, ok := runtime.Caller(2) + if ok { + fmt.Fprintf(out, "Line: %d ", line) + } + fmt.Fprintf(out, "Unexpected response.\nExpecting to contain: \n %q\nGot:\n %q\n", expected, found) + t.Errorf(out.String()) +} diff --git a/Godeps/_workspace/src/github.com/spf13/cobra/doc/man_docs.go b/Godeps/_workspace/src/github.com/spf13/cobra/doc/man_docs.go new file mode 100644 index 00000000..129ed809 --- /dev/null +++ b/Godeps/_workspace/src/github.com/spf13/cobra/doc/man_docs.go @@ -0,0 +1,217 @@ +// Copyright 2015 Red Hat Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package doc + +import ( + "bytes" + "fmt" + "io" + "os" + "sort" + "strings" + "time" + + mangen "github.com/cpuguy83/go-md2man/md2man" + "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/spf13/cobra" + "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/spf13/pflag" +) + +// GenManTree will generate a man page for this command and all decendants +// in the directory given. The header may be nil. This function may not work +// correctly if your command names have - in them. If you have `cmd` with two +// subcmds, `sub` and `sub-third`. And `sub` has a subcommand called `third` +// it is undefined which help output will be in the file `cmd-sub-third.1`. +func GenManTree(cmd *cobra.Command, header *GenManHeader, dir string) error { + if header == nil { + header = &GenManHeader{} + } + for _, c := range cmd.Commands() { + if !c.IsAvailableCommand() || c.IsHelpCommand() { + continue + } + if err := GenManTree(c, header, dir); err != nil { + return err + } + } + needToResetTitle := header.Title == "" + + filename := cmd.CommandPath() + filename = dir + strings.Replace(filename, " ", "-", -1) + ".1" + f, err := os.Create(filename) + if err != nil { + return err + } + defer f.Close() + + if err := GenMan(cmd, header, f); err != nil { + return err + } + + if needToResetTitle { + header.Title = "" + } + return nil +} + +// GenManHeader is a lot like the .TH header at the start of man pages. These +// include the title, section, date, source, and manual. We will use the +// current time if Date if unset and will use "Auto generated by spf13/cobra" +// if the Source is unset. +type GenManHeader struct { + Title string + Section string + Date *time.Time + date string + Source string + Manual string +} + +// GenMan will generate a man page for the given command and write it to +// w. The header argument may be nil, however obviously w may not. +func GenMan(cmd *cobra.Command, header *GenManHeader, w io.Writer) error { + if header == nil { + header = &GenManHeader{} + } + b := genMan(cmd, header) + final := mangen.Render(b) + _, err := w.Write(final) + return err +} + +func fillHeader(header *GenManHeader, name string) { + if header.Title == "" { + header.Title = strings.ToUpper(strings.Replace(name, " ", "\\-", -1)) + } + if header.Section == "" { + header.Section = "1" + } + if header.Date == nil { + now := time.Now() + header.Date = &now + } + header.date = (*header.Date).Format("Jan 2006") + if header.Source == "" { + header.Source = "Auto generated by spf13/cobra" + } +} + +func manPreamble(out io.Writer, header *GenManHeader, name, short, long string) { + dashName := strings.Replace(name, " ", "-", -1) + fmt.Fprintf(out, `%% %s(%s)%s +%% %s +%% %s +# NAME +`, header.Title, header.Section, header.date, header.Source, header.Manual) + fmt.Fprintf(out, "%s \\- %s\n\n", dashName, short) + fmt.Fprintf(out, "# SYNOPSIS\n") + fmt.Fprintf(out, "**%s** [OPTIONS]\n\n", name) + fmt.Fprintf(out, "# DESCRIPTION\n") + fmt.Fprintf(out, "%s\n\n", long) +} + +func manPrintFlags(out io.Writer, flags *pflag.FlagSet) { + flags.VisitAll(func(flag *pflag.Flag) { + if len(flag.Deprecated) > 0 || flag.Hidden { + return + } + format := "" + if len(flag.Shorthand) > 0 { + format = "**-%s**, **--%s**" + } else { + format = "%s**--%s**" + } + if len(flag.NoOptDefVal) > 0 { + format = format + "[" + } + if flag.Value.Type() == "string" { + // put quotes on the value + format = format + "=%q" + } else { + format = format + "=%s" + } + if len(flag.NoOptDefVal) > 0 { + format = format + "]" + } + format = format + "\n\t%s\n\n" + fmt.Fprintf(out, format, flag.Shorthand, flag.Name, flag.DefValue, flag.Usage) + }) +} + +func manPrintOptions(out io.Writer, command *cobra.Command) { + flags := command.NonInheritedFlags() + if flags.HasFlags() { + fmt.Fprintf(out, "# OPTIONS\n") + manPrintFlags(out, flags) + fmt.Fprintf(out, "\n") + } + flags = command.InheritedFlags() + if flags.HasFlags() { + fmt.Fprintf(out, "# OPTIONS INHERITED FROM PARENT COMMANDS\n") + manPrintFlags(out, flags) + fmt.Fprintf(out, "\n") + } +} + +func genMan(cmd *cobra.Command, header *GenManHeader) []byte { + // something like `rootcmd subcmd1 subcmd2` + commandName := cmd.CommandPath() + // something like `rootcmd-subcmd1-subcmd2` + dashCommandName := strings.Replace(commandName, " ", "-", -1) + + fillHeader(header, commandName) + + buf := new(bytes.Buffer) + + short := cmd.Short + long := cmd.Long + if len(long) == 0 { + long = short + } + + manPreamble(buf, header, commandName, short, long) + manPrintOptions(buf, cmd) + if len(cmd.Example) > 0 { + fmt.Fprintf(buf, "# EXAMPLE\n") + fmt.Fprintf(buf, "```\n%s\n```\n", cmd.Example) + } + if hasSeeAlso(cmd) { + fmt.Fprintf(buf, "# SEE ALSO\n") + if cmd.HasParent() { + parentPath := cmd.Parent().CommandPath() + dashParentPath := strings.Replace(parentPath, " ", "-", -1) + fmt.Fprintf(buf, "**%s(%s)**", dashParentPath, header.Section) + cmd.VisitParents(func(c *cobra.Command) { + if c.DisableAutoGenTag { + cmd.DisableAutoGenTag = c.DisableAutoGenTag + } + }) + } + children := cmd.Commands() + sort.Sort(byName(children)) + for i, c := range children { + if !c.IsAvailableCommand() || c.IsHelpCommand() { + continue + } + if cmd.HasParent() || i > 0 { + fmt.Fprintf(buf, ", ") + } + fmt.Fprintf(buf, "**%s-%s(%s)**", dashCommandName, c.Name(), header.Section) + } + fmt.Fprintf(buf, "\n") + } + if !cmd.DisableAutoGenTag { + fmt.Fprintf(buf, "# HISTORY\n%s Auto generated by spf13/cobra\n", header.Date.Format("2-Jan-2006")) + } + return buf.Bytes() +} diff --git a/Godeps/_workspace/src/github.com/spf13/cobra/doc/man_docs.md b/Godeps/_workspace/src/github.com/spf13/cobra/doc/man_docs.md new file mode 100644 index 00000000..3408c301 --- /dev/null +++ b/Godeps/_workspace/src/github.com/spf13/cobra/doc/man_docs.md @@ -0,0 +1,25 @@ +# Generating Man Pages For Your Own cobra.Command + +Generating man pages from a cobra command is incredibly easy. An example is as follows: + +```go +package main + +import ( + "github.com/spf13/cobra" +) + +func main() { + cmd := &cobra.Command{ + Use: "test", + Short: "my test program", + } + header := &cobra.GenManHeader{ + Title: "MINE", + Section: "3", + } + cmd.GenManTree(header, "/tmp") +} +``` + +That will get you a man page `/tmp/test.1` diff --git a/Godeps/_workspace/src/github.com/spf13/cobra/doc/man_docs_test.go b/Godeps/_workspace/src/github.com/spf13/cobra/doc/man_docs_test.go new file mode 100644 index 00000000..3083125a --- /dev/null +++ b/Godeps/_workspace/src/github.com/spf13/cobra/doc/man_docs_test.go @@ -0,0 +1,97 @@ +package doc + +import ( + "bytes" + "fmt" + "os" + "strings" + "testing" +) + +var _ = fmt.Println +var _ = os.Stderr + +func translate(in string) string { + return strings.Replace(in, "-", "\\-", -1) +} + +func TestGenManDoc(t *testing.T) { + c := initializeWithRootCmd() + // Need two commands to run the command alphabetical sort + cmdEcho.AddCommand(cmdTimes, cmdEchoSub, cmdDeprecated) + c.AddCommand(cmdPrint, cmdEcho) + cmdRootWithRun.PersistentFlags().StringVarP(&flags2a, "rootflag", "r", "two", strtwoParentHelp) + + out := new(bytes.Buffer) + + header := &GenManHeader{ + Title: "Project", + Section: "2", + } + // We generate on a subcommand so we have both subcommands and parents + if err := GenMan(cmdEcho, header, out); err != nil { + t.Fatal(err) + } + found := out.String() + + // Make sure parent has - in CommandPath() in SEE ALSO: + parentPath := cmdEcho.Parent().CommandPath() + dashParentPath := strings.Replace(parentPath, " ", "-", -1) + expected := translate(dashParentPath) + expected = expected + "(" + header.Section + ")" + checkStringContains(t, found, expected) + + // Our description + expected = translate(cmdEcho.Name()) + checkStringContains(t, found, expected) + + // Better have our example + expected = translate(cmdEcho.Name()) + checkStringContains(t, found, expected) + + // A local flag + expected = "boolone" + checkStringContains(t, found, expected) + + // persistent flag on parent + expected = "rootflag" + checkStringContains(t, found, expected) + + // We better output info about our parent + expected = translate(cmdRootWithRun.Name()) + checkStringContains(t, found, expected) + + // And about subcommands + expected = translate(cmdEchoSub.Name()) + checkStringContains(t, found, expected) + + unexpected := translate(cmdDeprecated.Name()) + checkStringOmits(t, found, unexpected) + + // auto generated + expected = translate("Auto generated") + checkStringContains(t, found, expected) +} + +func TestGenManNoGenTag(t *testing.T) { + c := initializeWithRootCmd() + // Need two commands to run the command alphabetical sort + cmdEcho.AddCommand(cmdTimes, cmdEchoSub, cmdDeprecated) + c.AddCommand(cmdPrint, cmdEcho) + cmdRootWithRun.PersistentFlags().StringVarP(&flags2a, "rootflag", "r", "two", strtwoParentHelp) + cmdEcho.DisableAutoGenTag = true + out := new(bytes.Buffer) + + header := &GenManHeader{ + Title: "Project", + Section: "2", + } + // We generate on a subcommand so we have both subcommands and parents + if err := GenMan(cmdEcho, header, out); err != nil { + t.Fatal(err) + } + found := out.String() + + unexpected := translate("#HISTORY") + checkStringOmits(t, found, unexpected) +} diff --git a/Godeps/_workspace/src/github.com/spf13/cobra/doc/man_examples_test.go b/Godeps/_workspace/src/github.com/spf13/cobra/doc/man_examples_test.go new file mode 100644 index 00000000..6d040136 --- /dev/null +++ b/Godeps/_workspace/src/github.com/spf13/cobra/doc/man_examples_test.go @@ -0,0 +1,35 @@ +package doc_test + +import ( + "bytes" + "fmt" + + "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/spf13/cobra" + "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/spf13/cobra/doc" +) + +func ExampleCommand_GenManTree() { + cmd := &cobra.Command{ + Use: "test", + Short: "my test program", + } + header := &doc.GenManHeader{ + Title: "MINE", + Section: "3", + } + doc.GenManTree(cmd, header, "/tmp") +} + +func ExampleCommand_GenMan() { + cmd := &cobra.Command{ + Use: "test", + Short: "my test program", + } + header := &doc.GenManHeader{ + Title: "MINE", + Section: "3", + } + out := new(bytes.Buffer) + doc.GenMan(cmd, header, out) + fmt.Print(out.String()) +} diff --git a/Godeps/_workspace/src/github.com/spf13/cobra/doc/md_docs.go b/Godeps/_workspace/src/github.com/spf13/cobra/doc/md_docs.go new file mode 100644 index 00000000..3d6c9e7a --- /dev/null +++ b/Godeps/_workspace/src/github.com/spf13/cobra/doc/md_docs.go @@ -0,0 +1,174 @@ +//Copyright 2015 Red Hat Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package doc + +import ( + "fmt" + "io" + "os" + "sort" + "strings" + "time" + + "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/spf13/cobra" +) + +func printOptions(w io.Writer, cmd *cobra.Command, name string) error { + flags := cmd.NonInheritedFlags() + flags.SetOutput(w) + if flags.HasFlags() { + if _, err := fmt.Fprintf(w, "### Options\n\n```\n"); err != nil { + return err + } + flags.PrintDefaults() + if _, err := fmt.Fprintf(w, "```\n\n"); err != nil { + return err + } + } + + parentFlags := cmd.InheritedFlags() + parentFlags.SetOutput(w) + if parentFlags.HasFlags() { + if _, err := fmt.Fprintf(w, "### Options inherited from parent commands\n\n```\n"); err != nil { + return err + } + parentFlags.PrintDefaults() + if _, err := fmt.Fprintf(w, "```\n\n"); err != nil { + return err + } + } + return nil +} + +func GenMarkdown(cmd *cobra.Command, w io.Writer) error { + return GenMarkdownCustom(cmd, w, func(s string) string { return s }) +} + +func GenMarkdownCustom(cmd *cobra.Command, w io.Writer, linkHandler func(string) string) error { + name := cmd.CommandPath() + + short := cmd.Short + long := cmd.Long + if len(long) == 0 { + long = short + } + + if _, err := fmt.Fprintf(w, "## %s\n\n", name); err != nil { + return err + } + if _, err := fmt.Fprintf(w, "%s\n\n", short); err != nil { + return err + } + if _, err := fmt.Fprintf(w, "### Synopsis\n\n"); err != nil { + return err + } + if _, err := fmt.Fprintf(w, "\n%s\n\n", long); err != nil { + return err + } + + if cmd.Runnable() { + if _, err := fmt.Fprintf(w, "```\n%s\n```\n\n", cmd.UseLine()); err != nil { + return err + } + } + + if len(cmd.Example) > 0 { + if _, err := fmt.Fprintf(w, "### Examples\n\n"); err != nil { + return err + } + if _, err := fmt.Fprintf(w, "```\n%s\n```\n\n", cmd.Example); err != nil { + return err + } + } + + if err := printOptions(w, cmd, name); err != nil { + return err + } + if hasSeeAlso(cmd) { + if _, err := fmt.Fprintf(w, "### SEE ALSO\n"); err != nil { + return err + } + if cmd.HasParent() { + parent := cmd.Parent() + pname := parent.CommandPath() + link := pname + ".md" + link = strings.Replace(link, " ", "_", -1) + if _, err := fmt.Fprintf(w, "* [%s](%s)\t - %s\n", pname, linkHandler(link), parent.Short); err != nil { + return err + } + cmd.VisitParents(func(c *cobra.Command) { + if c.DisableAutoGenTag { + cmd.DisableAutoGenTag = c.DisableAutoGenTag + } + }) + } + + children := cmd.Commands() + sort.Sort(byName(children)) + + for _, child := range children { + if !child.IsAvailableCommand() || child.IsHelpCommand() { + continue + } + cname := name + " " + child.Name() + link := cname + ".md" + link = strings.Replace(link, " ", "_", -1) + if _, err := fmt.Fprintf(w, "* [%s](%s)\t - %s\n", cname, linkHandler(link), child.Short); err != nil { + return err + } + } + if _, err := fmt.Fprintf(w, "\n"); err != nil { + return err + } + } + if !cmd.DisableAutoGenTag { + if _, err := fmt.Fprintf(w, "###### Auto generated by spf13/cobra on %s\n", time.Now().Format("2-Jan-2006")); err != nil { + return err + } + } + return nil +} + +func GenMarkdownTree(cmd *cobra.Command, dir string) error { + identity := func(s string) string { return s } + emptyStr := func(s string) string { return "" } + return GenMarkdownTreeCustom(cmd, dir, emptyStr, identity) +} + +func GenMarkdownTreeCustom(cmd *cobra.Command, dir string, filePrepender, linkHandler func(string) string) error { + for _, c := range cmd.Commands() { + if !c.IsAvailableCommand() || c.IsHelpCommand() { + continue + } + if err := GenMarkdownTreeCustom(c, dir, filePrepender, linkHandler); err != nil { + return err + } + } + + filename := cmd.CommandPath() + filename = dir + strings.Replace(filename, " ", "_", -1) + ".md" + f, err := os.Create(filename) + if err != nil { + return err + } + defer f.Close() + + if _, err := io.WriteString(f, filePrepender(filename)); err != nil { + return err + } + if err := GenMarkdownCustom(cmd, f, linkHandler); err != nil { + return err + } + return nil +} diff --git a/Godeps/_workspace/src/github.com/spf13/cobra/doc/md_docs.md b/Godeps/_workspace/src/github.com/spf13/cobra/doc/md_docs.md new file mode 100644 index 00000000..da35f92e --- /dev/null +++ b/Godeps/_workspace/src/github.com/spf13/cobra/doc/md_docs.md @@ -0,0 +1,81 @@ +# Generating Markdown Docs For Your Own cobra.Command + +## Generate markdown docs for the entire command tree + +This program can actually generate docs for the kubectl command in the kubernetes project + +```go +package main + +import ( + "io/ioutil" + "os" + + "github.com/GoogleCloudPlatform/kubernetes/pkg/kubectl/cmd" + "github.com/spf13/cobra/cobra" +) + +func main() { + kubectl := cmd.NewFactory(nil).NewKubectlCommand(os.Stdin, ioutil.Discard, ioutil.Discard) + doc.GenMarkdownTree(kubectl, "./") +} +``` + +This will generate a whole series of files, one for each command in the tree, in the directory specified (in this case "./") + +## Generate markdown docs for a single command + +You may wish to have more control over the output, or only generate for a single command, instead of the entire command tree. If this is the case you may prefer to `GenMarkdown` instead of `GenMarkdownTree` + +```go + out := new(bytes.Buffer) + doc.GenMarkdown(cmd, out) +``` + +This will write the markdown doc for ONLY "cmd" into the out, buffer. + +## Customize the output + +Both `GenMarkdown` and `GenMarkdownTree` have alternate versions with callbacks to get some control of the output: + +```go +func GenMarkdownTreeCustom(cmd *Command, dir string, filePrepender, linkHandler func(string) string) error { + //... +} +``` + +```go +func GenMarkdownCustom(cmd *Command, out *bytes.Buffer, linkHandler func(string) string) error { + //... +} +``` + +The `filePrepender` will prepend the return value given the full filepath to the rendered Markdown file. A common use case is to add front matter to use the generated documentation with [Hugo](http://gohugo.io/): + +```go +const fmTemplate = `--- +date: %s +title: "%s" +slug: %s +url: %s +--- +` + +filePrepender := func(filename string) string { + now := time.Now().Format(time.RFC3339) + name := filepath.Base(filename) + base := strings.TrimSuffix(name, path.Ext(name)) + url := "/commands/" + strings.ToLower(base) + "/" + return fmt.Sprintf(fmTemplate, now, strings.Replace(base, "_", " ", -1), base, url) +} +``` + +The `linkHandler` can be used to customize the rendered internal links to the commands, given a filename: + +```go +linkHandler := func(name string) string { + base := strings.TrimSuffix(name, path.Ext(name)) + return "/commands/" + strings.ToLower(base) + "/" +} +``` + diff --git a/Godeps/_workspace/src/github.com/spf13/cobra/md_docs_test.go b/Godeps/_workspace/src/github.com/spf13/cobra/doc/md_docs_test.go similarity index 75% rename from Godeps/_workspace/src/github.com/spf13/cobra/md_docs_test.go rename to Godeps/_workspace/src/github.com/spf13/cobra/doc/md_docs_test.go index defc9411..86ee0293 100644 --- a/Godeps/_workspace/src/github.com/spf13/cobra/md_docs_test.go +++ b/Godeps/_workspace/src/github.com/spf13/cobra/doc/md_docs_test.go @@ -1,4 +1,4 @@ -package cobra +package doc import ( "bytes" @@ -21,7 +21,9 @@ func TestGenMdDoc(t *testing.T) { out := new(bytes.Buffer) // We generate on s subcommand so we have both subcommands and parents - GenMarkdown(cmdEcho, out) + if err := GenMarkdown(cmdEcho, out); err != nil { + t.Fatal(err) + } found := out.String() // Our description @@ -65,3 +67,22 @@ func TestGenMdDoc(t *testing.T) { t.Errorf("Unexpected response.\nFound: %v\nBut should not have!!\n", unexpected) } } + +func TestGenMdNoTag(t *testing.T) { + c := initializeWithRootCmd() + // Need two commands to run the command alphabetical sort + cmdEcho.AddCommand(cmdTimes, cmdEchoSub, cmdDeprecated) + c.AddCommand(cmdPrint, cmdEcho) + c.DisableAutoGenTag = true + cmdRootWithRun.PersistentFlags().StringVarP(&flags2a, "rootflag", "r", "two", strtwoParentHelp) + out := new(bytes.Buffer) + + if err := GenMarkdown(c, out); err != nil { + t.Fatal(err) + } + found := out.String() + + unexpected := "Auto generated" + checkStringOmits(t, found, unexpected) + +} diff --git a/Godeps/_workspace/src/github.com/spf13/cobra/doc/util.go b/Godeps/_workspace/src/github.com/spf13/cobra/doc/util.go new file mode 100644 index 00000000..54bfe676 --- /dev/null +++ b/Godeps/_workspace/src/github.com/spf13/cobra/doc/util.go @@ -0,0 +1,38 @@ +// Copyright 2015 Red Hat Inc. All rights reserved. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package doc + +import "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/spf13/cobra" + +// Test to see if we have a reason to print See Also information in docs +// Basically this is a test for a parent commend or a subcommand which is +// both not deprecated and not the autogenerated help command. +func hasSeeAlso(cmd *cobra.Command) bool { + if cmd.HasParent() { + return true + } + for _, c := range cmd.Commands() { + if !c.IsAvailableCommand() || c.IsHelpCommand() { + continue + } + return true + } + return false +} + +type byName []*cobra.Command + +func (s byName) Len() int { return len(s) } +func (s byName) Swap(i, j int) { s[i], s[j] = s[j], s[i] } +func (s byName) Less(i, j int) bool { return s[i].Name() < s[j].Name() } diff --git a/Godeps/_workspace/src/github.com/spf13/cobra/md_docs.go b/Godeps/_workspace/src/github.com/spf13/cobra/md_docs.go deleted file mode 100644 index 4a57ebd0..00000000 --- a/Godeps/_workspace/src/github.com/spf13/cobra/md_docs.go +++ /dev/null @@ -1,124 +0,0 @@ -//Copyright 2015 Red Hat Inc. All rights reserved. -// -// Licensed under the Apache License, Version 2.0 (the "License"); -// you may not use this file except in compliance with the License. -// You may obtain a copy of the License at -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, software -// distributed under the License is distributed on an "AS IS" BASIS, -// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -// See the License for the specific language governing permissions and -// limitations under the License. - -package cobra - -import ( - "bytes" - "fmt" - "os" - "sort" - "strings" - "time" -) - -func printOptions(out *bytes.Buffer, cmd *Command, name string) { - flags := cmd.NonInheritedFlags() - flags.SetOutput(out) - if flags.HasFlags() { - fmt.Fprintf(out, "### Options\n\n```\n") - flags.PrintDefaults() - fmt.Fprintf(out, "```\n\n") - } - - parentFlags := cmd.InheritedFlags() - parentFlags.SetOutput(out) - if parentFlags.HasFlags() { - fmt.Fprintf(out, "### Options inherited from parent commands\n\n```\n") - parentFlags.PrintDefaults() - fmt.Fprintf(out, "```\n\n") - } -} - -type byName []*Command - -func (s byName) Len() int { return len(s) } -func (s byName) Swap(i, j int) { s[i], s[j] = s[j], s[i] } -func (s byName) Less(i, j int) bool { return s[i].Name() < s[j].Name() } - -func GenMarkdown(cmd *Command, out *bytes.Buffer) { - name := cmd.CommandPath() - - short := cmd.Short - long := cmd.Long - if len(long) == 0 { - long = short - } - - fmt.Fprintf(out, "## %s\n\n", name) - fmt.Fprintf(out, "%s\n\n", short) - fmt.Fprintf(out, "### Synopsis\n\n") - fmt.Fprintf(out, "\n%s\n\n", long) - - if cmd.Runnable() { - fmt.Fprintf(out, "```\n%s\n```\n\n", cmd.UseLine()) - } - - if len(cmd.Example) > 0 { - fmt.Fprintf(out, "### Examples\n\n") - fmt.Fprintf(out, "```\n%s\n```\n\n", cmd.Example) - } - - printOptions(out, cmd, name) - - if len(cmd.Commands()) > 0 || cmd.HasParent() { - fmt.Fprintf(out, "### SEE ALSO\n") - if cmd.HasParent() { - parent := cmd.Parent() - pname := parent.CommandPath() - link := pname + ".md" - link = strings.Replace(link, " ", "_", -1) - fmt.Fprintf(out, "* [%s](%s)\t - %s\n", pname, link, parent.Short) - } - - children := cmd.Commands() - sort.Sort(byName(children)) - - for _, child := range children { - if len(child.Deprecated) > 0 { - continue - } - cname := name + " " + child.Name() - link := cname + ".md" - link = strings.Replace(link, " ", "_", -1) - fmt.Fprintf(out, "* [%s](%s)\t - %s\n", cname, link, child.Short) - } - fmt.Fprintf(out, "\n") - } - - fmt.Fprintf(out, "###### Auto generated by spf13/cobra at %s\n", time.Now().UTC()) -} - -func GenMarkdownTree(cmd *Command, dir string) { - for _, c := range cmd.Commands() { - GenMarkdownTree(c, dir) - } - - out := new(bytes.Buffer) - - GenMarkdown(cmd, out) - - filename := cmd.CommandPath() - filename = dir + strings.Replace(filename, " ", "_", -1) + ".md" - outFile, err := os.Create(filename) - if err != nil { - fmt.Println(err) - os.Exit(1) - } - defer outFile.Close() - _, err = outFile.Write(out.Bytes()) - if err != nil { - fmt.Println(err) - os.Exit(1) - } -} diff --git a/Godeps/_workspace/src/github.com/spf13/cobra/md_docs.md b/Godeps/_workspace/src/github.com/spf13/cobra/md_docs.md deleted file mode 100644 index 43b6c994..00000000 --- a/Godeps/_workspace/src/github.com/spf13/cobra/md_docs.md +++ /dev/null @@ -1,35 +0,0 @@ -# Generating Markdown Docs For Your Own cobra.Command - -## Generate markdown docs for the entire command tree - -This program can actually generate docs for the kubectl command in the kubernetes project - -```go -package main - -import ( - "io/ioutil" - "os" - - "github.com/GoogleCloudPlatform/kubernetes/pkg/kubectl/cmd" - "github.com/spf13/cobra" -) - -func main() { - kubectl := cmd.NewFactory(nil).NewKubectlCommand(os.Stdin, ioutil.Discard, ioutil.Discard) - cobra.GenMarkdownTree(kubectl, "./") -} -``` - -This will generate a whole series of files, one for each command in the tree, in the directory specified (in this case "./") - -## Generate markdown docs for a single command - -You may wish to have more control over the output, or only generate for a single command, instead of the entire command tree. If this is the case you may prefer to `GenMarkdown()` instead of `GenMarkdownTree` - -```go - out := new(bytes.Buffer) - cobra.GenMarkdown(cmd, out) -``` - -This will write the markdown doc for ONLY "cmd" into the out, buffer. diff --git a/Godeps/_workspace/src/github.com/spf13/pflag/.travis.yml b/Godeps/_workspace/src/github.com/spf13/pflag/.travis.yml new file mode 100644 index 00000000..c7d8e05d --- /dev/null +++ b/Godeps/_workspace/src/github.com/spf13/pflag/.travis.yml @@ -0,0 +1,18 @@ +sudo: false + +language: go + +go: + - 1.3 + - 1.4 + - 1.5 + - tip + +install: + - go get github.com/golang/lint/golint + - export PATH=$GOPATH/bin:$PATH + - go install ./... + +script: + - verify/all.sh -v + - go test ./... diff --git a/Godeps/_workspace/src/github.com/spf13/pflag/README.md b/Godeps/_workspace/src/github.com/spf13/pflag/README.md index 4eef10b7..e74dd50b 100644 --- a/Godeps/_workspace/src/github.com/spf13/pflag/README.md +++ b/Godeps/_workspace/src/github.com/spf13/pflag/README.md @@ -1,3 +1,5 @@ +[![Build Status](https://travis-ci.org/spf13/pflag.svg?branch=master)](https://travis-ci.org/spf13/pflag) + ## Description pflag is a drop-in replacement for Go's flag package, implementing @@ -18,11 +20,11 @@ pflag is available using the standard `go get` command. Install by running: - go get github.com/ogier/pflag + go get github.com/spf13/pflag Run tests by running: - go test github.com/ogier/pflag + go test github.com/spf13/pflag ## Usage @@ -31,7 +33,7 @@ pflag under the name "flag" then all code should continue to function with no changes. ``` go -import flag "github.com/ogier/pflag" +import flag "github.com/spf13/pflag" ``` There is one exception to this: if you directly instantiate the Flag struct @@ -82,6 +84,16 @@ fmt.Println("ip has value ", *ip) fmt.Println("flagvar has value ", flagvar) ``` +There are helpers function to get values later if you have the FlagSet but +it was difficult to keep up with all of the the flag pointers in your code. +If you have a pflag.FlagSet with a flag called 'flagname' of type int you +can use GetInt() to get the int value. But notice that 'flagname' must exist +and it must be an int. GetString("flagname") will fail. + +``` go +i, err := flagset.GetInt("flagname") +``` + After parsing, the arguments after the flag are available as the slice flag.Args() or individually as flag.Arg(i). The arguments are indexed from 0 through flag.NArg()-1. @@ -109,29 +121,56 @@ in a command-line interface. The methods of FlagSet are analogous to the top-level functions for the command-line flag set. +## Setting no option default values for flags + +After you create a flag it is possible to set the pflag.NoOptDefVal for +the given flag. Doing this changes the meaning of the flag slightly. If +a flag has a NoOptDefVal and the flag is set on the command line without +an option the flag will be set to the NoOptDefVal. For example given: + +``` go +var ip = flag.IntP("flagname", "f", 1234, "help message") +flag.Lookup("flagname").NoOptDefVal = "4321" +``` + +Would result in something like + +| Parsed Arguments | Resulting Value | +| ------------- | ------------- | +| --flagname=1357 | ip=1357 | +| --flagname | ip=4321 | +| [nothing] | ip=1234 | + ## Command line flag syntax ``` ---flag // boolean flags only +--flag // boolean flags, or flags with no option default values +--flag x // only on flags without a default value --flag=x ``` Unlike the flag package, a single dash before an option means something different than a double dash. Single dashes signify a series of shorthand -letters for flags. All but the last shorthand letter must be boolean flags. +letters for flags. All but the last shorthand letter must be boolean flags +or a flag with a default value ``` -// boolean flags +// boolean or flags where the 'no option default value' is set -f +-f=true -abc +but +-b true is INVALID -// non-boolean flags +// non-boolean and flags without a 'no option default value' -n 1234 --Ifile +-n=1234 +-n1234 // mixed -abcs "hello" --abcn1234 +-absd="hello" +-abcs1234 ``` Flag parsing stops after the terminator "--". Unlike the flag package, @@ -149,7 +188,7 @@ It is possible to set a custom flag name 'normalization function.' It allows fla **Example #1**: You want -, _, and . in flags to compare the same. aka --my-flag == --my_flag == --my.flag -```go +``` go func wordSepNormalizeFunc(f *pflag.FlagSet, name string) pflag.NormalizedName { from := []string{"-", "_"} to := "." @@ -164,7 +203,7 @@ myFlagSet.SetNormalizeFunc(wordSepNormalizeFunc) **Example #2**: You want to alias two flags. aka --old-flag-name == --new-flag-name -```go +``` go func aliasNormalizeFunc(f *pflag.FlagSet, name string) pflag.NormalizedName { switch name { case "old-flag-name": @@ -177,6 +216,34 @@ func aliasNormalizeFunc(f *pflag.FlagSet, name string) pflag.NormalizedName { myFlagSet.SetNormalizeFunc(aliasNormalizeFunc) ``` +## Deprecating a flag or its shorthand +It is possible to deprecate a flag, or just its shorthand. Deprecating a flag/shorthand hides it from help text and prints a usage message when the deprecated flag/shorthand is used. + +**Example #1**: You want to deprecate a flag named "badflag" as well as inform the users what flag they should use instead. +```go +// deprecate a flag by specifying its name and a usage message +flags.MarkDeprecated("badflag", "please use --good-flag instead") +``` +This hides "badflag" from help text, and prints `Flag --badflag has been deprecated, please use --good-flag instead` when "badflag" is used. + +**Example #2**: You want to keep a flag name "noshorthandflag" but deprecate its shortname "n". +```go +// deprecate a flag shorthand by specifying its flag name and a usage message +flags.MarkShorthandDeprecated("noshorthandflag", "please use --noshorthandflag only") +``` +This hides the shortname "n" from help text, and prints `Flag shorthand -n has been deprecated, please use --noshorthandflag only` when the shorthand "n" is used. + +Note that usage message is essential here, and it should not be empty. + +## Hidden flags +It is possible to mark a flag as hidden, meaning it will still function as normal, however will not show up in usage/help text. + +**Example**: You have a flag named "secretFlag" that you need for internal use only and don't want it showing up in help text, or for its usage text to be available. +```go +// hide a flag by specifying its name +flags.MarkHidden("secretFlag") +``` + ## More info You can see the full reference documentation of the pflag package diff --git a/Godeps/_workspace/src/github.com/spf13/pflag/bool.go b/Godeps/_workspace/src/github.com/spf13/pflag/bool.go index 70e2e0a6..d272e40b 100644 --- a/Godeps/_workspace/src/github.com/spf13/pflag/bool.go +++ b/Godeps/_workspace/src/github.com/spf13/pflag/bool.go @@ -34,37 +34,50 @@ func (b *boolValue) String() string { return fmt.Sprintf("%v", *b) } func (b *boolValue) IsBoolFlag() bool { return true } +func boolConv(sval string) (interface{}, error) { + return strconv.ParseBool(sval) +} + +// GetBool return the bool value of a flag with the given name +func (f *FlagSet) GetBool(name string) (bool, error) { + val, err := f.getFlagType(name, "bool", boolConv) + if err != nil { + return false, err + } + return val.(bool), nil +} + // BoolVar defines a bool flag with specified name, default value, and usage string. // The argument p points to a bool variable in which to store the value of the flag. func (f *FlagSet) BoolVar(p *bool, name string, value bool, usage string) { - f.VarP(newBoolValue(value, p), name, "", usage) + f.BoolVarP(p, name, "", value, usage) } -// Like BoolVar, but accepts a shorthand letter that can be used after a single dash. +// BoolVarP is like BoolVar, but accepts a shorthand letter that can be used after a single dash. func (f *FlagSet) BoolVarP(p *bool, name, shorthand string, value bool, usage string) { - f.VarP(newBoolValue(value, p), name, shorthand, usage) + flag := f.VarPF(newBoolValue(value, p), name, shorthand, usage) + flag.NoOptDefVal = "true" } // BoolVar defines a bool flag with specified name, default value, and usage string. // The argument p points to a bool variable in which to store the value of the flag. func BoolVar(p *bool, name string, value bool, usage string) { - CommandLine.VarP(newBoolValue(value, p), name, "", usage) + BoolVarP(p, name, "", value, usage) } -// Like BoolVar, but accepts a shorthand letter that can be used after a single dash. +// BoolVarP is like BoolVar, but accepts a shorthand letter that can be used after a single dash. func BoolVarP(p *bool, name, shorthand string, value bool, usage string) { - CommandLine.VarP(newBoolValue(value, p), name, shorthand, usage) + flag := CommandLine.VarPF(newBoolValue(value, p), name, shorthand, usage) + flag.NoOptDefVal = "true" } // Bool defines a bool flag with specified name, default value, and usage string. // The return value is the address of a bool variable that stores the value of the flag. func (f *FlagSet) Bool(name string, value bool, usage string) *bool { - p := new(bool) - f.BoolVarP(p, name, "", value, usage) - return p + return f.BoolP(name, "", value, usage) } -// Like Bool, but accepts a shorthand letter that can be used after a single dash. +// BoolP is like Bool, but accepts a shorthand letter that can be used after a single dash. func (f *FlagSet) BoolP(name, shorthand string, value bool, usage string) *bool { p := new(bool) f.BoolVarP(p, name, shorthand, value, usage) @@ -74,10 +87,11 @@ func (f *FlagSet) BoolP(name, shorthand string, value bool, usage string) *bool // Bool defines a bool flag with specified name, default value, and usage string. // The return value is the address of a bool variable that stores the value of the flag. func Bool(name string, value bool, usage string) *bool { - return CommandLine.BoolP(name, "", value, usage) + return BoolP(name, "", value, usage) } -// Like Bool, but accepts a shorthand letter that can be used after a single dash. +// BoolP is like Bool, but accepts a shorthand letter that can be used after a single dash. func BoolP(name, shorthand string, value bool, usage string) *bool { - return CommandLine.BoolP(name, shorthand, value, usage) + b := CommandLine.BoolP(name, shorthand, value, usage) + return b } diff --git a/Godeps/_workspace/src/github.com/spf13/pflag/bool_test.go b/Godeps/_workspace/src/github.com/spf13/pflag/bool_test.go index 200a19a7..afd25ae2 100644 --- a/Godeps/_workspace/src/github.com/spf13/pflag/bool_test.go +++ b/Godeps/_workspace/src/github.com/spf13/pflag/bool_test.go @@ -2,14 +2,13 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package pflag_test +package pflag import ( + "bytes" "fmt" "strconv" "testing" - - . "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/spf13/pflag" ) // This value can be a boolean ("true", "false") or "maybe" @@ -52,7 +51,7 @@ func (v *triStateValue) String() string { return fmt.Sprintf("%v", bool(*v == triStateTrue)) } -// The type of the flag as requred by the pflag.Value interface +// The type of the flag as required by the pflag.Value interface func (v *triStateValue) Type() string { return "version" } @@ -60,7 +59,8 @@ func (v *triStateValue) Type() string { func setUpFlagSet(tristate *triStateValue) *FlagSet { f := NewFlagSet("test", ContinueOnError) *tristate = triStateFalse - f.VarP(tristate, "tristate", "t", "tristate value (true, maybe or false)") + flag := f.VarPF(tristate, "tristate", "t", "tristate value (true, maybe or false)") + flag.NoOptDefVal = "true" return f } @@ -156,9 +156,25 @@ func TestImplicitFalse(t *testing.T) { func TestInvalidValue(t *testing.T) { var tristate triStateValue f := setUpFlagSet(&tristate) - args := []string{"--tristate=invalid"} - _, err := parseReturnStderr(t, f, args) + var buf bytes.Buffer + f.SetOutput(&buf) + err := f.Parse([]string{"--tristate=invalid"}) if err == nil { t.Fatal("expected an error but did not get any, tristate has value", tristate) } } + +func TestBoolP(t *testing.T) { + b := BoolP("bool", "b", false, "bool value in CommandLine") + c := BoolP("c", "c", false, "other bool value") + args := []string{"--bool"} + if err := CommandLine.Parse(args); err != nil { + t.Error("expected no error, got ", err) + } + if *b != true { + t.Errorf("expected b=true got b=%s", b) + } + if *c != false { + t.Errorf("expect c=false got c=%s", c) + } +} diff --git a/Godeps/_workspace/src/github.com/spf13/pflag/count.go b/Godeps/_workspace/src/github.com/spf13/pflag/count.go new file mode 100644 index 00000000..7b1f142e --- /dev/null +++ b/Godeps/_workspace/src/github.com/spf13/pflag/count.go @@ -0,0 +1,97 @@ +package pflag + +import ( + "fmt" + "strconv" +) + +// -- count Value +type countValue int + +func newCountValue(val int, p *int) *countValue { + *p = val + return (*countValue)(p) +} + +func (i *countValue) Set(s string) error { + v, err := strconv.ParseInt(s, 0, 64) + // -1 means that no specific value was passed, so increment + if v == -1 { + *i = countValue(*i + 1) + } else { + *i = countValue(v) + } + return err +} + +func (i *countValue) Type() string { + return "count" +} + +func (i *countValue) String() string { return fmt.Sprintf("%v", *i) } + +func countConv(sval string) (interface{}, error) { + i, err := strconv.Atoi(sval) + if err != nil { + return nil, err + } + return i, nil +} + +// GetCount return the int value of a flag with the given name +func (f *FlagSet) GetCount(name string) (int, error) { + val, err := f.getFlagType(name, "count", countConv) + if err != nil { + return 0, err + } + return val.(int), nil +} + +// CountVar defines a count flag with specified name, default value, and usage string. +// The argument p points to an int variable in which to store the value of the flag. +// A count flag will add 1 to its value evey time it is found on the command line +func (f *FlagSet) CountVar(p *int, name string, usage string) { + f.CountVarP(p, name, "", usage) +} + +// CountVarP is like CountVar only take a shorthand for the flag name. +func (f *FlagSet) CountVarP(p *int, name, shorthand string, usage string) { + flag := f.VarPF(newCountValue(0, p), name, shorthand, usage) + flag.NoOptDefVal = "-1" +} + +// CountVar like CountVar only the flag is placed on the CommandLine instead of a given flag set +func CountVar(p *int, name string, usage string) { + CommandLine.CountVar(p, name, usage) +} + +// CountVarP is like CountVar only take a shorthand for the flag name. +func CountVarP(p *int, name, shorthand string, usage string) { + CommandLine.CountVarP(p, name, shorthand, usage) +} + +// Count defines a count flag with specified name, default value, and usage string. +// The return value is the address of an int variable that stores the value of the flag. +// A count flag will add 1 to its value evey time it is found on the command line +func (f *FlagSet) Count(name string, usage string) *int { + p := new(int) + f.CountVarP(p, name, "", usage) + return p +} + +// CountP is like Count only takes a shorthand for the flag name. +func (f *FlagSet) CountP(name, shorthand string, usage string) *int { + p := new(int) + f.CountVarP(p, name, shorthand, usage) + return p +} + +// Count like Count only the flag is placed on the CommandLine isntead of a given flag set +func Count(name string, usage string) *int { + return CommandLine.CountP(name, "", usage) +} + +// CountP is like Count only takes a shorthand for the flag name. +func CountP(name, shorthand string, usage string) *int { + return CommandLine.CountP(name, shorthand, usage) +} diff --git a/Godeps/_workspace/src/github.com/spf13/pflag/count_test.go b/Godeps/_workspace/src/github.com/spf13/pflag/count_test.go new file mode 100644 index 00000000..716765cb --- /dev/null +++ b/Godeps/_workspace/src/github.com/spf13/pflag/count_test.go @@ -0,0 +1,55 @@ +package pflag + +import ( + "fmt" + "os" + "testing" +) + +var _ = fmt.Printf + +func setUpCount(c *int) *FlagSet { + f := NewFlagSet("test", ContinueOnError) + f.CountVarP(c, "verbose", "v", "a counter") + return f +} + +func TestCount(t *testing.T) { + testCases := []struct { + input []string + success bool + expected int + }{ + {[]string{"-vvv"}, true, 3}, + {[]string{"-v", "-v", "-v"}, true, 3}, + {[]string{"-v", "--verbose", "-v"}, true, 3}, + {[]string{"-v=3", "-v"}, true, 4}, + {[]string{"-v=a"}, false, 0}, + } + + devnull, _ := os.Open(os.DevNull) + os.Stderr = devnull + for i := range testCases { + var count int + f := setUpCount(&count) + + tc := &testCases[i] + + err := f.Parse(tc.input) + if err != nil && tc.success == true { + t.Errorf("expected success, got %q", err) + continue + } else if err == nil && tc.success == false { + t.Errorf("expected failure, got success") + continue + } else if tc.success { + c, err := f.GetCount("verbose") + if err != nil { + t.Errorf("Got error trying to fetch the counter flag") + } + if c != tc.expected { + t.Errorf("expected %q, got %q", tc.expected, c) + } + } + } +} diff --git a/Godeps/_workspace/src/github.com/spf13/pflag/duration.go b/Godeps/_workspace/src/github.com/spf13/pflag/duration.go index 66ed7ac9..e9debef8 100644 --- a/Godeps/_workspace/src/github.com/spf13/pflag/duration.go +++ b/Godeps/_workspace/src/github.com/spf13/pflag/duration.go @@ -1,6 +1,8 @@ package pflag -import "time" +import ( + "time" +) // -- time.Duration Value type durationValue time.Duration @@ -22,13 +24,26 @@ func (d *durationValue) Type() string { func (d *durationValue) String() string { return (*time.Duration)(d).String() } +func durationConv(sval string) (interface{}, error) { + return time.ParseDuration(sval) +} + +// GetDuration return the duration value of a flag with the given name +func (f *FlagSet) GetDuration(name string) (time.Duration, error) { + val, err := f.getFlagType(name, "duration", durationConv) + if err != nil { + return 0, err + } + return val.(time.Duration), nil +} + // DurationVar defines a time.Duration flag with specified name, default value, and usage string. // The argument p points to a time.Duration variable in which to store the value of the flag. func (f *FlagSet) DurationVar(p *time.Duration, name string, value time.Duration, usage string) { f.VarP(newDurationValue(value, p), name, "", usage) } -// Like DurationVar, but accepts a shorthand letter that can be used after a single dash. +// DurationVarP is like DurationVar, but accepts a shorthand letter that can be used after a single dash. func (f *FlagSet) DurationVarP(p *time.Duration, name, shorthand string, value time.Duration, usage string) { f.VarP(newDurationValue(value, p), name, shorthand, usage) } @@ -39,7 +54,7 @@ func DurationVar(p *time.Duration, name string, value time.Duration, usage strin CommandLine.VarP(newDurationValue(value, p), name, "", usage) } -// Like DurationVar, but accepts a shorthand letter that can be used after a single dash. +// DurationVarP is like DurationVar, but accepts a shorthand letter that can be used after a single dash. func DurationVarP(p *time.Duration, name, shorthand string, value time.Duration, usage string) { CommandLine.VarP(newDurationValue(value, p), name, shorthand, usage) } @@ -52,7 +67,7 @@ func (f *FlagSet) Duration(name string, value time.Duration, usage string) *time return p } -// Like Duration, but accepts a shorthand letter that can be used after a single dash. +// DurationP is like Duration, but accepts a shorthand letter that can be used after a single dash. func (f *FlagSet) DurationP(name, shorthand string, value time.Duration, usage string) *time.Duration { p := new(time.Duration) f.DurationVarP(p, name, shorthand, value, usage) @@ -65,7 +80,7 @@ func Duration(name string, value time.Duration, usage string) *time.Duration { return CommandLine.DurationP(name, "", value, usage) } -// Like Duration, but accepts a shorthand letter that can be used after a single dash. +// DurationP is like Duration, but accepts a shorthand letter that can be used after a single dash. func DurationP(name, shorthand string, value time.Duration, usage string) *time.Duration { return CommandLine.DurationP(name, shorthand, value, usage) } diff --git a/Godeps/_workspace/src/github.com/spf13/pflag/flag.go b/Godeps/_workspace/src/github.com/spf13/pflag/flag.go index 55594df4..deac3af1 100644 --- a/Godeps/_workspace/src/github.com/spf13/pflag/flag.go +++ b/Godeps/_workspace/src/github.com/spf13/pflag/flag.go @@ -3,98 +3,98 @@ // license that can be found in the LICENSE file. /* - pflag is a drop-in replacement for Go's flag package, implementing - POSIX/GNU-style --flags. +Package pflag is a drop-in replacement for Go's flag package, implementing +POSIX/GNU-style --flags. - pflag is compatible with the GNU extensions to the POSIX recommendations - for command-line options. See - http://www.gnu.org/software/libc/manual/html_node/Argument-Syntax.html +pflag is compatible with the GNU extensions to the POSIX recommendations +for command-line options. See +http://www.gnu.org/software/libc/manual/html_node/Argument-Syntax.html - Usage: +Usage: - pflag is a drop-in replacement of Go's native flag package. If you import - pflag under the name "flag" then all code should continue to function - with no changes. +pflag is a drop-in replacement of Go's native flag package. If you import +pflag under the name "flag" then all code should continue to function +with no changes. - import flag "github.com/ogier/pflag" + import flag "github.com/ogier/pflag" There is one exception to this: if you directly instantiate the Flag struct - there is one more field "Shorthand" that you will need to set. - Most code never instantiates this struct directly, and instead uses - functions such as String(), BoolVar(), and Var(), and is therefore - unaffected. +there is one more field "Shorthand" that you will need to set. +Most code never instantiates this struct directly, and instead uses +functions such as String(), BoolVar(), and Var(), and is therefore +unaffected. - Define flags using flag.String(), Bool(), Int(), etc. +Define flags using flag.String(), Bool(), Int(), etc. - This declares an integer flag, -flagname, stored in the pointer ip, with type *int. - var ip = flag.Int("flagname", 1234, "help message for flagname") - If you like, you can bind the flag to a variable using the Var() functions. - var flagvar int - func init() { - flag.IntVar(&flagvar, "flagname", 1234, "help message for flagname") - } - Or you can create custom flags that satisfy the Value interface (with - pointer receivers) and couple them to flag parsing by - flag.Var(&flagVal, "name", "help message for flagname") - For such flags, the default value is just the initial value of the variable. +This declares an integer flag, -flagname, stored in the pointer ip, with type *int. + var ip = flag.Int("flagname", 1234, "help message for flagname") +If you like, you can bind the flag to a variable using the Var() functions. + var flagvar int + func init() { + flag.IntVar(&flagvar, "flagname", 1234, "help message for flagname") + } +Or you can create custom flags that satisfy the Value interface (with +pointer receivers) and couple them to flag parsing by + flag.Var(&flagVal, "name", "help message for flagname") +For such flags, the default value is just the initial value of the variable. - After all flags are defined, call - flag.Parse() - to parse the command line into the defined flags. +After all flags are defined, call + flag.Parse() +to parse the command line into the defined flags. - Flags may then be used directly. If you're using the flags themselves, - they are all pointers; if you bind to variables, they're values. - fmt.Println("ip has value ", *ip) - fmt.Println("flagvar has value ", flagvar) +Flags may then be used directly. If you're using the flags themselves, +they are all pointers; if you bind to variables, they're values. + fmt.Println("ip has value ", *ip) + fmt.Println("flagvar has value ", flagvar) - After parsing, the arguments after the flag are available as the - slice flag.Args() or individually as flag.Arg(i). - The arguments are indexed from 0 through flag.NArg()-1. +After parsing, the arguments after the flag are available as the +slice flag.Args() or individually as flag.Arg(i). +The arguments are indexed from 0 through flag.NArg()-1. - The pflag package also defines some new functions that are not in flag, - that give one-letter shorthands for flags. You can use these by appending - 'P' to the name of any function that defines a flag. - var ip = flag.IntP("flagname", "f", 1234, "help message") - var flagvar bool - func init() { - flag.BoolVarP("boolname", "b", true, "help message") - } - flag.VarP(&flagVar, "varname", "v", 1234, "help message") - Shorthand letters can be used with single dashes on the command line. - Boolean shorthand flags can be combined with other shorthand flags. +The pflag package also defines some new functions that are not in flag, +that give one-letter shorthands for flags. You can use these by appending +'P' to the name of any function that defines a flag. + var ip = flag.IntP("flagname", "f", 1234, "help message") + var flagvar bool + func init() { + flag.BoolVarP("boolname", "b", true, "help message") + } + flag.VarP(&flagVar, "varname", "v", 1234, "help message") +Shorthand letters can be used with single dashes on the command line. +Boolean shorthand flags can be combined with other shorthand flags. - Command line flag syntax: - --flag // boolean flags only - --flag=x +Command line flag syntax: + --flag // boolean flags only + --flag=x - Unlike the flag package, a single dash before an option means something - different than a double dash. Single dashes signify a series of shorthand - letters for flags. All but the last shorthand letter must be boolean flags. - // boolean flags - -f - -abc - // non-boolean flags - -n 1234 - -Ifile - // mixed - -abcs "hello" - -abcn1234 +Unlike the flag package, a single dash before an option means something +different than a double dash. Single dashes signify a series of shorthand +letters for flags. All but the last shorthand letter must be boolean flags. + // boolean flags + -f + -abc + // non-boolean flags + -n 1234 + -Ifile + // mixed + -abcs "hello" + -abcn1234 - Flag parsing stops after the terminator "--". Unlike the flag package, - flags can be interspersed with arguments anywhere on the command line - before this terminator. +Flag parsing stops after the terminator "--". Unlike the flag package, +flags can be interspersed with arguments anywhere on the command line +before this terminator. - Integer flags accept 1234, 0664, 0x1234 and may be negative. - Boolean flags (in their long form) accept 1, 0, t, f, true, false, - TRUE, FALSE, True, False. - Duration flags accept any input valid for time.ParseDuration. +Integer flags accept 1234, 0664, 0x1234 and may be negative. +Boolean flags (in their long form) accept 1, 0, t, f, true, false, +TRUE, FALSE, True, False. +Duration flags accept any input valid for time.ParseDuration. - The default set of command-line flags is controlled by - top-level functions. The FlagSet type allows one to define - independent sets of flags, such as to implement subcommands - in a command-line interface. The methods of FlagSet are - analogous to the top-level functions for the command-line - flag set. +The default set of command-line flags is controlled by +top-level functions. The FlagSet type allows one to define +independent sets of flags, such as to implement subcommands +in a command-line interface. The methods of FlagSet are +analogous to the top-level functions for the command-line +flag set. */ package pflag @@ -115,8 +115,11 @@ var ErrHelp = errors.New("pflag: help requested") type ErrorHandling int const ( + // ContinueOnError will return an err from Parse() if an error is found ContinueOnError ErrorHandling = iota + // ExitOnError will call os.Exit(2) if an error is found when parsing ExitOnError + // PanicOnError will panic() if an error is found when parsing flags PanicOnError ) @@ -137,6 +140,7 @@ type FlagSet struct { formal map[NormalizedName]*Flag shorthands map[byte]*Flag args []string // arguments after flags + argsLenAtDash int // len(args) when a '--' was located when parsing, or -1 if no -- exitOnError bool // does the program exit if there's an error? errorHandling ErrorHandling output io.Writer // nil means stderr; use out() accessor @@ -146,14 +150,17 @@ type FlagSet struct { // A Flag represents the state of a flag. type Flag struct { - Name string // name as it appears on command line - Shorthand string // one-letter abbreviated flag - Usage string // help message - Value Value // value as set - DefValue string // default value (as text); for usage message - Changed bool // If the user set the value (or if left to default) - Deprecated string // If this flag is deprecated, this string is the new or now thing to use - Annotations map[string][]string // used by cobra.Command bash autocomple code + Name string // name as it appears on command line + Shorthand string // one-letter abbreviated flag + Usage string // help message + Value Value // value as set + DefValue string // default value (as text); for usage message + Changed bool // If the user set the value (or if left to default) + NoOptDefVal string //default value (as text); if the flag is on the command line without any options + Deprecated string // If this flag is deprecated, this string is the new or now thing to use + Hidden bool // used by cobra.Command to allow flags to be hidden from help/usage text + ShorthandDeprecated string // If the shorthand of this flag is deprecated, this string is the new or now thing to use + Annotations map[string][]string // used by cobra.Command bash autocomple code } // Value is the interface to the dynamic value stored in a flag. @@ -180,14 +187,23 @@ func sortFlags(flags map[NormalizedName]*Flag) []*Flag { return result } +// SetNormalizeFunc allows you to add a function which can translate flag names. +// Flags added to the FlagSet will be translated and then when anything tries to +// look up the flag that will also be translated. So it would be possible to create +// a flag named "getURL" and have it translated to "geturl". A user could then pass +// "--getUrl" which may also be translated to "geturl" and everything will work. func (f *FlagSet) SetNormalizeFunc(n func(f *FlagSet, name string) NormalizedName) { f.normalizeNameFunc = n for k, v := range f.formal { delete(f.formal, k) - f.formal[f.normalizeFlagName(string(k))] = v + nname := f.normalizeFlagName(string(k)) + f.formal[nname] = v + v.Name = string(nname) } } +// GetNormalizeFunc returns the previously set NormalizeFunc of a function which +// does no translation, if not set previously. func (f *FlagSet) GetNormalizeFunc() func(f *FlagSet, name string) NormalizedName { if f.normalizeNameFunc != nil { return f.normalizeNameFunc @@ -221,6 +237,7 @@ func (f *FlagSet) VisitAll(fn func(*Flag)) { } } +// HasFlags returns a bool to indicate if the FlagSet has any flags definied. func (f *FlagSet) HasFlags() bool { return len(f.formal) > 0 } @@ -255,16 +272,75 @@ func (f *FlagSet) lookup(name NormalizedName) *Flag { return f.formal[name] } -// Mark a flag deprecated in your program +// func to return a given type for a given flag name +func (f *FlagSet) getFlagType(name string, ftype string, convFunc func(sval string) (interface{}, error)) (interface{}, error) { + flag := f.Lookup(name) + if flag == nil { + err := fmt.Errorf("flag accessed but not defined: %s", name) + return nil, err + } + + if flag.Value.Type() != ftype { + err := fmt.Errorf("trying to get %s value of flag of type %s", ftype, flag.Value.Type()) + return nil, err + } + + sval := flag.Value.String() + result, err := convFunc(sval) + if err != nil { + return nil, err + } + return result, nil +} + +// ArgsLenAtDash will return the length of f.Args at the moment when a -- was +// found during arg parsing. This allows your program to know which args were +// before the -- and which came after. +func (f *FlagSet) ArgsLenAtDash() int { + return f.argsLenAtDash +} + +// MarkDeprecated indicated that a flag is deprecated in your program. It will +// continue to function but will not show up in help or usage messages. Using +// this flag will also print the given usageMessage. func (f *FlagSet) MarkDeprecated(name string, usageMessage string) error { flag := f.Lookup(name) if flag == nil { return fmt.Errorf("flag %q does not exist", name) } + if len(usageMessage) == 0 { + return fmt.Errorf("deprecated message for flag %q must be set", name) + } flag.Deprecated = usageMessage return nil } +// MarkShorthandDeprecated will mark the shorthand of a flag deprecated in your +// program. It will continue to function but will not show up in help or usage +// messages. Using this flag will also print the given usageMessage. +func (f *FlagSet) MarkShorthandDeprecated(name string, usageMessage string) error { + flag := f.Lookup(name) + if flag == nil { + return fmt.Errorf("flag %q does not exist", name) + } + if len(usageMessage) == 0 { + return fmt.Errorf("deprecated message for flag %q must be set", name) + } + flag.ShorthandDeprecated = usageMessage + return nil +} + +// MarkHidden sets a flag to 'hidden' in your program. It will continue to +// function but will not show up in help or usage messages. +func (f *FlagSet) MarkHidden(name string) error { + flag := f.Lookup(name) + if flag == nil { + return fmt.Errorf("flag %q does not exist", name) + } + flag.Hidden = true + return nil +} + // Lookup returns the Flag structure of the named command-line flag, // returning nil if none exists. func Lookup(name string) *Flag { @@ -293,6 +369,33 @@ func (f *FlagSet) Set(name, value string) error { return nil } +// SetAnnotation allows one to set arbitrary annotations on a flag in the FlagSet. +// This is sometimes used by spf13/cobra programs which want to generate additional +// bash completion information. +func (f *FlagSet) SetAnnotation(name, key string, values []string) error { + normalName := f.normalizeFlagName(name) + flag, ok := f.formal[normalName] + if !ok { + return fmt.Errorf("no such flag -%v", name) + } + if flag.Annotations == nil { + flag.Annotations = map[string][]string{} + } + flag.Annotations[key] = values + return nil +} + +// Changed returns true if the flag was explicitly set during Parse() and false +// otherwise +func (f *FlagSet) Changed(name string) bool { + flag := f.Lookup(name) + // If a flag doesn't exist, it wasn't changed.... + if flag == nil { + return false + } + return flag.Changed +} + // Set sets the value of the named command-line flag. func Set(name, value string) error { return CommandLine.Set(name, value) @@ -301,44 +404,127 @@ func Set(name, value string) error { // PrintDefaults prints, to standard error unless configured // otherwise, the default values of all defined flags in the set. func (f *FlagSet) PrintDefaults() { - f.VisitAll(func(flag *Flag) { - if len(flag.Deprecated) > 0 { - return - } - format := "--%s=%s: %s\n" - if _, ok := flag.Value.(*stringValue); ok { - // put quotes on the value - format = "--%s=%q: %s\n" - } - if len(flag.Shorthand) > 0 { - format = " -%s, " + format - } else { - format = " %s " + format - } - fmt.Fprintf(f.out(), format, flag.Shorthand, flag.Name, flag.DefValue, flag.Usage) - }) + usages := f.FlagUsages() + fmt.Fprintf(f.out(), "%s", usages) } +// isZeroValue guesses whether the string represents the zero +// value for a flag. It is not accurate but in practice works OK. +func isZeroValue(value string) bool { + switch value { + case "false": + return true + case "": + return true + case "": + return true + case "0": + return true + } + return false +} + +// UnquoteUsage extracts a back-quoted name from the usage +// string for a flag and returns it and the un-quoted usage. +// Given "a `name` to show" it returns ("name", "a name to show"). +// If there are no back quotes, the name is an educated guess of the +// type of the flag's value, or the empty string if the flag is boolean. +func UnquoteUsage(flag *Flag) (name string, usage string) { + // Look for a back-quoted name, but avoid the strings package. + usage = flag.Usage + for i := 0; i < len(usage); i++ { + if usage[i] == '`' { + for j := i + 1; j < len(usage); j++ { + if usage[j] == '`' { + name = usage[i+1 : j] + usage = usage[:i] + name + usage[j+1:] + return name, usage + } + } + break // Only one back quote; use type name. + } + } + // No explicit name, so use type if we can find one. + name = "value" + switch flag.Value.(type) { + case boolFlag: + name = "" + case *durationValue: + name = "duration" + case *float64Value: + name = "float" + case *intValue, *int64Value: + name = "int" + case *stringValue: + name = "string" + case *uintValue, *uint64Value: + name = "uint" + } + return +} + +// FlagUsages Returns a string containing the usage information for all flags in +// the FlagSet func (f *FlagSet) FlagUsages() string { x := new(bytes.Buffer) + lines := make([]string, 0, len(f.formal)) + + maxlen := 0 f.VisitAll(func(flag *Flag) { - if len(flag.Deprecated) > 0 { + if len(flag.Deprecated) > 0 || flag.Hidden { return } - format := "--%s=%s: %s\n" - if _, ok := flag.Value.(*stringValue); ok { - // put quotes on the value - format = "--%s=%q: %s\n" - } - if len(flag.Shorthand) > 0 { - format = " -%s, " + format + + line := "" + if len(flag.Shorthand) > 0 && len(flag.ShorthandDeprecated) == 0 { + line = fmt.Sprintf(" -%s, --%s", flag.Shorthand, flag.Name) } else { - format = " %s " + format + line = fmt.Sprintf(" --%s", flag.Name) } - fmt.Fprintf(x, format, flag.Shorthand, flag.Name, flag.DefValue, flag.Usage) + + varname, usage := UnquoteUsage(flag) + if len(varname) > 0 { + line += " " + varname + } + if len(flag.NoOptDefVal) > 0 { + switch flag.Value.Type() { + case "string": + line += fmt.Sprintf("[=%q]", flag.NoOptDefVal) + case "bool": + if flag.NoOptDefVal != "true" { + line += fmt.Sprintf("[=%s]", flag.NoOptDefVal) + } + default: + line += fmt.Sprintf("[=%s]", flag.NoOptDefVal) + } + } + + // This special character will be replaced with spacing once the + // correct alignment is calculated + line += "\x00" + if len(line) > maxlen { + maxlen = len(line) + } + + line += usage + if !isZeroValue(flag.DefValue) { + if flag.Value.Type() == "string" { + line += fmt.Sprintf(" (default %q)", flag.DefValue) + } else { + line += fmt.Sprintf(" (default %s)", flag.DefValue) + } + } + + lines = append(lines, line) }) + for _, line := range lines { + sidx := strings.Index(line, "\x00") + spacing := strings.Repeat(" ", maxlen-sidx) + fmt.Fprintln(x, line[:sidx], spacing, line[sidx+1:]) + } + return x.String() } @@ -359,6 +545,8 @@ func defaultUsage(f *FlagSet) { // Usage prints to standard error a usage message documenting all defined command-line flags. // The function is a variable that may be changed to point to a custom function. +// By default it prints a simple header and calls PrintDefaults; for details about the +// format of the output and how to control it, see the documentation for PrintDefaults. var Usage = func() { fmt.Fprintf(os.Stderr, "Usage of %s:\n", os.Args[0]) PrintDefaults() @@ -407,8 +595,8 @@ func (f *FlagSet) Var(value Value, name string, usage string) { f.VarP(value, name, "", usage) } -// Like Var, but accepts a shorthand letter that can be used after a single dash. -func (f *FlagSet) VarP(value Value, name, shorthand, usage string) { +// VarPF is like VarP, but returns the flag created +func (f *FlagSet) VarPF(value Value, name, shorthand, usage string) *Flag { // Remember the default value as a string; it won't change. flag := &Flag{ Name: name, @@ -418,10 +606,20 @@ func (f *FlagSet) VarP(value Value, name, shorthand, usage string) { DefValue: value.String(), } f.AddFlag(flag) + return flag } +// VarP is like Var, but accepts a shorthand letter that can be used after a single dash. +func (f *FlagSet) VarP(value Value, name, shorthand, usage string) { + _ = f.VarPF(value, name, shorthand, usage) +} + +// AddFlag will add the flag to the FlagSet func (f *FlagSet) AddFlag(flag *Flag) { - _, alreadythere := f.formal[f.normalizeFlagName(flag.Name)] + // Call normalizeFlagName function only once + normalizedFlagName := f.normalizeFlagName(flag.Name) + + _, alreadythere := f.formal[normalizedFlagName] if alreadythere { msg := fmt.Sprintf("%s flag redefined: %s", f.name, flag.Name) fmt.Fprintln(f.out(), msg) @@ -430,7 +628,9 @@ func (f *FlagSet) AddFlag(flag *Flag) { if f.formal == nil { f.formal = make(map[NormalizedName]*Flag) } - f.formal[f.normalizeFlagName(flag.Name)] = flag + + flag.Name = string(normalizedFlagName) + f.formal[normalizedFlagName] = flag if len(flag.Shorthand) == 0 { return @@ -451,6 +651,19 @@ func (f *FlagSet) AddFlag(flag *Flag) { f.shorthands[c] = flag } +// AddFlagSet adds one FlagSet to another. If a flag is already present in f +// the flag from newSet will be ignored +func (f *FlagSet) AddFlagSet(newSet *FlagSet) { + if newSet == nil { + return + } + newSet.VisitAll(func(flag *Flag) { + if f.Lookup(flag.Name) == nil { + f.AddFlag(flag) + } + }) +} + // Var defines a flag with the specified name and usage string. The type and // value of the flag are represented by the first argument, of type Value, which // typically holds a user-defined implementation of Value. For instance, the @@ -461,7 +674,7 @@ func Var(value Value, name string, usage string) { CommandLine.VarP(value, name, "", usage) } -// Like Var, but accepts a shorthand letter that can be used after a single dash. +// VarP is like Var, but accepts a shorthand letter that can be used after a single dash. func VarP(value Value, name, shorthand, usage string) { CommandLine.VarP(value, name, shorthand, usage) } @@ -500,15 +713,23 @@ func (f *FlagSet) setFlag(flag *Flag, value string, origArg string) error { if len(flag.Deprecated) > 0 { fmt.Fprintf(os.Stderr, "Flag --%s has been deprecated, %s\n", flag.Name, flag.Deprecated) } + if len(flag.ShorthandDeprecated) > 0 && containsShorthand(origArg, flag.Shorthand) { + fmt.Fprintf(os.Stderr, "Flag shorthand -%s has been deprecated, %s\n", flag.Shorthand, flag.ShorthandDeprecated) + } return nil } +func containsShorthand(arg, shorthand string) bool { + // filter out flags -- + if strings.HasPrefix(arg, "-") { + return false + } + arg = strings.SplitN(arg, "=", 2)[0] + return strings.Contains(arg, shorthand) +} + func (f *FlagSet) parseLongArg(s string, args []string) (a []string, err error) { a = args - if len(s) == 2 { // "--" terminates the flags - f.args = append(f.args, args...) - return - } name := s[2:] if len(name) == 0 || name[0] == '-' || name[0] == '=' { err = f.failf("bad flag syntax: %s", s) @@ -516,75 +737,80 @@ func (f *FlagSet) parseLongArg(s string, args []string) (a []string, err error) } split := strings.SplitN(name, "=", 2) name = split[0] - m := f.formal - flag, alreadythere := m[f.normalizeFlagName(name)] // BUG + flag, alreadythere := f.formal[f.normalizeFlagName(name)] if !alreadythere { if name == "help" { // special case for nice help message. f.usage() - return args, ErrHelp + return a, ErrHelp } err = f.failf("unknown flag: --%s", name) return } - if len(split) == 1 { - if bv, ok := flag.Value.(boolFlag); !ok || !bv.IsBoolFlag() { - err = f.failf("flag needs an argument: %s", s) - return - } - f.setFlag(flag, "true", s) + var value string + if len(split) == 2 { + // '--flag=arg' + value = split[1] + } else if len(flag.NoOptDefVal) > 0 { + // '--flag' (arg was optional) + value = flag.NoOptDefVal + } else if len(a) > 0 { + // '--flag arg' + value = a[0] + a = a[1:] } else { - if e := f.setFlag(flag, split[1], s); e != nil { - err = e + // '--flag' (arg was required) + err = f.failf("flag needs an argument: %s", s) + return + } + err = f.setFlag(flag, value, s) + return +} + +func (f *FlagSet) parseSingleShortArg(shorthands string, args []string) (outShorts string, outArgs []string, err error) { + outArgs = args + outShorts = shorthands[1:] + c := shorthands[0] + + flag, alreadythere := f.shorthands[c] + if !alreadythere { + if c == 'h' { // special case for nice help message. + f.usage() + err = ErrHelp return } + //TODO continue on error + err = f.failf("unknown shorthand flag: %q in -%s", c, shorthands) + return } - return args, nil + var value string + if len(shorthands) > 2 && shorthands[1] == '=' { + value = shorthands[2:] + outShorts = "" + } else if len(flag.NoOptDefVal) > 0 { + value = flag.NoOptDefVal + } else if len(shorthands) > 1 { + value = shorthands[1:] + outShorts = "" + } else if len(args) > 0 { + value = args[0] + outArgs = args[1:] + } else { + err = f.failf("flag needs an argument: %q in -%s", c, shorthands) + return + } + err = f.setFlag(flag, value, shorthands) + return } func (f *FlagSet) parseShortArg(s string, args []string) (a []string, err error) { a = args shorthands := s[1:] - for i := 0; i < len(shorthands); i++ { - c := shorthands[i] - flag, alreadythere := f.shorthands[c] - if !alreadythere { - if c == 'h' { // special case for nice help message. - f.usage() - err = ErrHelp - return - } - //TODO continue on error - err = f.failf("unknown shorthand flag: %q in -%s", c, shorthands) - if len(args) == 0 { - return - } + for len(shorthands) > 0 { + shorthands, a, err = f.parseSingleShortArg(shorthands, args) + if err != nil { return } - if alreadythere { - if bv, ok := flag.Value.(boolFlag); ok && bv.IsBoolFlag() { - f.setFlag(flag, "true", s) - continue - } - if i < len(shorthands)-1 { - v := strings.TrimPrefix(shorthands[i+1:], "=") - if e := f.setFlag(flag, v, s); e != nil { - err = e - return - } - break - } - if len(args) == 0 { - err = f.failf("flag needs an argument: %q in -%s", c, shorthands) - return - } - if e := f.setFlag(flag, args[0], s); e != nil { - err = e - return - } - } - a = args[1:] - break // should be unnecessary } return @@ -605,12 +831,12 @@ func (f *FlagSet) parseArgs(args []string) (err error) { } if s[1] == '-' { - args, err = f.parseLongArg(s, args) - - if len(s) == 2 { - // stop parsing after -- + if len(s) == 2 { // "--" terminates the flags + f.argsLenAtDash = len(f.args) + f.args = append(f.args, args...) break } + args, err = f.parseLongArg(s, args) } else { args, err = f.parseShortArg(s, args) } @@ -654,7 +880,7 @@ func Parse() { CommandLine.Parse(os.Args[1:]) } -// Whether to support interspersed option/non-option arguments. +// SetInterspersed sets whether to support interspersed option/non-option arguments. func SetInterspersed(interspersed bool) { CommandLine.SetInterspersed(interspersed) } @@ -664,7 +890,7 @@ func Parsed() bool { return CommandLine.Parsed() } -// The default set of command-line flags, parsed from os.Args. +// CommandLine is the default set of command-line flags, parsed from os.Args. var CommandLine = NewFlagSet(os.Args[0], ExitOnError) // NewFlagSet returns a new, empty flag set with the specified name and @@ -673,12 +899,13 @@ func NewFlagSet(name string, errorHandling ErrorHandling) *FlagSet { f := &FlagSet{ name: name, errorHandling: errorHandling, + argsLenAtDash: -1, interspersed: true, } return f } -// Whether to support interspersed option/non-option arguments. +// SetInterspersed sets whether to support interspersed option/non-option arguments. func (f *FlagSet) SetInterspersed(interspersed bool) { f.interspersed = interspersed } @@ -689,4 +916,5 @@ func (f *FlagSet) SetInterspersed(interspersed bool) { func (f *FlagSet) Init(name string, errorHandling ErrorHandling) { f.name = name f.errorHandling = errorHandling + f.argsLenAtDash = -1 } diff --git a/Godeps/_workspace/src/github.com/spf13/pflag/flag_test.go b/Godeps/_workspace/src/github.com/spf13/pflag/flag_test.go index eb876044..0ae2e4ff 100644 --- a/Godeps/_workspace/src/github.com/spf13/pflag/flag_test.go +++ b/Godeps/_workspace/src/github.com/spf13/pflag/flag_test.go @@ -2,31 +2,33 @@ // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. -package pflag_test +package pflag import ( "bytes" "fmt" "io" "io/ioutil" + "net" "os" + "reflect" "sort" "strings" "testing" "time" - - . "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/spf13/pflag" ) var ( - test_bool = Bool("test_bool", false, "bool value") - test_int = Int("test_int", 0, "int value") - test_int64 = Int64("test_int64", 0, "int64 value") - test_uint = Uint("test_uint", 0, "uint value") - test_uint64 = Uint64("test_uint64", 0, "uint64 value") - test_string = String("test_string", "0", "string value") - test_float64 = Float64("test_float64", 0, "float64 value") - test_duration = Duration("test_duration", 0, "time.Duration value") + testBool = Bool("test_bool", false, "bool value") + testInt = Int("test_int", 0, "int value") + testInt64 = Int64("test_int64", 0, "int64 value") + testUint = Uint("test_uint", 0, "uint value") + testUint64 = Uint64("test_uint64", 0, "uint64 value") + testString = String("test_string", "0", "string value") + testFloat = Float64("test_float64", 0, "float64 value") + testDuration = Duration("test_duration", 0, "time.Duration value") + testOptionalInt = Int("test_optional_int", 0, "optional int value") + normalizeFlagNameInvocations = 0 ) func boolString(s string) string { @@ -57,7 +59,7 @@ func TestEverything(t *testing.T) { } } VisitAll(visitor) - if len(m) != 8 { + if len(m) != 9 { t.Error("VisitAll misses some flags") for k, v := range m { t.Log(k, *v) @@ -80,9 +82,10 @@ func TestEverything(t *testing.T) { Set("test_string", "1") Set("test_float64", "1") Set("test_duration", "1s") + Set("test_optional_int", "1") desired = "1" Visit(visitor) - if len(m) != 8 { + if len(m) != 9 { t.Error("Visit fails after set") for k, v := range m { t.Log(k, *v) @@ -107,6 +110,54 @@ func TestUsage(t *testing.T) { } } +func TestAddFlagSet(t *testing.T) { + oldSet := NewFlagSet("old", ContinueOnError) + newSet := NewFlagSet("new", ContinueOnError) + + oldSet.String("flag1", "flag1", "flag1") + oldSet.String("flag2", "flag2", "flag2") + + newSet.String("flag2", "flag2", "flag2") + newSet.String("flag3", "flag3", "flag3") + + oldSet.AddFlagSet(newSet) + + if len(oldSet.formal) != 3 { + t.Errorf("Unexpected result adding a FlagSet to a FlagSet %v", oldSet) + } +} + +func TestAnnotation(t *testing.T) { + f := NewFlagSet("shorthand", ContinueOnError) + + if err := f.SetAnnotation("missing-flag", "key", nil); err == nil { + t.Errorf("Expected error setting annotation on non-existent flag") + } + + f.StringP("stringa", "a", "", "string value") + if err := f.SetAnnotation("stringa", "key", nil); err != nil { + t.Errorf("Unexpected error setting new nil annotation: %v", err) + } + if annotation := f.Lookup("stringa").Annotations["key"]; annotation != nil { + t.Errorf("Unexpected annotation: %v", annotation) + } + + f.StringP("stringb", "b", "", "string2 value") + if err := f.SetAnnotation("stringb", "key", []string{"value1"}); err != nil { + t.Errorf("Unexpected error setting new annotation: %v", err) + } + if annotation := f.Lookup("stringb").Annotations["key"]; !reflect.DeepEqual(annotation, []string{"value1"}) { + t.Errorf("Unexpected annotation: %v", annotation) + } + + if err := f.SetAnnotation("stringb", "key", []string{"value2"}); err != nil { + t.Errorf("Unexpected error updating annotation: %v", err) + } + if annotation := f.Lookup("stringb").Annotations["key"]; !reflect.DeepEqual(annotation, []string{"value2"}) { + t.Errorf("Unexpected annotation: %v", annotation) + } +} + func testParse(f *FlagSet, t *testing.T) { if f.Parsed() { t.Error("f.Parse() = true before Parse") @@ -115,24 +166,46 @@ func testParse(f *FlagSet, t *testing.T) { bool2Flag := f.Bool("bool2", false, "bool2 value") bool3Flag := f.Bool("bool3", false, "bool3 value") intFlag := f.Int("int", 0, "int value") + int8Flag := f.Int8("int8", 0, "int value") + int32Flag := f.Int32("int32", 0, "int value") int64Flag := f.Int64("int64", 0, "int64 value") uintFlag := f.Uint("uint", 0, "uint value") + uint8Flag := f.Uint8("uint8", 0, "uint value") + uint16Flag := f.Uint16("uint16", 0, "uint value") + uint32Flag := f.Uint32("uint32", 0, "uint value") uint64Flag := f.Uint64("uint64", 0, "uint64 value") stringFlag := f.String("string", "0", "string value") + float32Flag := f.Float32("float32", 0, "float32 value") float64Flag := f.Float64("float64", 0, "float64 value") + ipFlag := f.IP("ip", net.ParseIP("127.0.0.1"), "ip value") + maskFlag := f.IPMask("mask", ParseIPv4Mask("0.0.0.0"), "mask value") durationFlag := f.Duration("duration", 5*time.Second, "time.Duration value") + optionalIntNoValueFlag := f.Int("optional-int-no-value", 0, "int value") + f.Lookup("optional-int-no-value").NoOptDefVal = "9" + optionalIntWithValueFlag := f.Int("optional-int-with-value", 0, "int value") + f.Lookup("optional-int-no-value").NoOptDefVal = "9" extra := "one-extra-argument" args := []string{ "--bool", "--bool2=true", "--bool3=false", "--int=22", + "--int8=-8", + "--int32=-32", "--int64=0x23", - "--uint=24", + "--uint", "24", + "--uint8=8", + "--uint16=16", + "--uint32=32", "--uint64=25", "--string=hello", + "--float32=-172e12", "--float64=2718e28", + "--ip=10.11.12.13", + "--mask=255.255.255.0", "--duration=2m", + "--optional-int-no-value", + "--optional-int-with-value=42", extra, } if err := f.Parse(args); err != nil { @@ -144,6 +217,9 @@ func testParse(f *FlagSet, t *testing.T) { if *boolFlag != true { t.Error("bool flag should be true, is ", *boolFlag) } + if v, err := f.GetBool("bool"); err != nil || v != *boolFlag { + t.Error("GetBool does not work.") + } if *bool2Flag != true { t.Error("bool2 flag should be true, is ", *bool2Flag) } @@ -153,24 +229,102 @@ func testParse(f *FlagSet, t *testing.T) { if *intFlag != 22 { t.Error("int flag should be 22, is ", *intFlag) } + if v, err := f.GetInt("int"); err != nil || v != *intFlag { + t.Error("GetInt does not work.") + } + if *int8Flag != -8 { + t.Error("int8 flag should be 0x23, is ", *int8Flag) + } + if v, err := f.GetInt8("int8"); err != nil || v != *int8Flag { + t.Error("GetInt8 does not work.") + } + if *int32Flag != -32 { + t.Error("int32 flag should be 0x23, is ", *int32Flag) + } + if v, err := f.GetInt32("int32"); err != nil || v != *int32Flag { + t.Error("GetInt32 does not work.") + } if *int64Flag != 0x23 { t.Error("int64 flag should be 0x23, is ", *int64Flag) } + if v, err := f.GetInt64("int64"); err != nil || v != *int64Flag { + t.Error("GetInt64 does not work.") + } if *uintFlag != 24 { t.Error("uint flag should be 24, is ", *uintFlag) } + if v, err := f.GetUint("uint"); err != nil || v != *uintFlag { + t.Error("GetUint does not work.") + } + if *uint8Flag != 8 { + t.Error("uint8 flag should be 8, is ", *uint8Flag) + } + if v, err := f.GetUint8("uint8"); err != nil || v != *uint8Flag { + t.Error("GetUint8 does not work.") + } + if *uint16Flag != 16 { + t.Error("uint16 flag should be 16, is ", *uint16Flag) + } + if v, err := f.GetUint16("uint16"); err != nil || v != *uint16Flag { + t.Error("GetUint16 does not work.") + } + if *uint32Flag != 32 { + t.Error("uint32 flag should be 32, is ", *uint32Flag) + } + if v, err := f.GetUint32("uint32"); err != nil || v != *uint32Flag { + t.Error("GetUint32 does not work.") + } if *uint64Flag != 25 { t.Error("uint64 flag should be 25, is ", *uint64Flag) } + if v, err := f.GetUint64("uint64"); err != nil || v != *uint64Flag { + t.Error("GetUint64 does not work.") + } if *stringFlag != "hello" { t.Error("string flag should be `hello`, is ", *stringFlag) } + if v, err := f.GetString("string"); err != nil || v != *stringFlag { + t.Error("GetString does not work.") + } + if *float32Flag != -172e12 { + t.Error("float32 flag should be -172e12, is ", *float32Flag) + } + if v, err := f.GetFloat32("float32"); err != nil || v != *float32Flag { + t.Errorf("GetFloat32 returned %v but float32Flag was %v", v, *float32Flag) + } if *float64Flag != 2718e28 { t.Error("float64 flag should be 2718e28, is ", *float64Flag) } + if v, err := f.GetFloat64("float64"); err != nil || v != *float64Flag { + t.Errorf("GetFloat64 returned %v but float64Flag was %v", v, *float64Flag) + } + if !(*ipFlag).Equal(net.ParseIP("10.11.12.13")) { + t.Error("ip flag should be 10.11.12.13, is ", *ipFlag) + } + if v, err := f.GetIP("ip"); err != nil || !v.Equal(*ipFlag) { + t.Errorf("GetIP returned %v but ipFlag was %v", v, *ipFlag) + } + if (*maskFlag).String() != ParseIPv4Mask("255.255.255.0").String() { + t.Error("mask flag should be 255.255.255.0, is ", (*maskFlag).String()) + } + if v, err := f.GetIPv4Mask("mask"); err != nil || v.String() != (*maskFlag).String() { + t.Errorf("GetIP returned %v maskFlag was %v error was %v", v, *maskFlag, err) + } if *durationFlag != 2*time.Minute { t.Error("duration flag should be 2m, is ", *durationFlag) } + if v, err := f.GetDuration("duration"); err != nil || v != *durationFlag { + t.Error("GetDuration does not work.") + } + if _, err := f.GetInt("duration"); err == nil { + t.Error("GetInt parsed a time.Duration?!?!") + } + if *optionalIntNoValueFlag != 9 { + t.Error("optional int flag should be the default value, is ", *optionalIntNoValueFlag) + } + if *optionalIntWithValueFlag != 42 { + t.Error("optional int flag should be 42, is ", *optionalIntWithValueFlag) + } if len(f.Args()) != 1 { t.Error("expected one argument, got", len(f.Args())) } else if f.Args()[0] != extra { @@ -186,6 +340,7 @@ func TestShorthand(t *testing.T) { boolaFlag := f.BoolP("boola", "a", false, "bool value") boolbFlag := f.BoolP("boolb", "b", false, "bool2 value") boolcFlag := f.BoolP("boolc", "c", false, "bool3 value") + booldFlag := f.BoolP("boold", "d", false, "bool4 value") stringaFlag := f.StringP("stringa", "s", "0", "string value") stringzFlag := f.StringP("stringz", "z", "0", "string value") extra := "interspersed-argument" @@ -196,6 +351,7 @@ func TestShorthand(t *testing.T) { "-cs", "hello", "-z=something", + "-d=true", "--", notaflag, } @@ -215,6 +371,9 @@ func TestShorthand(t *testing.T) { if *boolcFlag != true { t.Error("boolc flag should be true, is ", *boolcFlag) } + if *booldFlag != true { + t.Error("boold flag should be true, is ", *booldFlag) + } if *stringaFlag != "hello" { t.Error("stringa flag should be `hello`, is ", *stringaFlag) } @@ -228,6 +387,9 @@ func TestShorthand(t *testing.T) { } else if f.Args()[1] != notaflag { t.Errorf("expected argument %q got %q", notaflag, f.Args()[1]) } + if f.ArgsLenAtDash() != 1 { + t.Errorf("expected argsLenAtDash %d got %d", f.ArgsLenAtDash(), 1) + } } func TestParse(t *testing.T) { @@ -239,6 +401,37 @@ func TestFlagSetParse(t *testing.T) { testParse(NewFlagSet("test", ContinueOnError), t) } +func TestChangedHelper(t *testing.T) { + f := NewFlagSet("changedtest", ContinueOnError) + _ = f.Bool("changed", false, "changed bool") + _ = f.Bool("settrue", true, "true to true") + _ = f.Bool("setfalse", false, "false to false") + _ = f.Bool("unchanged", false, "unchanged bool") + + args := []string{"--changed", "--settrue", "--setfalse=false"} + if err := f.Parse(args); err != nil { + t.Error("f.Parse() = false after Parse") + } + if !f.Changed("changed") { + t.Errorf("--changed wasn't changed!") + } + if !f.Changed("settrue") { + t.Errorf("--settrue wasn't changed!") + } + if !f.Changed("setfalse") { + t.Errorf("--setfalse wasn't changed!") + } + if f.Changed("unchanged") { + t.Errorf("--unchanged was changed!") + } + if f.Changed("invalid") { + t.Errorf("--invalid was changed!") + } + if f.ArgsLenAtDash() != -1 { + t.Errorf("Expected argsLenAtDash: %d but got %d", -1, f.ArgsLenAtDash()) + } +} + func replaceSeparators(name string, from []string, to string) string { result := name for _, sep := range from { @@ -251,6 +444,8 @@ func replaceSeparators(name string, from []string, to string) string { func wordSepNormalizeFunc(f *FlagSet, name string) NormalizedName { seps := []string{"-", "_"} name = replaceSeparators(name, seps, ".") + normalizeFlagNameInvocations++ + return NormalizedName(name) } @@ -343,6 +538,31 @@ func TestCustomNormalizedNames(t *testing.T) { } } +// Every flag we add, the name (displayed also in usage) should normalized +func TestNormalizationFuncShouldChangeFlagName(t *testing.T) { + // Test normalization after addition + f := NewFlagSet("normalized", ContinueOnError) + + f.Bool("valid_flag", false, "bool value") + if f.Lookup("valid_flag").Name != "valid_flag" { + t.Error("The new flag should have the name 'valid_flag' instead of ", f.Lookup("valid_flag").Name) + } + + f.SetNormalizeFunc(wordSepNormalizeFunc) + if f.Lookup("valid_flag").Name != "valid.flag" { + t.Error("The new flag should have the name 'valid.flag' instead of ", f.Lookup("valid_flag").Name) + } + + // Test normalization before addition + f = NewFlagSet("normalized", ContinueOnError) + f.SetNormalizeFunc(wordSepNormalizeFunc) + + f.Bool("valid_flag", false, "bool value") + if f.Lookup("valid_flag").Name != "valid.flag" { + t.Error("The new flag should have the name 'valid.flag' instead of ", f.Lookup("valid_flag").Name) + } +} + // Declare a user-defined flag type. type flagVar []string @@ -499,6 +719,9 @@ func TestTermination(t *testing.T) { if f.Args()[1] != arg2 { t.Errorf("expected argument %q got %q", arg2, f.Args()[1]) } + if f.ArgsLenAtDash() != 0 { + t.Errorf("expected argsLenAtDash %d got %d", 0, f.ArgsLenAtDash()) + } } func TestDeprecatedFlagInDocs(t *testing.T) { @@ -515,6 +738,21 @@ func TestDeprecatedFlagInDocs(t *testing.T) { } } +func TestDeprecatedFlagShorthandInDocs(t *testing.T) { + f := NewFlagSet("bob", ContinueOnError) + name := "noshorthandflag" + f.BoolP(name, "n", true, "always true") + f.MarkShorthandDeprecated("noshorthandflag", fmt.Sprintf("use --%s instead", name)) + + out := new(bytes.Buffer) + f.SetOutput(out) + f.PrintDefaults() + + if strings.Contains(out.String(), "-n,") { + t.Errorf("found deprecated flag shorthand in usage!") + } +} + func parseReturnStderr(t *testing.T, f *FlagSet, args []string) (string, error) { oldStderr := os.Stderr r, w, _ := os.Pipe() @@ -554,6 +792,24 @@ func TestDeprecatedFlagUsage(t *testing.T) { } } +func TestDeprecatedFlagShorthandUsage(t *testing.T) { + f := NewFlagSet("bob", ContinueOnError) + name := "noshorthandflag" + f.BoolP(name, "n", true, "always true") + usageMsg := fmt.Sprintf("use --%s instead", name) + f.MarkShorthandDeprecated(name, usageMsg) + + args := []string{"-n"} + out, err := parseReturnStderr(t, f, args) + if err != nil { + t.Fatal("expected no error; got ", err) + } + + if !strings.Contains(out, usageMsg) { + t.Errorf("usageMsg not printed when using a deprecated flag!") + } +} + func TestDeprecatedFlagUsageNormalized(t *testing.T) { f := NewFlagSet("bob", ContinueOnError) f.Bool("bad-double_flag", true, "always true") @@ -571,3 +827,87 @@ func TestDeprecatedFlagUsageNormalized(t *testing.T) { t.Errorf("usageMsg not printed when using a deprecated flag!") } } + +// Name normalization function should be called only once on flag addition +func TestMultipleNormalizeFlagNameInvocations(t *testing.T) { + normalizeFlagNameInvocations = 0 + + f := NewFlagSet("normalized", ContinueOnError) + f.SetNormalizeFunc(wordSepNormalizeFunc) + f.Bool("with_under_flag", false, "bool value") + + if normalizeFlagNameInvocations != 1 { + t.Fatal("Expected normalizeFlagNameInvocations to be 1; got ", normalizeFlagNameInvocations) + } +} + +// +func TestHiddenFlagInUsage(t *testing.T) { + f := NewFlagSet("bob", ContinueOnError) + f.Bool("secretFlag", true, "shhh") + f.MarkHidden("secretFlag") + + out := new(bytes.Buffer) + f.SetOutput(out) + f.PrintDefaults() + + if strings.Contains(out.String(), "secretFlag") { + t.Errorf("found hidden flag in usage!") + } +} + +// +func TestHiddenFlagUsage(t *testing.T) { + f := NewFlagSet("bob", ContinueOnError) + f.Bool("secretFlag", true, "shhh") + f.MarkHidden("secretFlag") + + args := []string{"--secretFlag"} + out, err := parseReturnStderr(t, f, args) + if err != nil { + t.Fatal("expected no error; got ", err) + } + + if strings.Contains(out, "shhh") { + t.Errorf("usage message printed when using a hidden flag!") + } +} + +const defaultOutput = ` --A for bootstrapping, allow 'any' type + --Alongflagname disable bounds checking + -C, --CCC a boolean defaulting to true (default true) + --D path set relative path for local imports + --F number a non-zero number (default 2.7) + --G float a float that defaults to zero + --N int a non-zero int (default 27) + --ND1 string[="bar"] a string with NoOptDefVal (default "foo") + --ND2 num[=4321] a num with NoOptDefVal (default 1234) + --Z int an int that defaults to zero + --maxT timeout set timeout for dial +` + +func TestPrintDefaults(t *testing.T) { + fs := NewFlagSet("print defaults test", ContinueOnError) + var buf bytes.Buffer + fs.SetOutput(&buf) + fs.Bool("A", false, "for bootstrapping, allow 'any' type") + fs.Bool("Alongflagname", false, "disable bounds checking") + fs.BoolP("CCC", "C", true, "a boolean defaulting to true") + fs.String("D", "", "set relative `path` for local imports") + fs.Float64("F", 2.7, "a non-zero `number`") + fs.Float64("G", 0, "a float that defaults to zero") + fs.Int("N", 27, "a non-zero int") + fs.Int("Z", 0, "an int that defaults to zero") + fs.Duration("maxT", 0, "set `timeout` for dial") + fs.String("ND1", "foo", "a string with NoOptDefVal") + fs.Lookup("ND1").NoOptDefVal = "bar" + fs.Int("ND2", 1234, "a `num` with NoOptDefVal") + fs.Lookup("ND2").NoOptDefVal = "4321" + fs.PrintDefaults() + got := buf.String() + if got != defaultOutput { + fmt.Println("\n" + got) + fmt.Println("\n" + defaultOutput) + t.Errorf("got %q want %q\n", got, defaultOutput) + } +} diff --git a/Godeps/_workspace/src/github.com/spf13/pflag/float32.go b/Godeps/_workspace/src/github.com/spf13/pflag/float32.go index b7ad67d9..7683fae1 100644 --- a/Godeps/_workspace/src/github.com/spf13/pflag/float32.go +++ b/Godeps/_workspace/src/github.com/spf13/pflag/float32.go @@ -25,13 +25,30 @@ func (f *float32Value) Type() string { func (f *float32Value) String() string { return fmt.Sprintf("%v", *f) } +func float32Conv(sval string) (interface{}, error) { + v, err := strconv.ParseFloat(sval, 32) + if err != nil { + return 0, err + } + return float32(v), nil +} + +// GetFloat32 return the float32 value of a flag with the given name +func (f *FlagSet) GetFloat32(name string) (float32, error) { + val, err := f.getFlagType(name, "float32", float32Conv) + if err != nil { + return 0, err + } + return val.(float32), nil +} + // Float32Var defines a float32 flag with specified name, default value, and usage string. // The argument p points to a float32 variable in which to store the value of the flag. func (f *FlagSet) Float32Var(p *float32, name string, value float32, usage string) { f.VarP(newFloat32Value(value, p), name, "", usage) } -// Like Float32Var, but accepts a shorthand letter that can be used after a single dash. +// Float32VarP is like Float32Var, but accepts a shorthand letter that can be used after a single dash. func (f *FlagSet) Float32VarP(p *float32, name, shorthand string, value float32, usage string) { f.VarP(newFloat32Value(value, p), name, shorthand, usage) } @@ -42,7 +59,7 @@ func Float32Var(p *float32, name string, value float32, usage string) { CommandLine.VarP(newFloat32Value(value, p), name, "", usage) } -// Like Float32Var, but accepts a shorthand letter that can be used after a single dash. +// Float32VarP is like Float32Var, but accepts a shorthand letter that can be used after a single dash. func Float32VarP(p *float32, name, shorthand string, value float32, usage string) { CommandLine.VarP(newFloat32Value(value, p), name, shorthand, usage) } @@ -55,7 +72,7 @@ func (f *FlagSet) Float32(name string, value float32, usage string) *float32 { return p } -// Like Float32, but accepts a shorthand letter that can be used after a single dash. +// Float32P is like Float32, but accepts a shorthand letter that can be used after a single dash. func (f *FlagSet) Float32P(name, shorthand string, value float32, usage string) *float32 { p := new(float32) f.Float32VarP(p, name, shorthand, value, usage) @@ -68,7 +85,7 @@ func Float32(name string, value float32, usage string) *float32 { return CommandLine.Float32P(name, "", value, usage) } -// Like Float32, but accepts a shorthand letter that can be used after a single dash. +// Float32P is like Float32, but accepts a shorthand letter that can be used after a single dash. func Float32P(name, shorthand string, value float32, usage string) *float32 { return CommandLine.Float32P(name, shorthand, value, usage) } diff --git a/Godeps/_workspace/src/github.com/spf13/pflag/float64.go b/Godeps/_workspace/src/github.com/spf13/pflag/float64.go index 03155123..50fbf8cc 100644 --- a/Godeps/_workspace/src/github.com/spf13/pflag/float64.go +++ b/Godeps/_workspace/src/github.com/spf13/pflag/float64.go @@ -25,13 +25,26 @@ func (f *float64Value) Type() string { func (f *float64Value) String() string { return fmt.Sprintf("%v", *f) } +func float64Conv(sval string) (interface{}, error) { + return strconv.ParseFloat(sval, 64) +} + +// GetFloat64 return the float64 value of a flag with the given name +func (f *FlagSet) GetFloat64(name string) (float64, error) { + val, err := f.getFlagType(name, "float64", float64Conv) + if err != nil { + return 0, err + } + return val.(float64), nil +} + // Float64Var defines a float64 flag with specified name, default value, and usage string. // The argument p points to a float64 variable in which to store the value of the flag. func (f *FlagSet) Float64Var(p *float64, name string, value float64, usage string) { f.VarP(newFloat64Value(value, p), name, "", usage) } -// Like Float64Var, but accepts a shorthand letter that can be used after a single dash. +// Float64VarP is like Float64Var, but accepts a shorthand letter that can be used after a single dash. func (f *FlagSet) Float64VarP(p *float64, name, shorthand string, value float64, usage string) { f.VarP(newFloat64Value(value, p), name, shorthand, usage) } @@ -42,7 +55,7 @@ func Float64Var(p *float64, name string, value float64, usage string) { CommandLine.VarP(newFloat64Value(value, p), name, "", usage) } -// Like Float64Var, but accepts a shorthand letter that can be used after a single dash. +// Float64VarP is like Float64Var, but accepts a shorthand letter that can be used after a single dash. func Float64VarP(p *float64, name, shorthand string, value float64, usage string) { CommandLine.VarP(newFloat64Value(value, p), name, shorthand, usage) } @@ -55,7 +68,7 @@ func (f *FlagSet) Float64(name string, value float64, usage string) *float64 { return p } -// Like Float64, but accepts a shorthand letter that can be used after a single dash. +// Float64P is like Float64, but accepts a shorthand letter that can be used after a single dash. func (f *FlagSet) Float64P(name, shorthand string, value float64, usage string) *float64 { p := new(float64) f.Float64VarP(p, name, shorthand, value, usage) @@ -68,7 +81,7 @@ func Float64(name string, value float64, usage string) *float64 { return CommandLine.Float64P(name, "", value, usage) } -// Like Float64, but accepts a shorthand letter that can be used after a single dash. +// Float64P is like Float64, but accepts a shorthand letter that can be used after a single dash. func Float64P(name, shorthand string, value float64, usage string) *float64 { return CommandLine.Float64P(name, shorthand, value, usage) } diff --git a/Godeps/_workspace/src/github.com/spf13/pflag/golangflag.go b/Godeps/_workspace/src/github.com/spf13/pflag/golangflag.go new file mode 100644 index 00000000..b056147f --- /dev/null +++ b/Godeps/_workspace/src/github.com/spf13/pflag/golangflag.go @@ -0,0 +1,104 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package pflag + +import ( + goflag "flag" + "fmt" + "reflect" + "strings" +) + +var _ = fmt.Print + +// flagValueWrapper implements pflag.Value around a flag.Value. The main +// difference here is the addition of the Type method that returns a string +// name of the type. As this is generally unknown, we approximate that with +// reflection. +type flagValueWrapper struct { + inner goflag.Value + flagType string +} + +// We are just copying the boolFlag interface out of goflag as that is what +// they use to decide if a flag should get "true" when no arg is given. +type goBoolFlag interface { + goflag.Value + IsBoolFlag() bool +} + +func wrapFlagValue(v goflag.Value) Value { + // If the flag.Value happens to also be a pflag.Value, just use it directly. + if pv, ok := v.(Value); ok { + return pv + } + + pv := &flagValueWrapper{ + inner: v, + } + + t := reflect.TypeOf(v) + if t.Kind() == reflect.Interface || t.Kind() == reflect.Ptr { + t = t.Elem() + } + + pv.flagType = strings.TrimSuffix(t.Name(), "Value") + return pv +} + +func (v *flagValueWrapper) String() string { + return v.inner.String() +} + +func (v *flagValueWrapper) Set(s string) error { + return v.inner.Set(s) +} + +func (v *flagValueWrapper) Type() string { + return v.flagType +} + +// PFlagFromGoFlag will return a *pflag.Flag given a *flag.Flag +// If the *flag.Flag.Name was a single character (ex: `v`) it will be accessiblei +// with both `-v` and `--v` in flags. If the golang flag was more than a single +// character (ex: `verbose`) it will only be accessible via `--verbose` +func PFlagFromGoFlag(goflag *goflag.Flag) *Flag { + // Remember the default value as a string; it won't change. + flag := &Flag{ + Name: goflag.Name, + Usage: goflag.Usage, + Value: wrapFlagValue(goflag.Value), + // Looks like golang flags don't set DefValue correctly :-( + //DefValue: goflag.DefValue, + DefValue: goflag.Value.String(), + } + // Ex: if the golang flag was -v, allow both -v and --v to work + if len(flag.Name) == 1 { + flag.Shorthand = flag.Name + } + if fv, ok := goflag.Value.(goBoolFlag); ok && fv.IsBoolFlag() { + flag.NoOptDefVal = "true" + } + return flag +} + +// AddGoFlag will add the given *flag.Flag to the pflag.FlagSet +func (f *FlagSet) AddGoFlag(goflag *goflag.Flag) { + if f.Lookup(goflag.Name) != nil { + return + } + newflag := PFlagFromGoFlag(goflag) + f.AddFlag(newflag) +} + +// AddGoFlagSet will add the given *flag.FlagSet to the pflag.FlagSet +func (f *FlagSet) AddGoFlagSet(newSet *goflag.FlagSet) { + if newSet == nil { + return + } + newSet.VisitAll(func(goflag *goflag.Flag) { + f.AddGoFlag(goflag) + }) +} diff --git a/Godeps/_workspace/src/github.com/spf13/pflag/golangflag_test.go b/Godeps/_workspace/src/github.com/spf13/pflag/golangflag_test.go new file mode 100644 index 00000000..77e2d7d8 --- /dev/null +++ b/Godeps/_workspace/src/github.com/spf13/pflag/golangflag_test.go @@ -0,0 +1,39 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package pflag + +import ( + goflag "flag" + "testing" +) + +func TestGoflags(t *testing.T) { + goflag.String("stringFlag", "stringFlag", "stringFlag") + goflag.Bool("boolFlag", false, "boolFlag") + + f := NewFlagSet("test", ContinueOnError) + + f.AddGoFlagSet(goflag.CommandLine) + err := f.Parse([]string{"--stringFlag=bob", "--boolFlag"}) + if err != nil { + t.Fatal("expected no error; get", err) + } + + getString, err := f.GetString("stringFlag") + if err != nil { + t.Fatal("expected no error; get", err) + } + if getString != "bob" { + t.Fatalf("expected getString=bob but got getString=%s", getString) + } + + getBool, err := f.GetBool("boolFlag") + if err != nil { + t.Fatal("expected no error; get", err) + } + if getBool != true { + t.Fatalf("expected getBool=true but got getBool=%v", getBool) + } +} diff --git a/Godeps/_workspace/src/github.com/spf13/pflag/int.go b/Godeps/_workspace/src/github.com/spf13/pflag/int.go index dca9da6e..b6560368 100644 --- a/Godeps/_workspace/src/github.com/spf13/pflag/int.go +++ b/Godeps/_workspace/src/github.com/spf13/pflag/int.go @@ -25,13 +25,26 @@ func (i *intValue) Type() string { func (i *intValue) String() string { return fmt.Sprintf("%v", *i) } +func intConv(sval string) (interface{}, error) { + return strconv.Atoi(sval) +} + +// GetInt return the int value of a flag with the given name +func (f *FlagSet) GetInt(name string) (int, error) { + val, err := f.getFlagType(name, "int", intConv) + if err != nil { + return 0, err + } + return val.(int), nil +} + // IntVar defines an int flag with specified name, default value, and usage string. // The argument p points to an int variable in which to store the value of the flag. func (f *FlagSet) IntVar(p *int, name string, value int, usage string) { f.VarP(newIntValue(value, p), name, "", usage) } -// Like IntVar, but accepts a shorthand letter that can be used after a single dash. +// IntVarP is like IntVar, but accepts a shorthand letter that can be used after a single dash. func (f *FlagSet) IntVarP(p *int, name, shorthand string, value int, usage string) { f.VarP(newIntValue(value, p), name, shorthand, usage) } @@ -42,7 +55,7 @@ func IntVar(p *int, name string, value int, usage string) { CommandLine.VarP(newIntValue(value, p), name, "", usage) } -// Like IntVar, but accepts a shorthand letter that can be used after a single dash. +// IntVarP is like IntVar, but accepts a shorthand letter that can be used after a single dash. func IntVarP(p *int, name, shorthand string, value int, usage string) { CommandLine.VarP(newIntValue(value, p), name, shorthand, usage) } @@ -55,7 +68,7 @@ func (f *FlagSet) Int(name string, value int, usage string) *int { return p } -// Like Int, but accepts a shorthand letter that can be used after a single dash. +// IntP is like Int, but accepts a shorthand letter that can be used after a single dash. func (f *FlagSet) IntP(name, shorthand string, value int, usage string) *int { p := new(int) f.IntVarP(p, name, shorthand, value, usage) @@ -68,7 +81,7 @@ func Int(name string, value int, usage string) *int { return CommandLine.IntP(name, "", value, usage) } -// Like Int, but accepts a shorthand letter that can be used after a single dash. +// IntP is like Int, but accepts a shorthand letter that can be used after a single dash. func IntP(name, shorthand string, value int, usage string) *int { return CommandLine.IntP(name, shorthand, value, usage) } diff --git a/Godeps/_workspace/src/github.com/spf13/pflag/int32.go b/Godeps/_workspace/src/github.com/spf13/pflag/int32.go index 18eaacd6..41659a9a 100644 --- a/Godeps/_workspace/src/github.com/spf13/pflag/int32.go +++ b/Godeps/_workspace/src/github.com/spf13/pflag/int32.go @@ -25,13 +25,30 @@ func (i *int32Value) Type() string { func (i *int32Value) String() string { return fmt.Sprintf("%v", *i) } +func int32Conv(sval string) (interface{}, error) { + v, err := strconv.ParseInt(sval, 0, 32) + if err != nil { + return 0, err + } + return int32(v), nil +} + +// GetInt32 return the int32 value of a flag with the given name +func (f *FlagSet) GetInt32(name string) (int32, error) { + val, err := f.getFlagType(name, "int32", int32Conv) + if err != nil { + return 0, err + } + return val.(int32), nil +} + // Int32Var defines an int32 flag with specified name, default value, and usage string. // The argument p points to an int32 variable in which to store the value of the flag. func (f *FlagSet) Int32Var(p *int32, name string, value int32, usage string) { f.VarP(newInt32Value(value, p), name, "", usage) } -// Like Int32Var, but accepts a shorthand letter that can be used after a single dash. +// Int32VarP is like Int32Var, but accepts a shorthand letter that can be used after a single dash. func (f *FlagSet) Int32VarP(p *int32, name, shorthand string, value int32, usage string) { f.VarP(newInt32Value(value, p), name, shorthand, usage) } @@ -42,7 +59,7 @@ func Int32Var(p *int32, name string, value int32, usage string) { CommandLine.VarP(newInt32Value(value, p), name, "", usage) } -// Like Int32Var, but accepts a shorthand letter that can be used after a single dash. +// Int32VarP is like Int32Var, but accepts a shorthand letter that can be used after a single dash. func Int32VarP(p *int32, name, shorthand string, value int32, usage string) { CommandLine.VarP(newInt32Value(value, p), name, shorthand, usage) } @@ -55,7 +72,7 @@ func (f *FlagSet) Int32(name string, value int32, usage string) *int32 { return p } -// Like Int32, but accepts a shorthand letter that can be used after a single dash. +// Int32P is like Int32, but accepts a shorthand letter that can be used after a single dash. func (f *FlagSet) Int32P(name, shorthand string, value int32, usage string) *int32 { p := new(int32) f.Int32VarP(p, name, shorthand, value, usage) @@ -68,7 +85,7 @@ func Int32(name string, value int32, usage string) *int32 { return CommandLine.Int32P(name, "", value, usage) } -// Like Int32, but accepts a shorthand letter that can be used after a single dash. +// Int32P is like Int32, but accepts a shorthand letter that can be used after a single dash. func Int32P(name, shorthand string, value int32, usage string) *int32 { return CommandLine.Int32P(name, shorthand, value, usage) } diff --git a/Godeps/_workspace/src/github.com/spf13/pflag/int64.go b/Godeps/_workspace/src/github.com/spf13/pflag/int64.go index 0114aaaa..6e67e380 100644 --- a/Godeps/_workspace/src/github.com/spf13/pflag/int64.go +++ b/Godeps/_workspace/src/github.com/spf13/pflag/int64.go @@ -25,13 +25,26 @@ func (i *int64Value) Type() string { func (i *int64Value) String() string { return fmt.Sprintf("%v", *i) } +func int64Conv(sval string) (interface{}, error) { + return strconv.ParseInt(sval, 0, 64) +} + +// GetInt64 return the int64 value of a flag with the given name +func (f *FlagSet) GetInt64(name string) (int64, error) { + val, err := f.getFlagType(name, "int64", int64Conv) + if err != nil { + return 0, err + } + return val.(int64), nil +} + // Int64Var defines an int64 flag with specified name, default value, and usage string. // The argument p points to an int64 variable in which to store the value of the flag. func (f *FlagSet) Int64Var(p *int64, name string, value int64, usage string) { f.VarP(newInt64Value(value, p), name, "", usage) } -// Like Int64Var, but accepts a shorthand letter that can be used after a single dash. +// Int64VarP is like Int64Var, but accepts a shorthand letter that can be used after a single dash. func (f *FlagSet) Int64VarP(p *int64, name, shorthand string, value int64, usage string) { f.VarP(newInt64Value(value, p), name, shorthand, usage) } @@ -42,7 +55,7 @@ func Int64Var(p *int64, name string, value int64, usage string) { CommandLine.VarP(newInt64Value(value, p), name, "", usage) } -// Like Int64Var, but accepts a shorthand letter that can be used after a single dash. +// Int64VarP is like Int64Var, but accepts a shorthand letter that can be used after a single dash. func Int64VarP(p *int64, name, shorthand string, value int64, usage string) { CommandLine.VarP(newInt64Value(value, p), name, shorthand, usage) } @@ -55,7 +68,7 @@ func (f *FlagSet) Int64(name string, value int64, usage string) *int64 { return p } -// Like Int64, but accepts a shorthand letter that can be used after a single dash. +// Int64P is like Int64, but accepts a shorthand letter that can be used after a single dash. func (f *FlagSet) Int64P(name, shorthand string, value int64, usage string) *int64 { p := new(int64) f.Int64VarP(p, name, shorthand, value, usage) @@ -68,7 +81,7 @@ func Int64(name string, value int64, usage string) *int64 { return CommandLine.Int64P(name, "", value, usage) } -// Like Int64, but accepts a shorthand letter that can be used after a single dash. +// Int64P is like Int64, but accepts a shorthand letter that can be used after a single dash. func Int64P(name, shorthand string, value int64, usage string) *int64 { return CommandLine.Int64P(name, shorthand, value, usage) } diff --git a/Godeps/_workspace/src/github.com/spf13/pflag/int8.go b/Godeps/_workspace/src/github.com/spf13/pflag/int8.go index aab1022f..400db21f 100644 --- a/Godeps/_workspace/src/github.com/spf13/pflag/int8.go +++ b/Godeps/_workspace/src/github.com/spf13/pflag/int8.go @@ -25,13 +25,30 @@ func (i *int8Value) Type() string { func (i *int8Value) String() string { return fmt.Sprintf("%v", *i) } +func int8Conv(sval string) (interface{}, error) { + v, err := strconv.ParseInt(sval, 0, 8) + if err != nil { + return 0, err + } + return int8(v), nil +} + +// GetInt8 return the int8 value of a flag with the given name +func (f *FlagSet) GetInt8(name string) (int8, error) { + val, err := f.getFlagType(name, "int8", int8Conv) + if err != nil { + return 0, err + } + return val.(int8), nil +} + // Int8Var defines an int8 flag with specified name, default value, and usage string. // The argument p points to an int8 variable in which to store the value of the flag. func (f *FlagSet) Int8Var(p *int8, name string, value int8, usage string) { f.VarP(newInt8Value(value, p), name, "", usage) } -// Like Int8Var, but accepts a shorthand letter that can be used after a single dash. +// Int8VarP is like Int8Var, but accepts a shorthand letter that can be used after a single dash. func (f *FlagSet) Int8VarP(p *int8, name, shorthand string, value int8, usage string) { f.VarP(newInt8Value(value, p), name, shorthand, usage) } @@ -42,7 +59,7 @@ func Int8Var(p *int8, name string, value int8, usage string) { CommandLine.VarP(newInt8Value(value, p), name, "", usage) } -// Like Int8Var, but accepts a shorthand letter that can be used after a single dash. +// Int8VarP is like Int8Var, but accepts a shorthand letter that can be used after a single dash. func Int8VarP(p *int8, name, shorthand string, value int8, usage string) { CommandLine.VarP(newInt8Value(value, p), name, shorthand, usage) } @@ -55,7 +72,7 @@ func (f *FlagSet) Int8(name string, value int8, usage string) *int8 { return p } -// Like Int8, but accepts a shorthand letter that can be used after a single dash. +// Int8P is like Int8, but accepts a shorthand letter that can be used after a single dash. func (f *FlagSet) Int8P(name, shorthand string, value int8, usage string) *int8 { p := new(int8) f.Int8VarP(p, name, shorthand, value, usage) @@ -68,7 +85,7 @@ func Int8(name string, value int8, usage string) *int8 { return CommandLine.Int8P(name, "", value, usage) } -// Like Int8, but accepts a shorthand letter that can be used after a single dash. +// Int8P is like Int8, but accepts a shorthand letter that can be used after a single dash. func Int8P(name, shorthand string, value int8, usage string) *int8 { return CommandLine.Int8P(name, shorthand, value, usage) } diff --git a/Godeps/_workspace/src/github.com/spf13/pflag/int_slice.go b/Godeps/_workspace/src/github.com/spf13/pflag/int_slice.go new file mode 100644 index 00000000..1e7c9edd --- /dev/null +++ b/Godeps/_workspace/src/github.com/spf13/pflag/int_slice.go @@ -0,0 +1,128 @@ +package pflag + +import ( + "fmt" + "strconv" + "strings" +) + +// -- intSlice Value +type intSliceValue struct { + value *[]int + changed bool +} + +func newIntSliceValue(val []int, p *[]int) *intSliceValue { + isv := new(intSliceValue) + isv.value = p + *isv.value = val + return isv +} + +func (s *intSliceValue) Set(val string) error { + ss := strings.Split(val, ",") + out := make([]int, len(ss)) + for i, d := range ss { + var err error + out[i], err = strconv.Atoi(d) + if err != nil { + return err + } + + } + if !s.changed { + *s.value = out + } else { + *s.value = append(*s.value, out...) + } + s.changed = true + return nil +} + +func (s *intSliceValue) Type() string { + return "intSlice" +} + +func (s *intSliceValue) String() string { + out := make([]string, len(*s.value)) + for i, d := range *s.value { + out[i] = fmt.Sprintf("%d", d) + } + return "[" + strings.Join(out, ",") + "]" +} + +func intSliceConv(val string) (interface{}, error) { + val = strings.Trim(val, "[]") + // Empty string would cause a slice with one (empty) entry + if len(val) == 0 { + return []int{}, nil + } + ss := strings.Split(val, ",") + out := make([]int, len(ss)) + for i, d := range ss { + var err error + out[i], err = strconv.Atoi(d) + if err != nil { + return nil, err + } + + } + return out, nil +} + +// GetIntSlice return the []int value of a flag with the given name +func (f *FlagSet) GetIntSlice(name string) ([]int, error) { + val, err := f.getFlagType(name, "intSlice", intSliceConv) + if err != nil { + return []int{}, err + } + return val.([]int), nil +} + +// IntSliceVar defines a intSlice flag with specified name, default value, and usage string. +// The argument p points to a []int variable in which to store the value of the flag. +func (f *FlagSet) IntSliceVar(p *[]int, name string, value []int, usage string) { + f.VarP(newIntSliceValue(value, p), name, "", usage) +} + +// IntSliceVarP is like IntSliceVar, but accepts a shorthand letter that can be used after a single dash. +func (f *FlagSet) IntSliceVarP(p *[]int, name, shorthand string, value []int, usage string) { + f.VarP(newIntSliceValue(value, p), name, shorthand, usage) +} + +// IntSliceVar defines a int[] flag with specified name, default value, and usage string. +// The argument p points to a int[] variable in which to store the value of the flag. +func IntSliceVar(p *[]int, name string, value []int, usage string) { + CommandLine.VarP(newIntSliceValue(value, p), name, "", usage) +} + +// IntSliceVarP is like IntSliceVar, but accepts a shorthand letter that can be used after a single dash. +func IntSliceVarP(p *[]int, name, shorthand string, value []int, usage string) { + CommandLine.VarP(newIntSliceValue(value, p), name, shorthand, usage) +} + +// IntSlice defines a []int flag with specified name, default value, and usage string. +// The return value is the address of a []int variable that stores the value of the flag. +func (f *FlagSet) IntSlice(name string, value []int, usage string) *[]int { + p := []int{} + f.IntSliceVarP(&p, name, "", value, usage) + return &p +} + +// IntSliceP is like IntSlice, but accepts a shorthand letter that can be used after a single dash. +func (f *FlagSet) IntSliceP(name, shorthand string, value []int, usage string) *[]int { + p := []int{} + f.IntSliceVarP(&p, name, shorthand, value, usage) + return &p +} + +// IntSlice defines a []int flag with specified name, default value, and usage string. +// The return value is the address of a []int variable that stores the value of the flag. +func IntSlice(name string, value []int, usage string) *[]int { + return CommandLine.IntSliceP(name, "", value, usage) +} + +// IntSliceP is like IntSlice, but accepts a shorthand letter that can be used after a single dash. +func IntSliceP(name, shorthand string, value []int, usage string) *[]int { + return CommandLine.IntSliceP(name, shorthand, value, usage) +} diff --git a/Godeps/_workspace/src/github.com/spf13/pflag/int_slice_test.go b/Godeps/_workspace/src/github.com/spf13/pflag/int_slice_test.go new file mode 100644 index 00000000..5f2eee66 --- /dev/null +++ b/Godeps/_workspace/src/github.com/spf13/pflag/int_slice_test.go @@ -0,0 +1,162 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package pflag + +import ( + "fmt" + "strconv" + "strings" + "testing" +) + +func setUpISFlagSet(isp *[]int) *FlagSet { + f := NewFlagSet("test", ContinueOnError) + f.IntSliceVar(isp, "is", []int{}, "Command separated list!") + return f +} + +func setUpISFlagSetWithDefault(isp *[]int) *FlagSet { + f := NewFlagSet("test", ContinueOnError) + f.IntSliceVar(isp, "is", []int{0, 1}, "Command separated list!") + return f +} + +func TestEmptyIS(t *testing.T) { + var is []int + f := setUpISFlagSet(&is) + err := f.Parse([]string{}) + if err != nil { + t.Fatal("expected no error; got", err) + } + + getIS, err := f.GetIntSlice("is") + if err != nil { + t.Fatal("got an error from GetIntSlice():", err) + } + if len(getIS) != 0 { + t.Fatalf("got is %v with len=%d but expected length=0", getIS, len(getIS)) + } +} + +func TestIS(t *testing.T) { + var is []int + f := setUpISFlagSet(&is) + + vals := []string{"1", "2", "4", "3"} + arg := fmt.Sprintf("--is=%s", strings.Join(vals, ",")) + err := f.Parse([]string{arg}) + if err != nil { + t.Fatal("expected no error; got", err) + } + for i, v := range is { + d, err := strconv.Atoi(vals[i]) + if err != nil { + t.Fatalf("got error: %v", err) + } + if d != v { + t.Fatalf("expected is[%d] to be %s but got: %d", i, vals[i], v) + } + } + getIS, err := f.GetIntSlice("is") + for i, v := range getIS { + d, err := strconv.Atoi(vals[i]) + if err != nil { + t.Fatalf("got error: %v", err) + } + if d != v { + t.Fatalf("expected is[%d] to be %s but got: %d from GetIntSlice", i, vals[i], v) + } + } +} + +func TestISDefault(t *testing.T) { + var is []int + f := setUpISFlagSetWithDefault(&is) + + vals := []string{"0", "1"} + + err := f.Parse([]string{}) + if err != nil { + t.Fatal("expected no error; got", err) + } + for i, v := range is { + d, err := strconv.Atoi(vals[i]) + if err != nil { + t.Fatalf("got error: %v", err) + } + if d != v { + t.Fatalf("expected is[%d] to be %d but got: %d", i, d, v) + } + } + + getIS, err := f.GetIntSlice("is") + if err != nil { + t.Fatal("got an error from GetIntSlice():", err) + } + for i, v := range getIS { + d, err := strconv.Atoi(vals[i]) + if err != nil { + t.Fatal("got an error from GetIntSlice():", err) + } + if d != v { + t.Fatalf("expected is[%d] to be %d from GetIntSlice but got: %d", i, d, v) + } + } +} + +func TestISWithDefault(t *testing.T) { + var is []int + f := setUpISFlagSetWithDefault(&is) + + vals := []string{"1", "2"} + arg := fmt.Sprintf("--is=%s", strings.Join(vals, ",")) + err := f.Parse([]string{arg}) + if err != nil { + t.Fatal("expected no error; got", err) + } + for i, v := range is { + d, err := strconv.Atoi(vals[i]) + if err != nil { + t.Fatalf("got error: %v", err) + } + if d != v { + t.Fatalf("expected is[%d] to be %d but got: %d", i, d, v) + } + } + + getIS, err := f.GetIntSlice("is") + if err != nil { + t.Fatal("got an error from GetIntSlice():", err) + } + for i, v := range getIS { + d, err := strconv.Atoi(vals[i]) + if err != nil { + t.Fatalf("got error: %v", err) + } + if d != v { + t.Fatalf("expected is[%d] to be %d from GetIntSlice but got: %d", i, d, v) + } + } +} + +func TestISCalledTwice(t *testing.T) { + var is []int + f := setUpISFlagSet(&is) + + in := []string{"1,2", "3"} + expected := []int{1, 2, 3} + argfmt := "--is=%s" + arg1 := fmt.Sprintf(argfmt, in[0]) + arg2 := fmt.Sprintf(argfmt, in[1]) + err := f.Parse([]string{arg1, arg2}) + if err != nil { + t.Fatal("expected no error; got", err) + } + for i, v := range is { + if expected[i] != v { + t.Fatalf("expected is[%d] to be %d but got: %d", i, expected[i], v) + } + } +} diff --git a/Godeps/_workspace/src/github.com/spf13/pflag/ip.go b/Godeps/_workspace/src/github.com/spf13/pflag/ip.go index 9d53bd3e..88a17430 100644 --- a/Godeps/_workspace/src/github.com/spf13/pflag/ip.go +++ b/Godeps/_workspace/src/github.com/spf13/pflag/ip.go @@ -3,8 +3,11 @@ package pflag import ( "fmt" "net" + "strings" ) +var _ = strings.TrimSpace + // -- net.IP value type ipValue net.IP @@ -15,28 +18,42 @@ func newIPValue(val net.IP, p *net.IP) *ipValue { func (i *ipValue) String() string { return net.IP(*i).String() } func (i *ipValue) Set(s string) error { - ip := net.ParseIP(s) + ip := net.ParseIP(strings.TrimSpace(s)) if ip == nil { return fmt.Errorf("failed to parse IP: %q", s) } *i = ipValue(ip) return nil } -func (i *ipValue) Get() interface{} { - return net.IP(*i) -} func (i *ipValue) Type() string { return "ip" } +func ipConv(sval string) (interface{}, error) { + ip := net.ParseIP(sval) + if ip != nil { + return ip, nil + } + return nil, fmt.Errorf("invalid string being converted to IP address: %s", sval) +} + +// GetIP return the net.IP value of a flag with the given name +func (f *FlagSet) GetIP(name string) (net.IP, error) { + val, err := f.getFlagType(name, "ip", ipConv) + if err != nil { + return nil, err + } + return val.(net.IP), nil +} + // IPVar defines an net.IP flag with specified name, default value, and usage string. // The argument p points to an net.IP variable in which to store the value of the flag. func (f *FlagSet) IPVar(p *net.IP, name string, value net.IP, usage string) { f.VarP(newIPValue(value, p), name, "", usage) } -// Like IPVar, but accepts a shorthand letter that can be used after a single dash. +// IPVarP is like IPVar, but accepts a shorthand letter that can be used after a single dash. func (f *FlagSet) IPVarP(p *net.IP, name, shorthand string, value net.IP, usage string) { f.VarP(newIPValue(value, p), name, shorthand, usage) } @@ -47,7 +64,7 @@ func IPVar(p *net.IP, name string, value net.IP, usage string) { CommandLine.VarP(newIPValue(value, p), name, "", usage) } -// Like IPVar, but accepts a shorthand letter that can be used after a single dash. +// IPVarP is like IPVar, but accepts a shorthand letter that can be used after a single dash. func IPVarP(p *net.IP, name, shorthand string, value net.IP, usage string) { CommandLine.VarP(newIPValue(value, p), name, shorthand, usage) } @@ -60,7 +77,7 @@ func (f *FlagSet) IP(name string, value net.IP, usage string) *net.IP { return p } -// Like IP, but accepts a shorthand letter that can be used after a single dash. +// IPP is like IP, but accepts a shorthand letter that can be used after a single dash. func (f *FlagSet) IPP(name, shorthand string, value net.IP, usage string) *net.IP { p := new(net.IP) f.IPVarP(p, name, shorthand, value, usage) @@ -73,7 +90,7 @@ func IP(name string, value net.IP, usage string) *net.IP { return CommandLine.IPP(name, "", value, usage) } -// Like IP, but accepts a shorthand letter that can be used after a single dash. +// IPP is like IP, but accepts a shorthand letter that can be used after a single dash. func IPP(name, shorthand string, value net.IP, usage string) *net.IP { return CommandLine.IPP(name, shorthand, value, usage) } diff --git a/Godeps/_workspace/src/github.com/spf13/pflag/ip_test.go b/Godeps/_workspace/src/github.com/spf13/pflag/ip_test.go new file mode 100644 index 00000000..1fec50e4 --- /dev/null +++ b/Godeps/_workspace/src/github.com/spf13/pflag/ip_test.go @@ -0,0 +1,63 @@ +package pflag + +import ( + "fmt" + "net" + "os" + "testing" +) + +func setUpIP(ip *net.IP) *FlagSet { + f := NewFlagSet("test", ContinueOnError) + f.IPVar(ip, "address", net.ParseIP("0.0.0.0"), "IP Address") + return f +} + +func TestIP(t *testing.T) { + testCases := []struct { + input string + success bool + expected string + }{ + {"0.0.0.0", true, "0.0.0.0"}, + {" 0.0.0.0 ", true, "0.0.0.0"}, + {"1.2.3.4", true, "1.2.3.4"}, + {"127.0.0.1", true, "127.0.0.1"}, + {"255.255.255.255", true, "255.255.255.255"}, + {"", false, ""}, + {"0", false, ""}, + {"localhost", false, ""}, + {"0.0.0", false, ""}, + {"0.0.0.", false, ""}, + {"0.0.0.0.", false, ""}, + {"0.0.0.256", false, ""}, + {"0 . 0 . 0 . 0", false, ""}, + } + + devnull, _ := os.Open(os.DevNull) + os.Stderr = devnull + for i := range testCases { + var addr net.IP + f := setUpIP(&addr) + + tc := &testCases[i] + + arg := fmt.Sprintf("--address=%s", tc.input) + err := f.Parse([]string{arg}) + if err != nil && tc.success == true { + t.Errorf("expected success, got %q", err) + continue + } else if err == nil && tc.success == false { + t.Errorf("expected failure") + continue + } else if tc.success { + ip, err := f.GetIP("address") + if err != nil { + t.Errorf("Got error trying to fetch the IP flag: %v", err) + } + if ip.String() != tc.expected { + t.Errorf("expected %q, got %q", tc.expected, ip.String()) + } + } + } +} diff --git a/Godeps/_workspace/src/github.com/spf13/pflag/ipmask.go b/Godeps/_workspace/src/github.com/spf13/pflag/ipmask.go index 6f85be9b..5bd44bd2 100644 --- a/Godeps/_workspace/src/github.com/spf13/pflag/ipmask.go +++ b/Godeps/_workspace/src/github.com/spf13/pflag/ipmask.go @@ -3,6 +3,7 @@ package pflag import ( "fmt" "net" + "strconv" ) // -- net.IPMask value @@ -22,31 +23,63 @@ func (i *ipMaskValue) Set(s string) error { *i = ipMaskValue(ip) return nil } -func (i *ipMaskValue) Get() interface{} { - return net.IPMask(*i) -} func (i *ipMaskValue) Type() string { return "ipMask" } -// Parse IPv4 netmask written in IP form (e.g. 255.255.255.0). +// ParseIPv4Mask written in IP form (e.g. 255.255.255.0). // This function should really belong to the net package. func ParseIPv4Mask(s string) net.IPMask { mask := net.ParseIP(s) if mask == nil { - return nil + if len(s) != 8 { + return nil + } + // net.IPMask.String() actually outputs things like ffffff00 + // so write a horrible parser for that as well :-( + m := []int{} + for i := 0; i < 4; i++ { + b := "0x" + s[2*i:2*i+2] + d, err := strconv.ParseInt(b, 0, 0) + if err != nil { + return nil + } + m = append(m, int(d)) + } + s := fmt.Sprintf("%d.%d.%d.%d", m[0], m[1], m[2], m[3]) + mask = net.ParseIP(s) + if mask == nil { + return nil + } } return net.IPv4Mask(mask[12], mask[13], mask[14], mask[15]) } +func parseIPv4Mask(sval string) (interface{}, error) { + mask := ParseIPv4Mask(sval) + if mask == nil { + return nil, fmt.Errorf("unable to parse %s as net.IPMask", sval) + } + return mask, nil +} + +// GetIPv4Mask return the net.IPv4Mask value of a flag with the given name +func (f *FlagSet) GetIPv4Mask(name string) (net.IPMask, error) { + val, err := f.getFlagType(name, "ipMask", parseIPv4Mask) + if err != nil { + return nil, err + } + return val.(net.IPMask), nil +} + // IPMaskVar defines an net.IPMask flag with specified name, default value, and usage string. // The argument p points to an net.IPMask variable in which to store the value of the flag. func (f *FlagSet) IPMaskVar(p *net.IPMask, name string, value net.IPMask, usage string) { f.VarP(newIPMaskValue(value, p), name, "", usage) } -// Like IPMaskVar, but accepts a shorthand letter that can be used after a single dash. +// IPMaskVarP is like IPMaskVar, but accepts a shorthand letter that can be used after a single dash. func (f *FlagSet) IPMaskVarP(p *net.IPMask, name, shorthand string, value net.IPMask, usage string) { f.VarP(newIPMaskValue(value, p), name, shorthand, usage) } @@ -57,7 +90,7 @@ func IPMaskVar(p *net.IPMask, name string, value net.IPMask, usage string) { CommandLine.VarP(newIPMaskValue(value, p), name, "", usage) } -// Like IPMaskVar, but accepts a shorthand letter that can be used after a single dash. +// IPMaskVarP is like IPMaskVar, but accepts a shorthand letter that can be used after a single dash. func IPMaskVarP(p *net.IPMask, name, shorthand string, value net.IPMask, usage string) { CommandLine.VarP(newIPMaskValue(value, p), name, shorthand, usage) } @@ -70,7 +103,7 @@ func (f *FlagSet) IPMask(name string, value net.IPMask, usage string) *net.IPMas return p } -// Like IPMask, but accepts a shorthand letter that can be used after a single dash. +// IPMaskP is like IPMask, but accepts a shorthand letter that can be used after a single dash. func (f *FlagSet) IPMaskP(name, shorthand string, value net.IPMask, usage string) *net.IPMask { p := new(net.IPMask) f.IPMaskVarP(p, name, shorthand, value, usage) @@ -83,7 +116,7 @@ func IPMask(name string, value net.IPMask, usage string) *net.IPMask { return CommandLine.IPMaskP(name, "", value, usage) } -// Like IP, but accepts a shorthand letter that can be used after a single dash. +// IPMaskP is like IP, but accepts a shorthand letter that can be used after a single dash. func IPMaskP(name, shorthand string, value net.IPMask, usage string) *net.IPMask { return CommandLine.IPMaskP(name, shorthand, value, usage) } diff --git a/Godeps/_workspace/src/github.com/spf13/pflag/ipnet.go b/Godeps/_workspace/src/github.com/spf13/pflag/ipnet.go new file mode 100644 index 00000000..149b764b --- /dev/null +++ b/Godeps/_workspace/src/github.com/spf13/pflag/ipnet.go @@ -0,0 +1,100 @@ +package pflag + +import ( + "fmt" + "net" + "strings" +) + +// IPNet adapts net.IPNet for use as a flag. +type ipNetValue net.IPNet + +func (ipnet ipNetValue) String() string { + n := net.IPNet(ipnet) + return n.String() +} + +func (ipnet *ipNetValue) Set(value string) error { + _, n, err := net.ParseCIDR(strings.TrimSpace(value)) + if err != nil { + return err + } + *ipnet = ipNetValue(*n) + return nil +} + +func (*ipNetValue) Type() string { + return "ipNet" +} + +var _ = strings.TrimSpace + +func newIPNetValue(val net.IPNet, p *net.IPNet) *ipNetValue { + *p = val + return (*ipNetValue)(p) +} + +func ipNetConv(sval string) (interface{}, error) { + _, n, err := net.ParseCIDR(strings.TrimSpace(sval)) + if err == nil { + return *n, nil + } + return nil, fmt.Errorf("invalid string being converted to IPNet: %s", sval) +} + +// GetIPNet return the net.IPNet value of a flag with the given name +func (f *FlagSet) GetIPNet(name string) (net.IPNet, error) { + val, err := f.getFlagType(name, "ipNet", ipNetConv) + if err != nil { + return net.IPNet{}, err + } + return val.(net.IPNet), nil +} + +// IPNetVar defines an net.IPNet flag with specified name, default value, and usage string. +// The argument p points to an net.IPNet variable in which to store the value of the flag. +func (f *FlagSet) IPNetVar(p *net.IPNet, name string, value net.IPNet, usage string) { + f.VarP(newIPNetValue(value, p), name, "", usage) +} + +// IPNetVarP is like IPNetVar, but accepts a shorthand letter that can be used after a single dash. +func (f *FlagSet) IPNetVarP(p *net.IPNet, name, shorthand string, value net.IPNet, usage string) { + f.VarP(newIPNetValue(value, p), name, shorthand, usage) +} + +// IPNetVar defines an net.IPNet flag with specified name, default value, and usage string. +// The argument p points to an net.IPNet variable in which to store the value of the flag. +func IPNetVar(p *net.IPNet, name string, value net.IPNet, usage string) { + CommandLine.VarP(newIPNetValue(value, p), name, "", usage) +} + +// IPNetVarP is like IPNetVar, but accepts a shorthand letter that can be used after a single dash. +func IPNetVarP(p *net.IPNet, name, shorthand string, value net.IPNet, usage string) { + CommandLine.VarP(newIPNetValue(value, p), name, shorthand, usage) +} + +// IPNet defines an net.IPNet flag with specified name, default value, and usage string. +// The return value is the address of an net.IPNet variable that stores the value of the flag. +func (f *FlagSet) IPNet(name string, value net.IPNet, usage string) *net.IPNet { + p := new(net.IPNet) + f.IPNetVarP(p, name, "", value, usage) + return p +} + +// IPNetP is like IPNet, but accepts a shorthand letter that can be used after a single dash. +func (f *FlagSet) IPNetP(name, shorthand string, value net.IPNet, usage string) *net.IPNet { + p := new(net.IPNet) + f.IPNetVarP(p, name, shorthand, value, usage) + return p +} + +// IPNet defines an net.IPNet flag with specified name, default value, and usage string. +// The return value is the address of an net.IPNet variable that stores the value of the flag. +func IPNet(name string, value net.IPNet, usage string) *net.IPNet { + return CommandLine.IPNetP(name, "", value, usage) +} + +// IPNetP is like IPNet, but accepts a shorthand letter that can be used after a single dash. +func IPNetP(name, shorthand string, value net.IPNet, usage string) *net.IPNet { + return CommandLine.IPNetP(name, shorthand, value, usage) +} diff --git a/Godeps/_workspace/src/github.com/spf13/pflag/ipnet_test.go b/Godeps/_workspace/src/github.com/spf13/pflag/ipnet_test.go new file mode 100644 index 00000000..335b6fa1 --- /dev/null +++ b/Godeps/_workspace/src/github.com/spf13/pflag/ipnet_test.go @@ -0,0 +1,70 @@ +package pflag + +import ( + "fmt" + "net" + "os" + "testing" +) + +func setUpIPNet(ip *net.IPNet) *FlagSet { + f := NewFlagSet("test", ContinueOnError) + _, def, _ := net.ParseCIDR("0.0.0.0/0") + f.IPNetVar(ip, "address", *def, "IP Address") + return f +} + +func TestIPNet(t *testing.T) { + testCases := []struct { + input string + success bool + expected string + }{ + {"0.0.0.0/0", true, "0.0.0.0/0"}, + {" 0.0.0.0/0 ", true, "0.0.0.0/0"}, + {"1.2.3.4/8", true, "1.0.0.0/8"}, + {"127.0.0.1/16", true, "127.0.0.0/16"}, + {"255.255.255.255/19", true, "255.255.224.0/19"}, + {"255.255.255.255/32", true, "255.255.255.255/32"}, + {"", false, ""}, + {"/0", false, ""}, + {"0", false, ""}, + {"0/0", false, ""}, + {"localhost/0", false, ""}, + {"0.0.0/4", false, ""}, + {"0.0.0./8", false, ""}, + {"0.0.0.0./12", false, ""}, + {"0.0.0.256/16", false, ""}, + {"0.0.0.0 /20", false, ""}, + {"0.0.0.0/ 24", false, ""}, + {"0 . 0 . 0 . 0 / 28", false, ""}, + {"0.0.0.0/33", false, ""}, + } + + devnull, _ := os.Open(os.DevNull) + os.Stderr = devnull + for i := range testCases { + var addr net.IPNet + f := setUpIPNet(&addr) + + tc := &testCases[i] + + arg := fmt.Sprintf("--address=%s", tc.input) + err := f.Parse([]string{arg}) + if err != nil && tc.success == true { + t.Errorf("expected success, got %q", err) + continue + } else if err == nil && tc.success == false { + t.Errorf("expected failure") + continue + } else if tc.success { + ip, err := f.GetIPNet("address") + if err != nil { + t.Errorf("Got error trying to fetch the IP flag: %v", err) + } + if ip.String() != tc.expected { + t.Errorf("expected %q, got %q", tc.expected, ip.String()) + } + } + } +} diff --git a/Godeps/_workspace/src/github.com/spf13/pflag/string.go b/Godeps/_workspace/src/github.com/spf13/pflag/string.go index 362fbf8a..e296136e 100644 --- a/Godeps/_workspace/src/github.com/spf13/pflag/string.go +++ b/Godeps/_workspace/src/github.com/spf13/pflag/string.go @@ -20,13 +20,26 @@ func (s *stringValue) Type() string { func (s *stringValue) String() string { return fmt.Sprintf("%s", *s) } +func stringConv(sval string) (interface{}, error) { + return sval, nil +} + +// GetString return the string value of a flag with the given name +func (f *FlagSet) GetString(name string) (string, error) { + val, err := f.getFlagType(name, "string", stringConv) + if err != nil { + return "", err + } + return val.(string), nil +} + // StringVar defines a string flag with specified name, default value, and usage string. // The argument p points to a string variable in which to store the value of the flag. func (f *FlagSet) StringVar(p *string, name string, value string, usage string) { f.VarP(newStringValue(value, p), name, "", usage) } -// Like StringVar, but accepts a shorthand letter that can be used after a single dash. +// StringVarP is like StringVar, but accepts a shorthand letter that can be used after a single dash. func (f *FlagSet) StringVarP(p *string, name, shorthand string, value string, usage string) { f.VarP(newStringValue(value, p), name, shorthand, usage) } @@ -37,7 +50,7 @@ func StringVar(p *string, name string, value string, usage string) { CommandLine.VarP(newStringValue(value, p), name, "", usage) } -// Like StringVar, but accepts a shorthand letter that can be used after a single dash. +// StringVarP is like StringVar, but accepts a shorthand letter that can be used after a single dash. func StringVarP(p *string, name, shorthand string, value string, usage string) { CommandLine.VarP(newStringValue(value, p), name, shorthand, usage) } @@ -50,7 +63,7 @@ func (f *FlagSet) String(name string, value string, usage string) *string { return p } -// Like String, but accepts a shorthand letter that can be used after a single dash. +// StringP is like String, but accepts a shorthand letter that can be used after a single dash. func (f *FlagSet) StringP(name, shorthand string, value string, usage string) *string { p := new(string) f.StringVarP(p, name, shorthand, value, usage) @@ -63,7 +76,7 @@ func String(name string, value string, usage string) *string { return CommandLine.StringP(name, "", value, usage) } -// Like String, but accepts a shorthand letter that can be used after a single dash. +// StringP is like String, but accepts a shorthand letter that can be used after a single dash. func StringP(name, shorthand string, value string, usage string) *string { return CommandLine.StringP(name, shorthand, value, usage) } diff --git a/Godeps/_workspace/src/github.com/spf13/pflag/string_slice.go b/Godeps/_workspace/src/github.com/spf13/pflag/string_slice.go new file mode 100644 index 00000000..b53648b2 --- /dev/null +++ b/Godeps/_workspace/src/github.com/spf13/pflag/string_slice.go @@ -0,0 +1,111 @@ +package pflag + +import ( + "encoding/csv" + "fmt" + "strings" +) + +var _ = fmt.Fprint + +// -- stringSlice Value +type stringSliceValue struct { + value *[]string + changed bool +} + +func newStringSliceValue(val []string, p *[]string) *stringSliceValue { + ssv := new(stringSliceValue) + ssv.value = p + *ssv.value = val + return ssv +} + +func (s *stringSliceValue) Set(val string) error { + stringReader := strings.NewReader(val) + csvReader := csv.NewReader(stringReader) + v, err := csvReader.Read() + if err != nil { + return err + } + if !s.changed { + *s.value = v + } else { + *s.value = append(*s.value, v...) + } + s.changed = true + return nil +} + +func (s *stringSliceValue) Type() string { + return "stringSlice" +} + +func (s *stringSliceValue) String() string { return "[" + strings.Join(*s.value, ",") + "]" } + +func stringSliceConv(sval string) (interface{}, error) { + sval = strings.Trim(sval, "[]") + // An empty string would cause a slice with one (empty) string + if len(sval) == 0 { + return []string{}, nil + } + v := strings.Split(sval, ",") + return v, nil +} + +// GetStringSlice return the []string value of a flag with the given name +func (f *FlagSet) GetStringSlice(name string) ([]string, error) { + val, err := f.getFlagType(name, "stringSlice", stringSliceConv) + if err != nil { + return []string{}, err + } + return val.([]string), nil +} + +// StringSliceVar defines a string flag with specified name, default value, and usage string. +// The argument p points to a []string variable in which to store the value of the flag. +func (f *FlagSet) StringSliceVar(p *[]string, name string, value []string, usage string) { + f.VarP(newStringSliceValue(value, p), name, "", usage) +} + +// StringSliceVarP is like StringSliceVar, but accepts a shorthand letter that can be used after a single dash. +func (f *FlagSet) StringSliceVarP(p *[]string, name, shorthand string, value []string, usage string) { + f.VarP(newStringSliceValue(value, p), name, shorthand, usage) +} + +// StringSliceVar defines a string flag with specified name, default value, and usage string. +// The argument p points to a []string variable in which to store the value of the flag. +func StringSliceVar(p *[]string, name string, value []string, usage string) { + CommandLine.VarP(newStringSliceValue(value, p), name, "", usage) +} + +// StringSliceVarP is like StringSliceVar, but accepts a shorthand letter that can be used after a single dash. +func StringSliceVarP(p *[]string, name, shorthand string, value []string, usage string) { + CommandLine.VarP(newStringSliceValue(value, p), name, shorthand, usage) +} + +// StringSlice defines a string flag with specified name, default value, and usage string. +// The return value is the address of a []string variable that stores the value of the flag. +func (f *FlagSet) StringSlice(name string, value []string, usage string) *[]string { + p := []string{} + f.StringSliceVarP(&p, name, "", value, usage) + return &p +} + +// StringSliceP is like StringSlice, but accepts a shorthand letter that can be used after a single dash. +func (f *FlagSet) StringSliceP(name, shorthand string, value []string, usage string) *[]string { + p := []string{} + f.StringSliceVarP(&p, name, shorthand, value, usage) + return &p +} + +// StringSlice defines a string flag with specified name, default value, and usage string. +// The return value is the address of a []string variable that stores the value of the flag. +func StringSlice(name string, value []string, usage string) *[]string { + return CommandLine.StringSliceP(name, "", value, usage) +} + +// StringSliceP is like StringSlice, but accepts a shorthand letter that can be used after a single dash. +func StringSliceP(name, shorthand string, value []string, usage string) *[]string { + return CommandLine.StringSliceP(name, shorthand, value, usage) +} diff --git a/Godeps/_workspace/src/github.com/spf13/pflag/string_slice_test.go b/Godeps/_workspace/src/github.com/spf13/pflag/string_slice_test.go new file mode 100644 index 00000000..c7fdc70b --- /dev/null +++ b/Godeps/_workspace/src/github.com/spf13/pflag/string_slice_test.go @@ -0,0 +1,161 @@ +// Copyright 2009 The Go Authors. All rights reserved. +// Use of this source code is governed by a BSD-style +// license that can be found in the LICENSE file. + +package pflag + +import ( + "fmt" + "strings" + "testing" +) + +func setUpSSFlagSet(ssp *[]string) *FlagSet { + f := NewFlagSet("test", ContinueOnError) + f.StringSliceVar(ssp, "ss", []string{}, "Command separated list!") + return f +} + +func setUpSSFlagSetWithDefault(ssp *[]string) *FlagSet { + f := NewFlagSet("test", ContinueOnError) + f.StringSliceVar(ssp, "ss", []string{"default", "values"}, "Command separated list!") + return f +} + +func TestEmptySS(t *testing.T) { + var ss []string + f := setUpSSFlagSet(&ss) + err := f.Parse([]string{}) + if err != nil { + t.Fatal("expected no error; got", err) + } + + getSS, err := f.GetStringSlice("ss") + if err != nil { + t.Fatal("got an error from GetStringSlice():", err) + } + if len(getSS) != 0 { + t.Fatalf("got ss %v with len=%d but expected length=0", getSS, len(getSS)) + } +} + +func TestSS(t *testing.T) { + var ss []string + f := setUpSSFlagSet(&ss) + + vals := []string{"one", "two", "4", "3"} + arg := fmt.Sprintf("--ss=%s", strings.Join(vals, ",")) + err := f.Parse([]string{arg}) + if err != nil { + t.Fatal("expected no error; got", err) + } + for i, v := range ss { + if vals[i] != v { + t.Fatalf("expected ss[%d] to be %s but got: %s", i, vals[i], v) + } + } + + getSS, err := f.GetStringSlice("ss") + if err != nil { + t.Fatal("got an error from GetStringSlice():", err) + } + for i, v := range getSS { + if vals[i] != v { + t.Fatalf("expected ss[%d] to be %s from GetStringSlice but got: %s", i, vals[i], v) + } + } +} + +func TestSSDefault(t *testing.T) { + var ss []string + f := setUpSSFlagSetWithDefault(&ss) + + vals := []string{"default", "values"} + + err := f.Parse([]string{}) + if err != nil { + t.Fatal("expected no error; got", err) + } + for i, v := range ss { + if vals[i] != v { + t.Fatalf("expected ss[%d] to be %s but got: %s", i, vals[i], v) + } + } + + getSS, err := f.GetStringSlice("ss") + if err != nil { + t.Fatal("got an error from GetStringSlice():", err) + } + for i, v := range getSS { + if vals[i] != v { + t.Fatalf("expected ss[%d] to be %s from GetStringSlice but got: %s", i, vals[i], v) + } + } +} + +func TestSSWithDefault(t *testing.T) { + var ss []string + f := setUpSSFlagSetWithDefault(&ss) + + vals := []string{"one", "two", "4", "3"} + arg := fmt.Sprintf("--ss=%s", strings.Join(vals, ",")) + err := f.Parse([]string{arg}) + if err != nil { + t.Fatal("expected no error; got", err) + } + for i, v := range ss { + if vals[i] != v { + t.Fatalf("expected ss[%d] to be %s but got: %s", i, vals[i], v) + } + } + + getSS, err := f.GetStringSlice("ss") + if err != nil { + t.Fatal("got an error from GetStringSlice():", err) + } + for i, v := range getSS { + if vals[i] != v { + t.Fatalf("expected ss[%d] to be %s from GetStringSlice but got: %s", i, vals[i], v) + } + } +} + +func TestSSCalledTwice(t *testing.T) { + var ss []string + f := setUpSSFlagSet(&ss) + + in := []string{"one,two", "three"} + expected := []string{"one", "two", "three"} + argfmt := "--ss=%s" + arg1 := fmt.Sprintf(argfmt, in[0]) + arg2 := fmt.Sprintf(argfmt, in[1]) + err := f.Parse([]string{arg1, arg2}) + if err != nil { + t.Fatal("expected no error; got", err) + } + for i, v := range ss { + if expected[i] != v { + t.Fatalf("expected ss[%d] to be %s but got: %s", i, expected[i], v) + } + } +} + +func TestSSWithComma(t *testing.T) { + var ss []string + f := setUpSSFlagSet(&ss) + + in := []string{`"one,two"`, `"three"`} + expected := []string{"one,two", "three"} + argfmt := "--ss=%s" + arg1 := fmt.Sprintf(argfmt, in[0]) + arg2 := fmt.Sprintf(argfmt, in[1]) + err := f.Parse([]string{arg1, arg2}) + if err != nil { + t.Fatal("expected no error; got", err) + } + for i, v := range ss { + if expected[i] != v { + t.Fatalf("expected ss[%d] to be %s but got: %s", i, expected[i], v) + } + } +} diff --git a/Godeps/_workspace/src/github.com/spf13/pflag/uint.go b/Godeps/_workspace/src/github.com/spf13/pflag/uint.go index c063fe7c..e142b499 100644 --- a/Godeps/_workspace/src/github.com/spf13/pflag/uint.go +++ b/Godeps/_workspace/src/github.com/spf13/pflag/uint.go @@ -25,13 +25,30 @@ func (i *uintValue) Type() string { func (i *uintValue) String() string { return fmt.Sprintf("%v", *i) } +func uintConv(sval string) (interface{}, error) { + v, err := strconv.ParseUint(sval, 0, 0) + if err != nil { + return 0, err + } + return uint(v), nil +} + +// GetUint return the uint value of a flag with the given name +func (f *FlagSet) GetUint(name string) (uint, error) { + val, err := f.getFlagType(name, "uint", uintConv) + if err != nil { + return 0, err + } + return val.(uint), nil +} + // UintVar defines a uint flag with specified name, default value, and usage string. // The argument p points to a uint variable in which to store the value of the flag. func (f *FlagSet) UintVar(p *uint, name string, value uint, usage string) { f.VarP(newUintValue(value, p), name, "", usage) } -// Like UintVar, but accepts a shorthand letter that can be used after a single dash. +// UintVarP is like UintVar, but accepts a shorthand letter that can be used after a single dash. func (f *FlagSet) UintVarP(p *uint, name, shorthand string, value uint, usage string) { f.VarP(newUintValue(value, p), name, shorthand, usage) } @@ -42,7 +59,7 @@ func UintVar(p *uint, name string, value uint, usage string) { CommandLine.VarP(newUintValue(value, p), name, "", usage) } -// Like UintVar, but accepts a shorthand letter that can be used after a single dash. +// UintVarP is like UintVar, but accepts a shorthand letter that can be used after a single dash. func UintVarP(p *uint, name, shorthand string, value uint, usage string) { CommandLine.VarP(newUintValue(value, p), name, shorthand, usage) } @@ -55,7 +72,7 @@ func (f *FlagSet) Uint(name string, value uint, usage string) *uint { return p } -// Like Uint, but accepts a shorthand letter that can be used after a single dash. +// UintP is like Uint, but accepts a shorthand letter that can be used after a single dash. func (f *FlagSet) UintP(name, shorthand string, value uint, usage string) *uint { p := new(uint) f.UintVarP(p, name, shorthand, value, usage) @@ -68,7 +85,7 @@ func Uint(name string, value uint, usage string) *uint { return CommandLine.UintP(name, "", value, usage) } -// Like Uint, but accepts a shorthand letter that can be used after a single dash. +// UintP is like Uint, but accepts a shorthand letter that can be used after a single dash. func UintP(name, shorthand string, value uint, usage string) *uint { return CommandLine.UintP(name, shorthand, value, usage) } diff --git a/Godeps/_workspace/src/github.com/spf13/pflag/uint16.go b/Godeps/_workspace/src/github.com/spf13/pflag/uint16.go index ec14ab0c..5c96c19d 100644 --- a/Godeps/_workspace/src/github.com/spf13/pflag/uint16.go +++ b/Godeps/_workspace/src/github.com/spf13/pflag/uint16.go @@ -19,21 +19,34 @@ func (i *uint16Value) Set(s string) error { return err } -func (i *uint16Value) Get() interface{} { - return uint16(*i) -} - func (i *uint16Value) Type() string { return "uint16" } +func uint16Conv(sval string) (interface{}, error) { + v, err := strconv.ParseUint(sval, 0, 16) + if err != nil { + return 0, err + } + return uint16(v), nil +} + +// GetUint16 return the uint16 value of a flag with the given name +func (f *FlagSet) GetUint16(name string) (uint16, error) { + val, err := f.getFlagType(name, "uint16", uint16Conv) + if err != nil { + return 0, err + } + return val.(uint16), nil +} + // Uint16Var defines a uint flag with specified name, default value, and usage string. // The argument p points to a uint variable in which to store the value of the flag. func (f *FlagSet) Uint16Var(p *uint16, name string, value uint16, usage string) { f.VarP(newUint16Value(value, p), name, "", usage) } -// Like Uint16Var, but accepts a shorthand letter that can be used after a single dash. +// Uint16VarP is like Uint16Var, but accepts a shorthand letter that can be used after a single dash. func (f *FlagSet) Uint16VarP(p *uint16, name, shorthand string, value uint16, usage string) { f.VarP(newUint16Value(value, p), name, shorthand, usage) } @@ -44,7 +57,7 @@ func Uint16Var(p *uint16, name string, value uint16, usage string) { CommandLine.VarP(newUint16Value(value, p), name, "", usage) } -// Like Uint16Var, but accepts a shorthand letter that can be used after a single dash. +// Uint16VarP is like Uint16Var, but accepts a shorthand letter that can be used after a single dash. func Uint16VarP(p *uint16, name, shorthand string, value uint16, usage string) { CommandLine.VarP(newUint16Value(value, p), name, shorthand, usage) } @@ -57,7 +70,7 @@ func (f *FlagSet) Uint16(name string, value uint16, usage string) *uint16 { return p } -// Like Uint16, but accepts a shorthand letter that can be used after a single dash. +// Uint16P is like Uint16, but accepts a shorthand letter that can be used after a single dash. func (f *FlagSet) Uint16P(name, shorthand string, value uint16, usage string) *uint16 { p := new(uint16) f.Uint16VarP(p, name, shorthand, value, usage) @@ -70,7 +83,7 @@ func Uint16(name string, value uint16, usage string) *uint16 { return CommandLine.Uint16P(name, "", value, usage) } -// Like Uint16, but accepts a shorthand letter that can be used after a single dash. +// Uint16P is like Uint16, but accepts a shorthand letter that can be used after a single dash. func Uint16P(name, shorthand string, value uint16, usage string) *uint16 { return CommandLine.Uint16P(name, shorthand, value, usage) } diff --git a/Godeps/_workspace/src/github.com/spf13/pflag/uint32.go b/Godeps/_workspace/src/github.com/spf13/pflag/uint32.go index 05bc3bd0..294fcaa3 100644 --- a/Godeps/_workspace/src/github.com/spf13/pflag/uint32.go +++ b/Godeps/_workspace/src/github.com/spf13/pflag/uint32.go @@ -18,21 +18,35 @@ func (i *uint32Value) Set(s string) error { *i = uint32Value(v) return err } -func (i *uint32Value) Get() interface{} { - return uint32(*i) -} func (i *uint32Value) Type() string { return "uint32" } +func uint32Conv(sval string) (interface{}, error) { + v, err := strconv.ParseUint(sval, 0, 32) + if err != nil { + return 0, err + } + return uint32(v), nil +} + +// GetUint32 return the uint32 value of a flag with the given name +func (f *FlagSet) GetUint32(name string) (uint32, error) { + val, err := f.getFlagType(name, "uint32", uint32Conv) + if err != nil { + return 0, err + } + return val.(uint32), nil +} + // Uint32Var defines a uint32 flag with specified name, default value, and usage string. // The argument p points to a uint32 variable in which to store the value of the flag. func (f *FlagSet) Uint32Var(p *uint32, name string, value uint32, usage string) { f.VarP(newUint32Value(value, p), name, "", usage) } -// Like Uint32Var, but accepts a shorthand letter that can be used after a single dash. +// Uint32VarP is like Uint32Var, but accepts a shorthand letter that can be used after a single dash. func (f *FlagSet) Uint32VarP(p *uint32, name, shorthand string, value uint32, usage string) { f.VarP(newUint32Value(value, p), name, shorthand, usage) } @@ -43,7 +57,7 @@ func Uint32Var(p *uint32, name string, value uint32, usage string) { CommandLine.VarP(newUint32Value(value, p), name, "", usage) } -// Like Uint32Var, but accepts a shorthand letter that can be used after a single dash. +// Uint32VarP is like Uint32Var, but accepts a shorthand letter that can be used after a single dash. func Uint32VarP(p *uint32, name, shorthand string, value uint32, usage string) { CommandLine.VarP(newUint32Value(value, p), name, shorthand, usage) } @@ -56,7 +70,7 @@ func (f *FlagSet) Uint32(name string, value uint32, usage string) *uint32 { return p } -// Like Uint32, but accepts a shorthand letter that can be used after a single dash. +// Uint32P is like Uint32, but accepts a shorthand letter that can be used after a single dash. func (f *FlagSet) Uint32P(name, shorthand string, value uint32, usage string) *uint32 { p := new(uint32) f.Uint32VarP(p, name, shorthand, value, usage) @@ -69,7 +83,7 @@ func Uint32(name string, value uint32, usage string) *uint32 { return CommandLine.Uint32P(name, "", value, usage) } -// Like Uint32, but accepts a shorthand letter that can be used after a single dash. +// Uint32P is like Uint32, but accepts a shorthand letter that can be used after a single dash. func Uint32P(name, shorthand string, value uint32, usage string) *uint32 { return CommandLine.Uint32P(name, shorthand, value, usage) } diff --git a/Godeps/_workspace/src/github.com/spf13/pflag/uint64.go b/Godeps/_workspace/src/github.com/spf13/pflag/uint64.go index 99c7e805..c6818850 100644 --- a/Godeps/_workspace/src/github.com/spf13/pflag/uint64.go +++ b/Godeps/_workspace/src/github.com/spf13/pflag/uint64.go @@ -25,13 +25,30 @@ func (i *uint64Value) Type() string { func (i *uint64Value) String() string { return fmt.Sprintf("%v", *i) } +func uint64Conv(sval string) (interface{}, error) { + v, err := strconv.ParseUint(sval, 0, 64) + if err != nil { + return 0, err + } + return uint64(v), nil +} + +// GetUint64 return the uint64 value of a flag with the given name +func (f *FlagSet) GetUint64(name string) (uint64, error) { + val, err := f.getFlagType(name, "uint64", uint64Conv) + if err != nil { + return 0, err + } + return val.(uint64), nil +} + // Uint64Var defines a uint64 flag with specified name, default value, and usage string. // The argument p points to a uint64 variable in which to store the value of the flag. func (f *FlagSet) Uint64Var(p *uint64, name string, value uint64, usage string) { f.VarP(newUint64Value(value, p), name, "", usage) } -// Like Uint64Var, but accepts a shorthand letter that can be used after a single dash. +// Uint64VarP is like Uint64Var, but accepts a shorthand letter that can be used after a single dash. func (f *FlagSet) Uint64VarP(p *uint64, name, shorthand string, value uint64, usage string) { f.VarP(newUint64Value(value, p), name, shorthand, usage) } @@ -42,7 +59,7 @@ func Uint64Var(p *uint64, name string, value uint64, usage string) { CommandLine.VarP(newUint64Value(value, p), name, "", usage) } -// Like Uint64Var, but accepts a shorthand letter that can be used after a single dash. +// Uint64VarP is like Uint64Var, but accepts a shorthand letter that can be used after a single dash. func Uint64VarP(p *uint64, name, shorthand string, value uint64, usage string) { CommandLine.VarP(newUint64Value(value, p), name, shorthand, usage) } @@ -55,7 +72,7 @@ func (f *FlagSet) Uint64(name string, value uint64, usage string) *uint64 { return p } -// Like Uint64, but accepts a shorthand letter that can be used after a single dash. +// Uint64P is like Uint64, but accepts a shorthand letter that can be used after a single dash. func (f *FlagSet) Uint64P(name, shorthand string, value uint64, usage string) *uint64 { p := new(uint64) f.Uint64VarP(p, name, shorthand, value, usage) @@ -68,7 +85,7 @@ func Uint64(name string, value uint64, usage string) *uint64 { return CommandLine.Uint64P(name, "", value, usage) } -// Like Uint64, but accepts a shorthand letter that can be used after a single dash. +// Uint64P is like Uint64, but accepts a shorthand letter that can be used after a single dash. func Uint64P(name, shorthand string, value uint64, usage string) *uint64 { return CommandLine.Uint64P(name, shorthand, value, usage) } diff --git a/Godeps/_workspace/src/github.com/spf13/pflag/uint8.go b/Godeps/_workspace/src/github.com/spf13/pflag/uint8.go index 6fef508d..26db418a 100644 --- a/Godeps/_workspace/src/github.com/spf13/pflag/uint8.go +++ b/Godeps/_workspace/src/github.com/spf13/pflag/uint8.go @@ -25,13 +25,30 @@ func (i *uint8Value) Type() string { func (i *uint8Value) String() string { return fmt.Sprintf("%v", *i) } +func uint8Conv(sval string) (interface{}, error) { + v, err := strconv.ParseUint(sval, 0, 8) + if err != nil { + return 0, err + } + return uint8(v), nil +} + +// GetUint8 return the uint8 value of a flag with the given name +func (f *FlagSet) GetUint8(name string) (uint8, error) { + val, err := f.getFlagType(name, "uint8", uint8Conv) + if err != nil { + return 0, err + } + return val.(uint8), nil +} + // Uint8Var defines a uint8 flag with specified name, default value, and usage string. // The argument p points to a uint8 variable in which to store the value of the flag. func (f *FlagSet) Uint8Var(p *uint8, name string, value uint8, usage string) { f.VarP(newUint8Value(value, p), name, "", usage) } -// Like Uint8Var, but accepts a shorthand letter that can be used after a single dash. +// Uint8VarP is like Uint8Var, but accepts a shorthand letter that can be used after a single dash. func (f *FlagSet) Uint8VarP(p *uint8, name, shorthand string, value uint8, usage string) { f.VarP(newUint8Value(value, p), name, shorthand, usage) } @@ -42,7 +59,7 @@ func Uint8Var(p *uint8, name string, value uint8, usage string) { CommandLine.VarP(newUint8Value(value, p), name, "", usage) } -// Like Uint8Var, but accepts a shorthand letter that can be used after a single dash. +// Uint8VarP is like Uint8Var, but accepts a shorthand letter that can be used after a single dash. func Uint8VarP(p *uint8, name, shorthand string, value uint8, usage string) { CommandLine.VarP(newUint8Value(value, p), name, shorthand, usage) } @@ -55,7 +72,7 @@ func (f *FlagSet) Uint8(name string, value uint8, usage string) *uint8 { return p } -// Like Uint8, but accepts a shorthand letter that can be used after a single dash. +// Uint8P is like Uint8, but accepts a shorthand letter that can be used after a single dash. func (f *FlagSet) Uint8P(name, shorthand string, value uint8, usage string) *uint8 { p := new(uint8) f.Uint8VarP(p, name, shorthand, value, usage) @@ -68,7 +85,7 @@ func Uint8(name string, value uint8, usage string) *uint8 { return CommandLine.Uint8P(name, "", value, usage) } -// Like Uint8, but accepts a shorthand letter that can be used after a single dash. +// Uint8P is like Uint8, but accepts a shorthand letter that can be used after a single dash. func Uint8P(name, shorthand string, value uint8, usage string) *uint8 { return CommandLine.Uint8P(name, shorthand, value, usage) } diff --git a/Godeps/_workspace/src/github.com/spf13/pflag/verify/all.sh b/Godeps/_workspace/src/github.com/spf13/pflag/verify/all.sh new file mode 100644 index 00000000..739f89c0 --- /dev/null +++ b/Godeps/_workspace/src/github.com/spf13/pflag/verify/all.sh @@ -0,0 +1,69 @@ +#!/bin/bash + +set -o errexit +set -o nounset +set -o pipefail + +ROOT=$(dirname "${BASH_SOURCE}")/.. + +# Some useful colors. +if [[ -z "${color_start-}" ]]; then + declare -r color_start="\033[" + declare -r color_red="${color_start}0;31m" + declare -r color_yellow="${color_start}0;33m" + declare -r color_green="${color_start}0;32m" + declare -r color_norm="${color_start}0m" +fi + +SILENT=true + +function is-excluded { + for e in $EXCLUDE; do + if [[ $1 -ef ${BASH_SOURCE} ]]; then + return + fi + if [[ $1 -ef "$ROOT/hack/$e" ]]; then + return + fi + done + return 1 +} + +while getopts ":v" opt; do + case $opt in + v) + SILENT=false + ;; + \?) + echo "Invalid flag: -$OPTARG" >&2 + exit 1 + ;; + esac +done + +if $SILENT ; then + echo "Running in the silent mode, run with -v if you want to see script logs." +fi + +EXCLUDE="all.sh" + +ret=0 +for t in `ls $ROOT/verify/*.sh` +do + if is-excluded $t ; then + echo "Skipping $t" + continue + fi + if $SILENT ; then + echo -e "Verifying $t" + if bash "$t" &> /dev/null; then + echo -e "${color_green}SUCCESS${color_norm}" + else + echo -e "${color_red}FAILED${color_norm}" + ret=1 + fi + else + bash "$t" || ret=1 + fi +done +exit $ret diff --git a/Godeps/_workspace/src/github.com/spf13/pflag/verify/gofmt.sh b/Godeps/_workspace/src/github.com/spf13/pflag/verify/gofmt.sh new file mode 100644 index 00000000..f66acf80 --- /dev/null +++ b/Godeps/_workspace/src/github.com/spf13/pflag/verify/gofmt.sh @@ -0,0 +1,19 @@ +#!/bin/bash + +set -o errexit +set -o nounset +set -o pipefail + +ROOT=$(dirname "${BASH_SOURCE}")/.. + +pushd "${ROOT}" > /dev/null + +GOFMT=${GOFMT:-"gofmt"} +bad_files=$(find . -name '*.go' | xargs $GOFMT -s -l) +if [[ -n "${bad_files}" ]]; then + echo "!!! '$GOFMT' needs to be run on the following files: " + echo "${bad_files}" + exit 1 +fi + +# ex: ts=2 sw=2 et filetype=sh diff --git a/Godeps/_workspace/src/github.com/spf13/pflag/verify/golint.sh b/Godeps/_workspace/src/github.com/spf13/pflag/verify/golint.sh new file mode 100644 index 00000000..685c1778 --- /dev/null +++ b/Godeps/_workspace/src/github.com/spf13/pflag/verify/golint.sh @@ -0,0 +1,15 @@ +#!/bin/bash + +ROOT=$(dirname "${BASH_SOURCE}")/.. +GOLINT=${GOLINT:-"golint"} + +pushd "${ROOT}" > /dev/null + bad_files=$($GOLINT -min_confidence=0.9 ./...) + if [[ -n "${bad_files}" ]]; then + echo "!!! '$GOLINT' problems: " + echo "${bad_files}" + exit 1 + fi +popd > /dev/null + +# ex: ts=2 sw=2 et filetype=sh diff --git a/Godeps/_workspace/src/github.com/spf13/viper/.travis.yml b/Godeps/_workspace/src/github.com/spf13/viper/.travis.yml index a578cbf6..9e6d6211 100644 --- a/Godeps/_workspace/src/github.com/spf13/viper/.travis.yml +++ b/Godeps/_workspace/src/github.com/spf13/viper/.travis.yml @@ -1,9 +1,9 @@ language: go go: - - 1.2 - 1.3 - release - tip script: - go test -v ./... +sudo: false diff --git a/Godeps/_workspace/src/github.com/spf13/viper/README.md b/Godeps/_workspace/src/github.com/spf13/viper/README.md index 16d017ba..8fdcd379 100644 --- a/Godeps/_workspace/src/github.com/spf13/viper/README.md +++ b/Godeps/_workspace/src/github.com/spf13/viper/README.md @@ -1,47 +1,48 @@ -viper [![Build Status](https://travis-ci.org/spf13/viper.svg)](https://travis-ci.org/spf13/viper) -===== +![viper logo](https://cloud.githubusercontent.com/assets/173412/10886745/998df88a-8151-11e5-9448-4736db51020d.png) -[![Join the chat at https://gitter.im/spf13/viper](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/spf13/viper?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) +Go configuration with fangs! + + [![Build Status](https://travis-ci.org/spf13/viper.svg)](https://travis-ci.org/spf13/viper) [![Join the chat at https://gitter.im/spf13/viper](https://badges.gitter.im/Join%20Chat.svg)](https://gitter.im/spf13/viper?utm_source=badge&utm_medium=badge&utm_campaign=pr-badge&utm_content=badge) -Go configuration with fangs ## What is Viper? -Viper is a complete configuration solution for go applications. It has -been designed to work within an application to handle all types of -configuration. It supports +Viper is a complete configuration solution for go applications including 12 factor apps. It is designed +to work within an application, and can handle all types of configuration needs +and formats. It supports: * setting defaults -* reading from json, toml and yaml config files +* reading from JSON, TOML, YAML and HCL config files +* live watching and re-reading of config files (optional) * reading from environment variables -* reading from remote config systems (Etcd or Consul), watching changes +* reading from remote config systems (etcd or Consul), and watching changes * reading from command line flags * reading from buffer * setting explicit values -It can be thought of as a registry for all of your applications +Viper can be thought of as a registry for all of your applications configuration needs. ## Why Viper? -When building a modern application, you don’t want to have to worry about +When building a modern application, you don’t want to worry about configuration file formats; you want to focus on building awesome software. Viper is here to help with that. Viper does the following for you: -1. Find, load and marshal a configuration file in JSON, TOML or YAML. +1. Find, load, and unmarshal a configuration file in JSON, TOML, YAML or HCL. 2. Provide a mechanism to set default values for your different configuration options. -3. Provide a mechanism to set override values for options specified - through command line flags. -4. Provide an alias system to easily rename parameters without breaking - existing code. -5. Make it easy to tell the difference between when a user has provided - a command line or config file which is the same as the default. +3. Provide a mechanism to set override values for options specified through + command line flags. +4. Provide an alias system to easily rename parameters without breaking existing + code. +5. Make it easy to tell the difference between when a user has provided a + command line or config file which is the same as the default. -Viper uses the following precedence order. Each item takes precedence -over the item below it: +Viper uses the following precedence order. Each item takes precedence over the +item below it: * explicit call to Set * flag @@ -56,42 +57,71 @@ Viper configuration keys are case insensitive. ### Establishing Defaults -A good configuration system will support default values. A default value -is not required for a key, but can establish a default to be used in the -event that the key hasn’t be set via config file, environment variable, -remote configuration or flag. +A good configuration system will support default values. A default value is not +required for a key, but it's useful in the event that a key hasn’t been set via +config file, environment variable, remote configuration or flag. Examples: - viper.SetDefault("ContentDir", "content") - viper.SetDefault("LayoutDir", "layouts") - viper.SetDefault("Taxonomies", map[string]string{"tag": "tags", "category": "categories"}) +```go +viper.SetDefault("ContentDir", "content") +viper.SetDefault("LayoutDir", "layouts") +viper.SetDefault("Taxonomies", map[string]string{"tag": "tags", "category": "categories"}) +``` ### Reading Config Files -If you want to support a config file, Viper requires a minimal -configuration so it knows where to look for the config file. Viper -supports json, toml and yaml files. Viper can search multiple paths, but -currently a single viper only supports a single config file. +Viper requires minimal configuration so it knows where to look for config files. +Viper supports JSON, TOML, YAML and HCL files. Viper can search multiple paths, but +currently a single Viper instance only supports a single configuration file. +Viper does not default to any configuration search paths leaving defaults decision +to an application. - viper.SetConfigName("config") // name of config file (without extension) - viper.AddConfigPath("/etc/appname/") // path to look for the config file in - viper.AddConfigPath("$HOME/.appname") // call multiple times to add many search paths - err := viper.ReadInConfig() // Find and read the config file - if err != nil { // Handle errors reading the config file - panic(fmt.Errorf("Fatal error config file: %s \n", err)) - } +Here is an example of how to use Viper to search for and read a configuration file. +None of the specific paths are required, but at least one path should be provided +where a configuration file is expected. + +```go +viper.SetConfigName("config") // name of config file (without extension) +viper.AddConfigPath("/etc/appname/") // path to look for the config file in +viper.AddConfigPath("$HOME/.appname") // call multiple times to add many search paths +viper.AddConfigPath(".") // optionally look for config in the working directory +err := viper.ReadInConfig() // Find and read the config file +if err != nil { // Handle errors reading the config file + panic(fmt.Errorf("Fatal error config file: %s \n", err)) +} +``` + +### Watching and re-reading config files + +Viper supports the ability to have your application live read a config file while running. + +Gone are the days of needing to restart a server to have a config take effect, +viper powered applications can read an update to a config file while running and +not miss a beat. + +Simply tell the viper instance to watchConfig. +Optionally you can provide a function for Viper to run each time a change occurs. + +**Make sure you add all of the configPaths prior to calling `WatchConfig()`** + +```go + viper.WatchConfig() + viper.OnConfigChange(func(e fsnotify.Event) { + fmt.Println("Config file changed:", e.Name) + }) +``` ### Reading Config from io.Reader -Viper predefined many configuration sources, such as files, environment variables, flags and -remote K/V store. But you are not bound to them. You can also implement your own way to -require configuration and feed it to viper. +Viper predefines many configuration sources such as files, environment +variables, flags, and remote K/V store, but you are not bound to them. You can +also implement your own required configuration source and feed it to viper. -````go +```go viper.SetConfigType("yaml") // or viper.SetConfigType("YAML") -// any approach to require this configuration into your program. +// any approach to require this configuration into your program. var yamlExample = []byte(` Hacker: true name: steve @@ -110,202 +140,289 @@ beard: true viper.ReadConfig(bytes.NewBuffer(yamlExample)) viper.Get("name") // this would be "steve" -```` +``` ### Setting Overrides These could be from a command line flag, or from your own application logic. - viper.Set("Verbose", true) - viper.Set("LogFile", LogFile) +```go +viper.Set("Verbose", true) +viper.Set("LogFile", LogFile) +``` ### Registering and Using Aliases Aliases permit a single value to be referenced by multiple keys - viper.RegisterAlias("loud", "Verbose") +```go +viper.RegisterAlias("loud", "Verbose") - viper.Set("verbose", true) // same result as next line - viper.Set("loud", true) // same result as prior line +viper.Set("verbose", true) // same result as next line +viper.Set("loud", true) // same result as prior line - viper.GetBool("loud") // true - viper.GetBool("verbose") // true +viper.GetBool("loud") // true +viper.GetBool("verbose") // true +``` ### Working with Environment Variables Viper has full support for environment variables. This enables 12 factor -applications out of the box. There are four methods that exist to aid -with working with ENV: +applications out of the box. There are four methods that exist to aid working +with ENV: - * AutomaticEnv() - * BindEnv(string...) : error - * SetEnvPrefix(string) - * SetEnvReplacer(string...) *strings.Replacer + * `AutomaticEnv()` + * `BindEnv(string...) : error` + * `SetEnvPrefix(string)` + * `SetEnvReplacer(string...) *strings.Replacer` _When working with ENV variables, it’s important to recognize that Viper treats ENV variables as case sensitive._ -Viper provides a mechanism to try to ensure that ENV variables are -unique. By using SetEnvPrefix, you can tell Viper to use add a prefix -while reading from the environment variables. Both BindEnv and -AutomaticEnv will use this prefix. +Viper provides a mechanism to try to ensure that ENV variables are unique. By +using `SetEnvPrefix`, you can tell Viper to use add a prefix while reading from +the environment variables. Both `BindEnv` and `AutomaticEnv` will use this +prefix. -BindEnv takes one or two parameters. The first parameter is the key -name, the second is the name of the environment variable. The name of -the environment variable is case sensitive. If the ENV variable name is -not provided, then Viper will automatically assume that the key name -matches the ENV variable name but the ENV variable is IN ALL CAPS. When -you explicitly provide the ENV variable name, it **does not** -automatically add the prefix. +`BindEnv` takes one or two parameters. The first parameter is the key name, the +second is the name of the environment variable. The name of the environment +variable is case sensitive. If the ENV variable name is not provided, then +Viper will automatically assume that the key name matches the ENV variable name, +but the ENV variable is IN ALL CAPS. When you explicitly provide the ENV +variable name, it **does not** automatically add the prefix. -One important thing to recognize when working with ENV variables is that -the value will be read each time it is accessed. It does not fix the -value when the BindEnv is called. +One important thing to recognize when working with ENV variables is that the +value will be read each time it is accessed. Viper does not fix the value when +the `BindEnv` is called. -AutomaticEnv is a powerful helper especially when combined with -SetEnvPrefix. When called, Viper will check for an environment variable -any time a viper.Get request is made. It will apply the following rules. -It will check for a environment variable with a name matching the key -uppercased and prefixed with the EnvPrefix if set. +`AutomaticEnv` is a powerful helper especially when combined with +`SetEnvPrefix`. When called, Viper will check for an environment variable any +time a `viper.Get` request is made. It will apply the following rules. It will +check for a environment variable with a name matching the key uppercased and +prefixed with the `EnvPrefix` if set. -SetEnvReplacer allows you to use a `strings.Replacer` object to rewrite Env keys -to an extent. This is useful if you want to use `-` or something in your Get() -calls, but want your environmental variables to use `_` delimiters. An example -of using it can be found in `viper_test.go`. +`SetEnvReplacer` allows you to use a `strings.Replacer` object to rewrite Env +keys to an extent. This is useful if you want to use `-` or something in your +`Get()` calls, but want your environmental variables to use `_` delimiters. An +example of using it can be found in `viper_test.go`. #### Env example - SetEnvPrefix("spf") // will be uppercased automatically - BindEnv("id") +```go +SetEnvPrefix("spf") // will be uppercased automatically +BindEnv("id") - os.Setenv("SPF_ID", "13") // typically done outside of the app - - id := Get("id") // 13 +os.Setenv("SPF_ID", "13") // typically done outside of the app +id := Get("id") // 13 +``` ### Working with Flags -Viper has the ability to bind to flags. Specifically, Viper supports -Pflags as used in the [Cobra](https://github.com/spf13/cobra) library. +Viper has the ability to bind to flags. Specifically, Viper supports `Pflags` +as used in the [Cobra](https://github.com/spf13/cobra) library. -Like BindEnv, the value is not set when the binding method is called, but -when it is accessed. This means you can bind as early as you want, even -in an init() function. +Like `BindEnv`, the value is not set when the binding method is called, but when +it is accessed. This means you can bind as early as you want, even in an +`init()` function. -The BindPFlag() method provides this functionality. +The `BindPFlag()` method provides this functionality. Example: - serverCmd.Flags().Int("port", 1138, "Port to run Application server on") - viper.BindPFlag("port", serverCmd.Flags().Lookup("port")) +```go +serverCmd.Flags().Int("port", 1138, "Port to run Application server on") +viper.BindPFlag("port", serverCmd.Flags().Lookup("port")) +``` +The use of [pflag](https://github.com/spf13/pflag/) in Viper does not preclude +the use of other packages that use the [flag](https://golang.org/pkg/flag/) +package from the standard library. The pflag package can handle the flags +defined for the flag package by importing these flags. This is accomplished +by a calling a convenience function provided by the pflag package called +AddGoFlagSet(). + +Example: + +```go +package main + +import ( + "flag" + "github.com/spf13/pflag" +) + +func main() { + pflag.CommandLine.AddGoFlagSet(flag.CommandLine) + pflag.Parse() + ... +} +``` + +#### Flag interfaces + +Viper provides two Go interfaces to bind other flag systems if you don't use `Pflags`. + +`FlagValue` represents a single flag. This is a very simple example on how to implement this interface: + +```go +type myFlag struct {} +func (f myFlag) IsChanged() { return false } +func (f myFlag) Name() { return "my-flag-name" } +func (f myFlag) ValueString() { return "my-flag-value" } +func (f myFlag) ValueType() { return "string" } +``` + +Once your flag implements this interface, you can simply tell Viper to bind it: + +```go +viper.BindFlagValue("my-flag-name", myFlag{}) +``` + +`FlagValueSet` represents a group of flags. This is a very simple example on how to implement this interface: + +```go +type myFlagSet struct { + flags []myFlag +} + +func (f myFlagSet) VisitAll(fn func(FlagValue)) { + for _, flag := range flags { + fn(flag) + } +} +``` + +Once your flag set implements this interface, you can simply tell Viper to bind it: + +```go +fSet := myFlagSet{ + flags: []myFlag{myFlag{}, myFlag{}}, +} +viper.BindFlagValues("my-flags", fSet) +``` ### Remote Key/Value Store Support -Viper will read a config string (as JSON, TOML, or YAML) retrieved from a -path in a Key/Value store such as Etcd or Consul. These values take precedence -over default values, but are overriden by configuration values retrieved from disk, + +To enable remote support in Viper, do a blank import of the `viper/remote` +package: + +`import _ "github.com/spf13/viper/remote"` + +Viper will read a config string (as JSON, TOML, YAML or HCL) retrieved from a path +in a Key/Value store such as etcd or Consul. These values take precedence over +default values, but are overridden by configuration values retrieved from disk, flags, or environment variables. -Viper uses [crypt](https://github.com/xordataexchange/crypt) to retrieve configuration -from the K/V store, which means that you can store your configuration values -encrypted and have them automatically decrypted if you have the correct -gpg keyring. Encryption is optional. +Viper uses [crypt](https://github.com/xordataexchange/crypt) to retrieve +configuration from the K/V store, which means that you can store your +configuration values encrypted and have them automatically decrypted if you have +the correct gpg keyring. Encryption is optional. You can use remote configuration in conjunction with local configuration, or independently of it. -`crypt` has a command-line helper that you can use to put configurations -in your K/V store. `crypt` defaults to etcd on http://127.0.0.1:4001. +`crypt` has a command-line helper that you can use to put configurations in your +K/V store. `crypt` defaults to etcd on http://127.0.0.1:4001. - go get github.com/xordataexchange/crypt/bin/crypt - crypt set -plaintext /config/hugo.json /Users/hugo/settings/config.json +```bash +$ go get github.com/xordataexchange/crypt/bin/crypt +$ crypt set -plaintext /config/hugo.json /Users/hugo/settings/config.json +``` Confirm that your value was set: - crypt get -plaintext /config/hugo.json +```bash +$ crypt get -plaintext /config/hugo.json +``` -See the `crypt` documentation for examples of how to set encrypted values, or how -to use Consul. +See the `crypt` documentation for examples of how to set encrypted values, or +how to use Consul. ### Remote Key/Value Store Example - Unencrypted - viper.AddRemoteProvider("etcd", "http://127.0.0.1:4001","/config/hugo.json") - viper.SetConfigType("json") // because there is no file extension in a stream of bytes - err := viper.ReadRemoteConfig() +```go +viper.AddRemoteProvider("etcd", "http://127.0.0.1:4001","/config/hugo.json") +viper.SetConfigType("json") // because there is no file extension in a stream of bytes +err := viper.ReadRemoteConfig() +``` ### Remote Key/Value Store Example - Encrypted - viper.AddSecureRemoteProvider("etcd","http://127.0.0.1:4001","/config/hugo.json","/etc/secrets/mykeyring.gpg") - viper.SetConfigType("json") // because there is no file extension in a stream of bytes - err := viper.ReadRemoteConfig() +```go +viper.AddSecureRemoteProvider("etcd","http://127.0.0.1:4001","/config/hugo.json","/etc/secrets/mykeyring.gpg") +viper.SetConfigType("json") // because there is no file extension in a stream of bytes +err := viper.ReadRemoteConfig() +``` -### Watching Changes in Etcd - Unencrypted +### Watching Changes in etcd - Unencrypted - // alternatively, you can create a new viper instance. - var runtime_viper = viper.New() +```go +// alternatively, you can create a new viper instance. +var runtime_viper = viper.New() - runtime_viper.AddRemoteProvider("etcd", "http://127.0.0.1:4001", "/config/hugo.yml") - runtime_viper.SetConfigType("yaml") // because there is no file extension in a stream of bytes +runtime_viper.AddRemoteProvider("etcd", "http://127.0.0.1:4001", "/config/hugo.yml") +runtime_viper.SetConfigType("yaml") // because there is no file extension in a stream of bytes - // read from remote config the first time. - err := runtime_viper.ReadRemoteConfig() +// read from remote config the first time. +err := runtime_viper.ReadRemoteConfig() - // marshal config - runtime_viper.Marshal(&runtime_conf) +// unmarshal config +runtime_viper.Unmarshal(&runtime_conf) - // open a goroutine to wath remote changes forever - go func(){ - for { - time.Sleep(time.Second * 5) // delay after each request +// open a goroutine to watch remote changes forever +go func(){ + for { + time.Sleep(time.Second * 5) // delay after each request - // currenlty, only tested with etcd support - err := runtime_viper.WatchRemoteConfig() - if err != nil { - log.Errorf("unable to read remote config: %v", err) - continue - } - - // marshal new config into our runtime config struct. you can also use channel - // to implement a signal to notify the system of the changes - runtime_viper.Marshal(&runtime_conf) - } - }() + // currently, only tested with etcd support + err := runtime_viper.WatchRemoteConfig() + if err != nil { + log.Errorf("unable to read remote config: %v", err) + continue + } + // unmarshal new config into our runtime config struct. you can also use channel + // to implement a signal to notify the system of the changes + runtime_viper.Unmarshal(&runtime_conf) + } +}() +``` ## Getting Values From Viper -In Viper, there are a few ways to get a value depending on what type of value you want to retrieved. +In Viper, there are a few ways to get a value depending on the value's type. The following functions and methods exist: - * Get(key string) : interface{} - * GetBool(key string) : bool - * GetFloat64(key string) : float64 - * GetInt(key string) : int - * GetString(key string) : string - * GetStringMap(key string) : map[string]interface{} - * GetStringMapString(key string) : map[string]string - * GetStringSlice(key string) : []string - * GetTime(key string) : time.Time - * GetDuration(key string) : time.Duration - * IsSet(key string) : bool + * `Get(key string) : interface{}` + * `GetBool(key string) : bool` + * `GetFloat64(key string) : float64` + * `GetInt(key string) : int` + * `GetString(key string) : string` + * `GetStringMap(key string) : map[string]interface{}` + * `GetStringMapString(key string) : map[string]string` + * `GetStringSlice(key string) : []string` + * `GetTime(key string) : time.Time` + * `GetDuration(key string) : time.Duration` + * `IsSet(key string) : bool` -One important thing to recognize is that each Get function will return -its zero value if it’s not found. To check if a given key exists, the IsSet() -method has been provided. +One important thing to recognize is that each Get function will return a zero +value if it’s not found. To check if a given key exists, the `IsSet()` method +has been provided. Example: - - viper.GetString("logfile") // case-insensitive Setting & Getting - if viper.GetBool("verbose") { - fmt.Println("verbose enabled") - } - +```go +viper.GetString("logfile") // case-insensitive Setting & Getting +if viper.GetBool("verbose") { + fmt.Println("verbose enabled") +} +``` ### Accessing nested keys -The accessor methods also accept formatted paths to deeply nested keys. -For example, if the following JSON file is loaded: +The accessor methods also accept formatted paths to deeply nested keys. For +example, if the following JSON file is loaded: -``` +```json { "host": { "address": "localhost", @@ -326,24 +443,26 @@ For example, if the following JSON file is loaded: ``` Viper can access a nested field by passing a `.` delimited path of keys: -``` + +```go GetString("datastore.metric.host") // (returns "127.0.0.1") ``` -This obeys the precendense rules established above; the search for the root key -(in this examole, `datastore`) will cascade through the remaining configuration registries -until found. The search for the subkeys (`metric` and `host`), however, will not. +This obeys the precedence rules established above; the search for the root key +(in this example, `datastore`) will cascade through the remaining configuration +registries until found. The search for the sub-keys (`metric` and `host`), +however, will not. For example, if the `metric` key was not defined in the configuration loaded from file, but was defined in the defaults, Viper would return the zero value. -On the other hand, if the primary key was not defined, Viper would go through the -remaining registries looking for it. +On the other hand, if the primary key was not defined, Viper would go through +the remaining registries looking for it. -Lastly, if there exists a key that matches the delimited key path, its value will -be returned instead. E.g. +Lastly, if there exists a key that matches the delimited key path, its value +will be returned instead. E.g. -``` +```json { "datastore.metric.host": "0.0.0.0", "host": { @@ -365,59 +484,110 @@ be returned instead. E.g. GetString("datastore.metric.host") //returns "0.0.0.0" ``` -### Marshaling +### Extract sub-tree -You also have the option of Marshaling all or a specific value to a struct, map, etc. +Extract sub-tree from Viper. + +For example, `viper` represents: + +```json +app: + cache1: + max-items: 100 + item-size: 64 + cache2: + max-items: 200 + item-size: 80 +``` + +After executing: + +```go +subv := viper.Sub("app.cache1") +``` + +`subv` represents: + +```json +max-items: 100 +item-size: 64 +``` + +Suppose we have: + +```go +func NewCache(cfg *Viper) *Cache {...} +``` + +which creates a cache based on config information formatted as `subv`. +Now it's easy to create these 2 caches separately as: + +```go +cfg1 := viper.Sub("app.cache1") +cache1 := NewCache(cfg1) + +cfg2 := viper.Sub("app.cache2") +cache2 := NewCache(cfg2) +``` + +### Unmarshaling + +You also have the option of Unmarshaling all or a specific value to a struct, map, +etc. There are two methods to do this: - * Marshal(rawVal interface{}) : error - * MarshalKey(key string, rawVal interface{}) : error + * `Unmarshal(rawVal interface{}) : error` + * `UnmarshalKey(key string, rawVal interface{}) : error` Example: - type config struct { - Port int - Name string - } +```go +type config struct { + Port int + Name string + PathMap string `mapstructure:"path_map"` +} - var C config - - err := Marshal(&C) - if err != nil { - t.Fatalf("unable to decode into struct, %v", err) - } +var C config +err := Unmarshal(&C) +if err != nil { + t.Fatalf("unable to decode into struct, %v", err) +} +``` ## Viper or Vipers? Viper comes ready to use out of the box. There is no configuration or -initialization needed to begin using Viper. Since most applications will -want to use a single central repository for their configuration, the -viper package provides this. It is similar to a singleton. +initialization needed to begin using Viper. Since most applications will want +to use a single central repository for their configuration, the viper package +provides this. It is similar to a singleton. -In all of the examples above, they demonstrate using viper in its -singleton style approach. +In all of the examples above, they demonstrate using viper in it's singleton +style approach. ### Working with multiple vipers -You can also create many different vipers for use in your application. -Each will have it’s own unique set of configurations and values. Each -can read from a different config file, key value store, etc. All of the -functions that viper package supports are mirrored as methods on a viper. +You can also create many different vipers for use in your application. Each will +have it’s own unique set of configurations and values. Each can read from a +different config file, key value store, etc. All of the functions that viper +package supports are mirrored as methods on a viper. Example: - x := viper.New() - y := viper.New() +```go +x := viper.New() +y := viper.New() - x.SetDefault("ContentDir", "content") - y.SetDefault("ContentDir", "foobar") +x.SetDefault("ContentDir", "content") +y.SetDefault("ContentDir", "foobar") - ... +//... +``` -When working with multiple vipers, it is up to the user to keep track of -the different vipers. +When working with multiple vipers, it is up to the user to keep track of the +different vipers. ## Q & A @@ -425,13 +595,13 @@ Q: Why not INI files? A: Ini files are pretty awful. There’s no standard format, and they are hard to validate. Viper is designed to work with JSON, TOML or YAML files. If someone -really wants to add this feature, I’d be happy to merge it. It’s easy to -specify which formats your application will permit. +really wants to add this feature, I’d be happy to merge it. It’s easy to specify +which formats your application will permit. Q: Why is it called “Viper”? -A: Viper is designed to be a [companion](http://en.wikipedia.org/wiki/Viper_(G.I._Joe)) to -[Cobra](https://github.com/spf13/cobra). While both can operate completely +A: Viper is designed to be a [companion](http://en.wikipedia.org/wiki/Viper_(G.I._Joe)) +to [Cobra](https://github.com/spf13/cobra). While both can operate completely independently, together they make a powerful pair to handle much of your application foundation needs. diff --git a/Godeps/_workspace/src/github.com/spf13/viper/flags.go b/Godeps/_workspace/src/github.com/spf13/viper/flags.go new file mode 100644 index 00000000..e433f30f --- /dev/null +++ b/Godeps/_workspace/src/github.com/spf13/viper/flags.go @@ -0,0 +1,57 @@ +package viper + +import "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/spf13/pflag" + +// FlagValueSet is an interface that users can implement +// to bind a set of flags to viper. +type FlagValueSet interface { + VisitAll(fn func(FlagValue)) +} + +// FlagValue is an interface that users can implement +// to bind different flags to viper. +type FlagValue interface { + HasChanged() bool + Name() string + ValueString() string + ValueType() string +} + +// pflagValueSet is a wrapper around *pflag.ValueSet +// that implements FlagValueSet. +type pflagValueSet struct { + flags *pflag.FlagSet +} + +// VisitAll iterates over all *pflag.Flag inside the *pflag.FlagSet. +func (p pflagValueSet) VisitAll(fn func(flag FlagValue)) { + p.flags.VisitAll(func(flag *pflag.Flag) { + fn(pflagValue{flag}) + }) +} + +// pflagValue is a wrapper aroung *pflag.flag +// that implements FlagValue +type pflagValue struct { + flag *pflag.Flag +} + +// HasChanges returns whether the flag has changes or not. +func (p pflagValue) HasChanged() bool { + return p.flag.Changed +} + +// Name returns the name of the flag. +func (p pflagValue) Name() string { + return p.flag.Name +} + +// ValueString returns the value of the flag as a string. +func (p pflagValue) ValueString() string { + return p.flag.Value.String() +} + +// ValueType returns the type of the flag as a string. +func (p pflagValue) ValueType() string { + return p.flag.Value.Type() +} diff --git a/Godeps/_workspace/src/github.com/spf13/viper/flags_test.go b/Godeps/_workspace/src/github.com/spf13/viper/flags_test.go new file mode 100644 index 00000000..0fcddf8f --- /dev/null +++ b/Godeps/_workspace/src/github.com/spf13/viper/flags_test.go @@ -0,0 +1,66 @@ +package viper + +import ( + "testing" + + "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/spf13/pflag" + "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/stretchr/testify/assert" +) + +func TestBindFlagValueSet(t *testing.T) { + flagSet := pflag.NewFlagSet("test", pflag.ContinueOnError) + + var testValues = map[string]*string{ + "host": nil, + "port": nil, + "endpoint": nil, + } + + var mutatedTestValues = map[string]string{ + "host": "localhost", + "port": "6060", + "endpoint": "/public", + } + + for name, _ := range testValues { + testValues[name] = flagSet.String(name, "", "test") + } + + flagValueSet := pflagValueSet{flagSet} + + err := BindFlagValues(flagValueSet) + if err != nil { + t.Fatalf("error binding flag set, %v", err) + } + + flagSet.VisitAll(func(flag *pflag.Flag) { + flag.Value.Set(mutatedTestValues[flag.Name]) + flag.Changed = true + }) + + for name, expected := range mutatedTestValues { + assert.Equal(t, Get(name), expected) + } +} + +func TestBindFlagValue(t *testing.T) { + var testString = "testing" + var testValue = newStringValue(testString, &testString) + + flag := &pflag.Flag{ + Name: "testflag", + Value: testValue, + Changed: false, + } + + flagValue := pflagValue{flag} + BindFlagValue("testvalue", flagValue) + + assert.Equal(t, testString, Get("testvalue")) + + flag.Value.Set("testing_mutate") + flag.Changed = true //hack for pflag usage + + assert.Equal(t, "testing_mutate", Get("testvalue")) + +} diff --git a/Godeps/_workspace/src/github.com/spf13/viper/remote/remote.go b/Godeps/_workspace/src/github.com/spf13/viper/remote/remote.go new file mode 100644 index 00000000..af25f968 --- /dev/null +++ b/Godeps/_workspace/src/github.com/spf13/viper/remote/remote.go @@ -0,0 +1,77 @@ +// Copyright © 2015 Steve Francia . +// +// Use of this source code is governed by an MIT-style +// license that can be found in the LICENSE file. + +// Package remote integrates the remote features of Viper. +package remote + +import ( + "bytes" + "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/spf13/viper" + crypt "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/xordataexchange/crypt/config" + "io" + "os" +) + +type remoteConfigProvider struct{} + +func (rc remoteConfigProvider) Get(rp viper.RemoteProvider) (io.Reader, error) { + cm, err := getConfigManager(rp) + if err != nil { + return nil, err + } + b, err := cm.Get(rp.Path()) + if err != nil { + return nil, err + } + return bytes.NewReader(b), nil +} + +func (rc remoteConfigProvider) Watch(rp viper.RemoteProvider) (io.Reader, error) { + cm, err := getConfigManager(rp) + if err != nil { + return nil, err + } + resp := <-cm.Watch(rp.Path(), nil) + err = resp.Error + if err != nil { + return nil, err + } + + return bytes.NewReader(resp.Value), nil +} + +func getConfigManager(rp viper.RemoteProvider) (crypt.ConfigManager, error) { + + var cm crypt.ConfigManager + var err error + + if rp.SecretKeyring() != "" { + kr, err := os.Open(rp.SecretKeyring()) + defer kr.Close() + if err != nil { + return nil, err + } + if rp.Provider() == "etcd" { + cm, err = crypt.NewEtcdConfigManager([]string{rp.Endpoint()}, kr) + } else { + cm, err = crypt.NewConsulConfigManager([]string{rp.Endpoint()}, kr) + } + } else { + if rp.Provider() == "etcd" { + cm, err = crypt.NewStandardEtcdConfigManager([]string{rp.Endpoint()}) + } else { + cm, err = crypt.NewStandardConsulConfigManager([]string{rp.Endpoint()}) + } + } + if err != nil { + return nil, err + } + return cm, nil + +} + +func init() { + viper.RemoteConfig = &remoteConfigProvider{} +} diff --git a/Godeps/_workspace/src/github.com/spf13/viper/util.go b/Godeps/_workspace/src/github.com/spf13/viper/util.go index 3493774a..15dd834b 100644 --- a/Godeps/_workspace/src/github.com/spf13/viper/util.go +++ b/Godeps/_workspace/src/github.com/spf13/viper/util.go @@ -21,6 +21,7 @@ import ( "strings" "unicode" + "github.com/hashicorp/hcl" "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/BurntSushi/toml" "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/magiconair/properties" "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/spf13/cast" @@ -28,6 +29,16 @@ import ( "github.com/khlieng/dispatch/Godeps/_workspace/src/gopkg.in/yaml.v2" ) +// Denotes failing to parse configuration file. +type ConfigParseError struct { + err error +} + +// Returns the formatted configuration error. +func (pe ConfigParseError) Error() string { + return fmt.Sprintf("While parsing config: %s", pe.err.Error()) +} + func insensitiviseMap(m map[string]interface{}) { for key, val := range m { lower := strings.ToLower(key) @@ -119,31 +130,40 @@ func findCWD() (string, error) { return path, nil } -func marshallConfigReader(in io.Reader, c map[string]interface{}, configType string) { +func unmarshallConfigReader(in io.Reader, c map[string]interface{}, configType string) error { buf := new(bytes.Buffer) buf.ReadFrom(in) switch strings.ToLower(configType) { case "yaml", "yml": if err := yaml.Unmarshal(buf.Bytes(), &c); err != nil { - jww.ERROR.Fatalf("Error parsing config: %s", err) + return ConfigParseError{err} } case "json": if err := json.Unmarshal(buf.Bytes(), &c); err != nil { - jww.ERROR.Fatalf("Error parsing config: %s", err) + return ConfigParseError{err} + } + + case "hcl": + obj, err := hcl.Parse(string(buf.Bytes())) + if err != nil { + return ConfigParseError{err} + } + if err = hcl.DecodeObject(&c, obj); err != nil { + return ConfigParseError{err} } case "toml": if _, err := toml.Decode(buf.String(), &c); err != nil { - jww.ERROR.Fatalf("Error parsing config: %s", err) + return ConfigParseError{err} } case "properties", "props", "prop": var p *properties.Properties var err error if p, err = properties.Load(buf.Bytes(), properties.UTF8); err != nil { - jww.ERROR.Fatalf("Error parsing config: %s", err) + return ConfigParseError{err} } for _, key := range p.Keys() { value, _ := p.Get(key) @@ -152,6 +172,7 @@ func marshallConfigReader(in io.Reader, c map[string]interface{}, configType str } insensitiviseMap(c) + return nil } func safeMul(a, b uint) uint { diff --git a/Godeps/_workspace/src/github.com/spf13/viper/viper.go b/Godeps/_workspace/src/github.com/spf13/viper/viper.go index e902086d..611b9bf4 100644 --- a/Godeps/_workspace/src/github.com/spf13/viper/viper.go +++ b/Godeps/_workspace/src/github.com/spf13/viper/viper.go @@ -24,6 +24,7 @@ import ( "fmt" "io" "io/ioutil" + "log" "os" "path/filepath" "reflect" @@ -35,7 +36,7 @@ import ( "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/spf13/cast" jww "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/spf13/jwalterweatherman" "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/spf13/pflag" - crypt "github.com/khlieng/dispatch/Godeps/_workspace/src/github.com/xordataexchange/crypt/config" + "gopkg.in/fsnotify.v1" ) var v *Viper @@ -44,6 +45,14 @@ func init() { v = New() } +type remoteConfigFactory interface { + Get(rp RemoteProvider) (io.Reader, error) + Watch(rp RemoteProvider) (io.Reader, error) +} + +// RemoteConfig is optional, see the remote package +var RemoteConfig remoteConfigFactory + // Denotes encountering an unsupported // configuration filetype. type UnsupportedConfigError string @@ -54,7 +63,7 @@ func (str UnsupportedConfigError) Error() string { } // Denotes encountering an unsupported remote -// provider. Currently only Etcd and Consul are +// provider. Currently only etcd and Consul are // supported. type UnsupportedRemoteProviderError string @@ -72,6 +81,16 @@ func (rce RemoteConfigError) Error() string { return fmt.Sprintf("Remote Configurations Error: %s", string(rce)) } +// Denotes failing to find configuration file. +type ConfigFileNotFoundError struct { + name, locations string +} + +// Returns the formatted configuration error. +func (fnfe ConfigFileNotFoundError) Error() string { + return fmt.Sprintf("Config File %q Not Found in %q", fnfe.name, fnfe.locations) +} + // Viper is a prioritized configuration registry. It // maintains a set of configuration sources, fetches // values to populate those, and provides them according @@ -89,11 +108,11 @@ func (rce RemoteConfigError) Error() string { // Defaults : { // "secret": "", // "user": "default", -// "endpoint": "https://localhost" +// "endpoint": "https://localhost" // } // Config : { // "user": "root" -// "secret": "defaultsecret" +// "secret": "defaultsecret" // } // Env : { // "secret": "somesecretkey" @@ -115,7 +134,7 @@ type Viper struct { configPaths []string // A set of remote providers to search for the configuration - remoteProviders []*remoteProvider + remoteProviders []*defaultRemoteProvider // Name of file to look for inside the path configName string @@ -126,13 +145,16 @@ type Viper struct { automaticEnvApplied bool envKeyReplacer *strings.Replacer - config map[string]interface{} - override map[string]interface{} - defaults map[string]interface{} - kvstore map[string]interface{} - pflags map[string]*pflag.Flag - env map[string]string - aliases map[string]string + config map[string]interface{} + override map[string]interface{} + defaults map[string]interface{} + kvstore map[string]interface{} + pflags map[string]FlagValue + env map[string]string + aliases map[string]string + typeByDefValue bool + + onConfigChange func(fsnotify.Event) } // Returns an initialized Viper instance. @@ -144,9 +166,10 @@ func New() *Viper { v.override = make(map[string]interface{}) v.defaults = make(map[string]interface{}) v.kvstore = make(map[string]interface{}) - v.pflags = make(map[string]*pflag.Flag) + v.pflags = make(map[string]FlagValue) v.env = make(map[string]string) v.aliases = make(map[string]string) + v.typeByDefValue = false return v } @@ -156,27 +179,94 @@ func New() *Viper { // can use it in their testing as well. func Reset() { v = New() - SupportedExts = []string{"json", "toml", "yaml", "yml"} + SupportedExts = []string{"json", "toml", "yaml", "yml", "hcl"} SupportedRemoteProviders = []string{"etcd", "consul"} } -// remoteProvider stores the configuration necessary -// to connect to a remote key/value store. -// Optional secretKeyring to unencrypt encrypted values -// can be provided. -type remoteProvider struct { +type defaultRemoteProvider struct { provider string endpoint string path string secretKeyring string } +func (rp defaultRemoteProvider) Provider() string { + return rp.provider +} + +func (rp defaultRemoteProvider) Endpoint() string { + return rp.endpoint +} + +func (rp defaultRemoteProvider) Path() string { + return rp.path +} + +func (rp defaultRemoteProvider) SecretKeyring() string { + return rp.secretKeyring +} + +// RemoteProvider stores the configuration necessary +// to connect to a remote key/value store. +// Optional secretKeyring to unencrypt encrypted values +// can be provided. +type RemoteProvider interface { + Provider() string + Endpoint() string + Path() string + SecretKeyring() string +} + // Universally supported extensions. -var SupportedExts []string = []string{"json", "toml", "yaml", "yml", "properties", "props", "prop"} +var SupportedExts []string = []string{"json", "toml", "yaml", "yml", "properties", "props", "prop", "hcl"} // Universally supported remote providers. var SupportedRemoteProviders []string = []string{"etcd", "consul"} +func OnConfigChange(run func(in fsnotify.Event)) { v.OnConfigChange(run) } +func (v *Viper) OnConfigChange(run func(in fsnotify.Event)) { + v.onConfigChange = run +} + +func WatchConfig() { v.WatchConfig() } +func (v *Viper) WatchConfig() { + go func() { + watcher, err := fsnotify.NewWatcher() + if err != nil { + log.Fatal(err) + } + defer watcher.Close() + + // we have to watch the entire directory to pick up renames/atomic saves in a cross-platform way + configFile := filepath.Clean(v.getConfigFile()) + configDir, _ := filepath.Split(configFile) + + done := make(chan bool) + go func() { + for { + select { + case event := <-watcher.Events: + // we only care about the config file + if filepath.Clean(event.Name) == configFile { + if event.Op&fsnotify.Write == fsnotify.Write || event.Op&fsnotify.Create == fsnotify.Create { + err := v.ReadInConfig() + if err != nil { + log.Println("error:", err) + } + v.onConfigChange(event) + } + } + case err := <-watcher.Errors: + log.Println("error:", err) + } + } + }() + + watcher.Add(configDir) + <-done + }() +} + // Explicitly define the path, name and extension of the config file // Viper will use this and not check any of the config paths func SetConfigFile(in string) { v.SetConfigFile(in) } @@ -252,7 +342,7 @@ func (v *Viper) AddRemoteProvider(provider, endpoint, path string) error { } if provider != "" && endpoint != "" { jww.INFO.Printf("adding %s:%s to remote provider list", provider, endpoint) - rp := &remoteProvider{ + rp := &defaultRemoteProvider{ endpoint: endpoint, provider: provider, path: path, @@ -284,10 +374,11 @@ func (v *Viper) AddSecureRemoteProvider(provider, endpoint, path, secretkeyring } if provider != "" && endpoint != "" { jww.INFO.Printf("adding %s:%s to remote provider list", provider, endpoint) - rp := &remoteProvider{ - endpoint: endpoint, - provider: provider, - path: path, + rp := &defaultRemoteProvider{ + endpoint: endpoint, + provider: provider, + path: path, + secretKeyring: secretkeyring, } if !v.providerPathExists(rp) { v.remoteProviders = append(v.remoteProviders, rp) @@ -296,7 +387,7 @@ func (v *Viper) AddSecureRemoteProvider(provider, endpoint, path, secretkeyring return nil } -func (v *Viper) providerPathExists(p *remoteProvider) bool { +func (v *Viper) providerPathExists(p *defaultRemoteProvider) bool { for _, y := range v.remoteProviders { if reflect.DeepEqual(y, p) { return true @@ -311,8 +402,20 @@ func (v *Viper) searchMap(source map[string]interface{}, path []string) interfac return source } - if next, ok := source[path[0]]; ok { + var ok bool + var next interface{} + for k, v := range source { + if strings.ToLower(k) == strings.ToLower(path[0]) { + ok = true + next = v + break + } + } + + if ok { switch next.(type) { + case map[interface{}]interface{}: + return v.searchMap(cast.ToStringMap(next), path[1:]) case map[string]interface{}: // Type assertion is safe here since it is only reached // if the type of `next` is the same as the type being asserted @@ -325,6 +428,25 @@ func (v *Viper) searchMap(source map[string]interface{}, path []string) interfac } } +// SetTypeByDefaultValue enables or disables the inference of a key value's +// type when the Get function is used based upon a key's default value as +// opposed to the value returned based on the normal fetch logic. +// +// For example, if a key has a default value of []string{} and the same key +// is set via an environment variable to "a b c", a call to the Get function +// would return a string slice for the key if the key's type is inferred by +// the default value and the Get function would return: +// +// []string {"a", "b", "c"} +// +// Otherwise the Get function would return: +// +// "a b c" +func SetTypeByDefaultValue(enable bool) { v.SetTypeByDefaultValue(enable) } +func (v *Viper) SetTypeByDefaultValue(enable bool) { + v.typeByDefValue = enable +} + // Viper is essentially repository for configurations // Get can retrieve any value given the key to use // Get has the behavior of returning the value associated with the first @@ -336,20 +458,51 @@ func Get(key string) interface{} { return v.Get(key) } func (v *Viper) Get(key string) interface{} { path := strings.Split(key, v.keyDelim) - val := v.find(strings.ToLower(key)) + lcaseKey := strings.ToLower(key) + val := v.find(lcaseKey) if val == nil { - source := v.find(path[0]) - if source == nil { - return nil - } - - if reflect.TypeOf(source).Kind() == reflect.Map { - val = v.searchMap(cast.ToStringMap(source), path[1:]) + source := v.find(strings.ToLower(path[0])) + if source != nil { + if reflect.TypeOf(source).Kind() == reflect.Map { + val = v.searchMap(cast.ToStringMap(source), path[1:]) + } } } - switch val.(type) { + // if no other value is returned and a flag does exist for the value, + // get the flag's value even if the flag's value has not changed + if val == nil { + if flag, exists := v.pflags[lcaseKey]; exists { + jww.TRACE.Println(key, "get pflag default", val) + switch flag.ValueType() { + case "int", "int8", "int16", "int32", "int64": + val = cast.ToInt(flag.ValueString()) + case "bool": + val = cast.ToBool(flag.ValueString()) + default: + val = flag.ValueString() + } + } + } + + if val == nil { + return nil + } + + var valType interface{} + if !v.typeByDefValue { + valType = val + } else { + defVal, defExists := v.defaults[lcaseKey] + if defExists { + valType = defVal + } else { + valType = val + } + } + + switch valType.(type) { case bool: return cast.ToBool(val) case string: @@ -363,11 +516,24 @@ func (v *Viper) Get(key string) interface{} { case time.Duration: return cast.ToDuration(val) case []string: - return val + return cast.ToStringSlice(val) } return val } +// Returns new Viper instance representing a sub tree of this instance +func Sub(key string) *Viper { return v.Sub(key) } +func (v *Viper) Sub(key string) *Viper { + subv := New() + data := v.Get(key) + if reflect.TypeOf(data).Kind() == reflect.Map { + subv.config = cast.ToStringMap(data) + return subv + } else { + return nil + } +} + // Returns the value associated with the key as a string func GetString(key string) string { return v.GetString(key) } func (v *Viper) GetString(key string) string { @@ -422,6 +588,12 @@ func (v *Viper) GetStringMapString(key string) map[string]string { return cast.ToStringMapString(v.Get(key)) } +// Returns the value associated with the key as a map to a slice of strings. +func GetStringMapStringSlice(key string) map[string][]string { return v.GetStringMapStringSlice(key) } +func (v *Viper) GetStringMapStringSlice(key string) map[string][]string { + return cast.ToStringMapStringSlice(v.Get(key)) +} + // Returns the size of the value associated with the given key // in bytes. func GetSizeInBytes(key string) uint { return v.GetSizeInBytes(key) } @@ -430,16 +602,16 @@ func (v *Viper) GetSizeInBytes(key string) uint { return parseSizeInBytes(sizeStr) } -// Takes a single key and marshals it into a Struct -func MarshalKey(key string, rawVal interface{}) error { return v.MarshalKey(key, rawVal) } -func (v *Viper) MarshalKey(key string, rawVal interface{}) error { +// Takes a single key and unmarshals it into a Struct +func UnmarshalKey(key string, rawVal interface{}) error { return v.UnmarshalKey(key, rawVal) } +func (v *Viper) UnmarshalKey(key string, rawVal interface{}) error { return mapstructure.Decode(v.Get(key), rawVal) } -// Marshals the config into a Struct. Make sure that the tags +// Unmarshals the config into a Struct. Make sure that the tags // on the fields of the structure are properly set. -func Marshal(rawVal interface{}) error { return v.Marshal(rawVal) } -func (v *Viper) Marshal(rawVal interface{}) error { +func Unmarshal(rawVal interface{}) error { return v.Unmarshal(rawVal) } +func (v *Viper) Unmarshal(rawVal interface{}) error { err := mapstructure.WeakDecode(v.AllSettings(), rawVal) if err != nil { @@ -455,26 +627,10 @@ func (v *Viper) Marshal(rawVal interface{}) error { // name as the config key. func BindPFlags(flags *pflag.FlagSet) (err error) { return v.BindPFlags(flags) } func (v *Viper) BindPFlags(flags *pflag.FlagSet) (err error) { - flags.VisitAll(func(flag *pflag.Flag) { - if err != nil { - // an error has been encountered in one of the previous flags - return - } - - err = v.BindPFlag(flag.Name, flag) - switch flag.Value.Type() { - case "int", "int8", "int16", "int32", "int64": - v.SetDefault(flag.Name, cast.ToInt(flag.Value.String())) - case "bool": - v.SetDefault(flag.Name, cast.ToBool(flag.Value.String())) - default: - v.SetDefault(flag.Name, flag.Value.String()) - } - }) - return + return v.BindFlagValues(pflagValueSet{flags}) } -// Bind a specific key to a flag (as used by cobra) +// Bind a specific key to a pflag (as used by cobra) // Example(where serverCmd is a Cobra instance): // // serverCmd.Flags().Int("port", 1138, "Port to run Application server on") @@ -482,19 +638,33 @@ func (v *Viper) BindPFlags(flags *pflag.FlagSet) (err error) { // func BindPFlag(key string, flag *pflag.Flag) (err error) { return v.BindPFlag(key, flag) } func (v *Viper) BindPFlag(key string, flag *pflag.Flag) (err error) { + return v.BindFlagValue(key, pflagValue{flag}) +} + +// Bind a full FlagValue set to the configuration, using each flag's long +// name as the config key. +func BindFlagValues(flags FlagValueSet) (err error) { return v.BindFlagValues(flags) } +func (v *Viper) BindFlagValues(flags FlagValueSet) (err error) { + flags.VisitAll(func(flag FlagValue) { + if err = v.BindFlagValue(flag.Name(), flag); err != nil { + return + } + }) + return nil +} + +// Bind a specific key to a FlagValue. +// Example(where serverCmd is a Cobra instance): +// +// serverCmd.Flags().Int("port", 1138, "Port to run Application server on") +// Viper.BindFlagValue("port", serverCmd.Flags().Lookup("port")) +// +func BindFlagValue(key string, flag FlagValue) (err error) { return v.BindFlagValue(key, flag) } +func (v *Viper) BindFlagValue(key string, flag FlagValue) (err error) { if flag == nil { return fmt.Errorf("flag for %q is nil", key) } v.pflags[strings.ToLower(key)] = flag - - switch flag.Value.Type() { - case "int", "int8", "int16", "int32", "int64": - SetDefault(key, cast.ToInt(flag.Value.String())) - case "bool": - SetDefault(key, cast.ToBool(flag.Value.String())) - default: - SetDefault(key, flag.Value.String()) - } return nil } @@ -535,10 +705,15 @@ func (v *Viper) find(key string) interface{} { // PFlag Override first flag, exists := v.pflags[key] - if exists { - if flag.Changed { - jww.TRACE.Println(key, "found in override (via pflag):", val) - return flag.Value.String() + if exists && flag.HasChanged() { + jww.TRACE.Println(key, "found in override (via pflag):", flag.ValueString()) + switch flag.ValueType() { + case "int", "int8", "int16", "int32", "int64": + return cast.ToInt(flag.ValueString()) + case "bool": + return cast.ToBool(flag.ValueString()) + default: + return flag.ValueString() } } @@ -574,6 +749,20 @@ func (v *Viper) find(key string) interface{} { return val } + // Test for nested config parameter + if strings.Contains(key, v.keyDelim) { + path := strings.Split(key, v.keyDelim) + + source := v.find(path[0]) + if source != nil { + if reflect.TypeOf(source).Kind() == reflect.Map { + val := v.searchMap(cast.ToStringMap(source), path[1:]) + jww.TRACE.Println(key, "found in nested config:", val) + return val + } + } + } + val, exists = v.kvstore[key] if exists { jww.TRACE.Println(key, "found in key/value store:", val) @@ -592,8 +781,21 @@ func (v *Viper) find(key string) interface{} { // Check to see if the key has been set in any of the data locations func IsSet(key string) bool { return v.IsSet(key) } func (v *Viper) IsSet(key string) bool { - t := v.Get(key) - return t != nil + path := strings.Split(key, v.keyDelim) + + lcaseKey := strings.ToLower(key) + val := v.find(lcaseKey) + + if val == nil { + source := v.find(strings.ToLower(path[0])) + if source != nil { + if reflect.TypeOf(source).Kind() == reflect.Map { + val = v.searchMap(cast.ToStringMap(source), path[1:]) + } + } + } + + return val != nil } // Have Viper check ENV variables for all @@ -705,22 +907,131 @@ func (v *Viper) ReadInConfig() error { v.config = make(map[string]interface{}) - v.marshalReader(bytes.NewReader(file), v.config) - return nil + return v.unmarshalReader(bytes.NewReader(file), v.config) } +// MergeInConfig merges a new configuration with an existing config. +func MergeInConfig() error { return v.MergeInConfig() } +func (v *Viper) MergeInConfig() error { + jww.INFO.Println("Attempting to merge in config file") + if !stringInSlice(v.getConfigType(), SupportedExts) { + return UnsupportedConfigError(v.getConfigType()) + } + + file, err := ioutil.ReadFile(v.getConfigFile()) + if err != nil { + return err + } + + return v.MergeConfig(bytes.NewReader(file)) +} + +// Viper will read a configuration file, setting existing keys to nil if the +// key does not exist in the file. func ReadConfig(in io.Reader) error { return v.ReadConfig(in) } func (v *Viper) ReadConfig(in io.Reader) error { v.config = make(map[string]interface{}) - v.marshalReader(in, v.config) + return v.unmarshalReader(in, v.config) +} + +// MergeConfig merges a new configuration with an existing config. +func MergeConfig(in io.Reader) error { return v.MergeConfig(in) } +func (v *Viper) MergeConfig(in io.Reader) error { + if v.config == nil { + v.config = make(map[string]interface{}) + } + cfg := make(map[string]interface{}) + if err := v.unmarshalReader(in, cfg); err != nil { + return err + } + mergeMaps(cfg, v.config, nil) return nil } +func keyExists(k string, m map[string]interface{}) string { + lk := strings.ToLower(k) + for mk := range m { + lmk := strings.ToLower(mk) + if lmk == lk { + return mk + } + } + return "" +} + +func castToMapStringInterface( + src map[interface{}]interface{}) map[string]interface{} { + tgt := map[string]interface{}{} + for k, v := range src { + tgt[fmt.Sprintf("%v", k)] = v + } + return tgt +} + +// mergeMaps merges two maps. The `itgt` parameter is for handling go-yaml's +// insistence on parsing nested structures as `map[interface{}]interface{}` +// instead of using a `string` as the key for nest structures beyond one level +// deep. Both map types are supported as there is a go-yaml fork that uses +// `map[string]interface{}` instead. +func mergeMaps( + src, tgt map[string]interface{}, itgt map[interface{}]interface{}) { + for sk, sv := range src { + tk := keyExists(sk, tgt) + if tk == "" { + jww.TRACE.Printf("tk=\"\", tgt[%s]=%v", sk, sv) + tgt[sk] = sv + if itgt != nil { + itgt[sk] = sv + } + continue + } + + tv, ok := tgt[tk] + if !ok { + jww.TRACE.Printf("tgt[%s] != ok, tgt[%s]=%v", tk, sk, sv) + tgt[sk] = sv + if itgt != nil { + itgt[sk] = sv + } + continue + } + + svType := reflect.TypeOf(sv) + tvType := reflect.TypeOf(tv) + if svType != tvType { + jww.ERROR.Printf( + "svType != tvType; key=%s, st=%v, tt=%v, sv=%v, tv=%v", + sk, svType, tvType, sv, tv) + continue + } + + jww.TRACE.Printf("processing key=%s, st=%v, tt=%v, sv=%v, tv=%v", + sk, svType, tvType, sv, tv) + + switch ttv := tv.(type) { + case map[interface{}]interface{}: + jww.TRACE.Printf("merging maps (must convert)") + tsv := sv.(map[interface{}]interface{}) + ssv := castToMapStringInterface(tsv) + stv := castToMapStringInterface(ttv) + mergeMaps(ssv, stv, ttv) + case map[string]interface{}: + jww.TRACE.Printf("merging maps") + mergeMaps(sv.(map[string]interface{}), ttv, nil) + default: + jww.TRACE.Printf("setting value") + tgt[tk] = sv + if itgt != nil { + itgt[tk] = sv + } + } + } +} + // func ReadBufConfig(buf *bytes.Buffer) error { return v.ReadBufConfig(buf) } // func (v *Viper) ReadBufConfig(buf *bytes.Buffer) error { // v.config = make(map[string]interface{}) -// v.marshalReader(buf, v.config) -// return nil +// return v.unmarshalReader(buf, v.config) // } // Attempts to get configuration from a remote source @@ -743,11 +1054,14 @@ func (v *Viper) WatchRemoteConfig() error { return nil } -// Marshall a Reader into a map +// Unmarshall a Reader into a map // Should probably be an unexported function -func marshalReader(in io.Reader, c map[string]interface{}) { v.marshalReader(in, c) } -func (v *Viper) marshalReader(in io.Reader, c map[string]interface{}) { - marshallConfigReader(in, c, v.getConfigType()) +func unmarshalReader(in io.Reader, c map[string]interface{}) error { + return v.unmarshalReader(in, c) +} + +func (v *Viper) unmarshalReader(in io.Reader, c map[string]interface{}) error { + return unmarshallConfigReader(in, c, v.getConfigType()) } func (v *Viper) insensitiviseMaps() { @@ -759,6 +1073,10 @@ func (v *Viper) insensitiviseMaps() { // retrieve the first found remote configuration func (v *Viper) getKeyValueConfig() error { + if RemoteConfig == nil { + return RemoteConfigError("Enable the remote features by doing a blank import of the viper/remote package: '_ github.com/spf13/viper/remote'") + } + for _, rp := range v.remoteProviders { val, err := v.getRemoteConfig(rp) if err != nil { @@ -770,37 +1088,13 @@ func (v *Viper) getKeyValueConfig() error { return RemoteConfigError("No Files Found") } -func (v *Viper) getRemoteConfig(provider *remoteProvider) (map[string]interface{}, error) { - var cm crypt.ConfigManager - var err error +func (v *Viper) getRemoteConfig(provider *defaultRemoteProvider) (map[string]interface{}, error) { - if provider.secretKeyring != "" { - kr, err := os.Open(provider.secretKeyring) - defer kr.Close() - if err != nil { - return nil, err - } - if provider.provider == "etcd" { - cm, err = crypt.NewEtcdConfigManager([]string{provider.endpoint}, kr) - } else { - cm, err = crypt.NewConsulConfigManager([]string{provider.endpoint}, kr) - } - } else { - if provider.provider == "etcd" { - cm, err = crypt.NewStandardEtcdConfigManager([]string{provider.endpoint}) - } else { - cm, err = crypt.NewStandardConsulConfigManager([]string{provider.endpoint}) - } - } + reader, err := RemoteConfig.Get(provider) if err != nil { return nil, err } - b, err := cm.Get(provider.path) - if err != nil { - return nil, err - } - reader := bytes.NewReader(b) - v.marshalReader(reader, v.kvstore) + err = v.unmarshalReader(reader, v.kvstore) return v.kvstore, err } @@ -817,40 +1111,12 @@ func (v *Viper) watchKeyValueConfig() error { return RemoteConfigError("No Files Found") } -func (v *Viper) watchRemoteConfig(provider *remoteProvider) (map[string]interface{}, error) { - var cm crypt.ConfigManager - var err error - - if provider.secretKeyring != "" { - kr, err := os.Open(provider.secretKeyring) - defer kr.Close() - if err != nil { - return nil, err - } - if provider.provider == "etcd" { - cm, err = crypt.NewEtcdConfigManager([]string{provider.endpoint}, kr) - } else { - cm, err = crypt.NewConsulConfigManager([]string{provider.endpoint}, kr) - } - } else { - if provider.provider == "etcd" { - cm, err = crypt.NewStandardEtcdConfigManager([]string{provider.endpoint}) - } else { - cm, err = crypt.NewStandardConsulConfigManager([]string{provider.endpoint}) - } - } +func (v *Viper) watchRemoteConfig(provider *defaultRemoteProvider) (map[string]interface{}, error) { + reader, err := RemoteConfig.Watch(provider) if err != nil { return nil, err } - resp := <-cm.Watch(provider.path, nil) - // b, err := cm.Watch(provider.path, nil) - err = resp.Error - if err != nil { - return nil, err - } - - reader := bytes.NewReader(resp.Value) - v.marshalReader(reader, v.kvstore) + err = v.unmarshalReader(reader, v.kvstore) return v.kvstore, err } @@ -860,19 +1126,27 @@ func (v *Viper) AllKeys() []string { m := map[string]struct{}{} for key, _ := range v.defaults { - m[key] = struct{}{} + m[strings.ToLower(key)] = struct{}{} + } + + for key, _ := range v.pflags { + m[strings.ToLower(key)] = struct{}{} + } + + for key, _ := range v.env { + m[strings.ToLower(key)] = struct{}{} } for key, _ := range v.config { - m[key] = struct{}{} + m[strings.ToLower(key)] = struct{}{} } for key, _ := range v.kvstore { - m[key] = struct{}{} + m[strings.ToLower(key)] = struct{}{} } for key, _ := range v.override { - m[key] = struct{}{} + m[strings.ToLower(key)] = struct{}{} } a := []string{} @@ -958,6 +1232,7 @@ func (v *Viper) searchInPath(in string) (filename string) { // search all configPaths for any config file. // Returns the first path that exists (and is a config file) func (v *Viper) findConfigFile() (string, error) { + jww.INFO.Println("Searching for config in ", v.configPaths) for _, cp := range v.configPaths { @@ -966,14 +1241,7 @@ func (v *Viper) findConfigFile() (string, error) { return file, nil } } - - // try the current working directory - wd, _ := os.Getwd() - file := v.searchInPath(wd) - if file != "" { - return file, nil - } - return "", fmt.Errorf("config file not found in: %s", v.configPaths) + return "", ConfigFileNotFoundError{v.configName, fmt.Sprintf("%s", v.configPaths)} } // Prints all configuration registries for debugging diff --git a/Godeps/_workspace/src/github.com/spf13/viper/viper_test.go b/Godeps/_workspace/src/github.com/spf13/viper/viper_test.go index c1b56c7b..07847118 100644 --- a/Godeps/_workspace/src/github.com/spf13/viper/viper_test.go +++ b/Godeps/_workspace/src/github.com/spf13/viper/viper_test.go @@ -8,7 +8,10 @@ package viper import ( "bytes" "fmt" + "io/ioutil" "os" + "path" + "reflect" "sort" "strings" "testing" @@ -27,6 +30,8 @@ hobbies: clothing: jacket: leather trousers: denim + pants: + size: large age: 35 eyes : brown beard: true @@ -55,6 +60,26 @@ var jsonExample = []byte(`{ } }`) +var hclExample = []byte(` +id = "0001" +type = "donut" +name = "Cake" +ppu = 0.55 +foos { + foo { + key = 1 + } + foo { + key = 2 + } + foo { + key = 3 + } + foo { + key = 4 + } +}`) + var propertiesExample = []byte(` p_id: 0001 p_type: donut @@ -73,23 +98,27 @@ func initConfigs() { Reset() SetConfigType("yaml") r := bytes.NewReader(yamlExample) - marshalReader(r, v.config) + unmarshalReader(r, v.config) SetConfigType("json") r = bytes.NewReader(jsonExample) - marshalReader(r, v.config) + unmarshalReader(r, v.config) + + SetConfigType("hcl") + r = bytes.NewReader(hclExample) + unmarshalReader(r, v.config) SetConfigType("properties") r = bytes.NewReader(propertiesExample) - marshalReader(r, v.config) + unmarshalReader(r, v.config) SetConfigType("toml") r = bytes.NewReader(tomlExample) - marshalReader(r, v.config) + unmarshalReader(r, v.config) SetConfigType("json") remote := bytes.NewReader(remoteExample) - marshalReader(remote, v.kvstore) + unmarshalReader(remote, v.kvstore) } func initYAML() { @@ -97,7 +126,7 @@ func initYAML() { SetConfigType("yaml") r := bytes.NewReader(yamlExample) - marshalReader(r, v.config) + unmarshalReader(r, v.config) } func initJSON() { @@ -105,7 +134,7 @@ func initJSON() { SetConfigType("json") r := bytes.NewReader(jsonExample) - marshalReader(r, v.config) + unmarshalReader(r, v.config) } func initProperties() { @@ -113,7 +142,7 @@ func initProperties() { SetConfigType("properties") r := bytes.NewReader(propertiesExample) - marshalReader(r, v.config) + unmarshalReader(r, v.config) } func initTOML() { @@ -121,7 +150,56 @@ func initTOML() { SetConfigType("toml") r := bytes.NewReader(tomlExample) - marshalReader(r, v.config) + unmarshalReader(r, v.config) +} + +func initHcl() { + Reset() + SetConfigType("hcl") + r := bytes.NewReader(hclExample) + + unmarshalReader(r, v.config) +} + +// make directories for testing +func initDirs(t *testing.T) (string, string, func()) { + + var ( + testDirs = []string{`a a`, `b`, `c\c`, `D_`} + config = `improbable` + ) + + root, err := ioutil.TempDir("", "") + + cleanup := true + defer func() { + if cleanup { + os.Chdir("..") + os.RemoveAll(root) + } + }() + + assert.Nil(t, err) + + err = os.Chdir(root) + assert.Nil(t, err) + + for _, dir := range testDirs { + err = os.Mkdir(dir, 0750) + assert.Nil(t, err) + + err = ioutil.WriteFile( + path.Join(dir, config+".toml"), + []byte("key = \"value is "+dir+"\"\n"), + 0640) + assert.Nil(t, err) + } + + cleanup = false + return root, config, func() { + os.Chdir("..") + os.RemoveAll(root) + } } //stubs for PFlag Values @@ -153,18 +231,27 @@ func TestBasics(t *testing.T) { func TestDefault(t *testing.T) { SetDefault("age", 45) assert.Equal(t, 45, Get("age")) + + SetDefault("clothing.jacket", "slacks") + assert.Equal(t, "slacks", Get("clothing.jacket")) + + SetConfigType("yaml") + err := ReadConfig(bytes.NewBuffer(yamlExample)) + + assert.NoError(t, err) + assert.Equal(t, "leather", Get("clothing.jacket")) } -func TestMarshalling(t *testing.T) { +func TestUnmarshalling(t *testing.T) { SetConfigType("yaml") r := bytes.NewReader(yamlExample) - marshalReader(r, v.config) + unmarshalReader(r, v.config) assert.True(t, InConfig("name")) assert.False(t, InConfig("state")) assert.Equal(t, "steve", Get("name")) assert.Equal(t, []interface{}{"skateboarding", "snowboarding", "go"}, Get("hobbies")) - assert.Equal(t, map[interface{}]interface{}{"jacket": "leather", "trousers": "denim"}, Get("clothing")) + assert.Equal(t, map[interface{}]interface{}{"jacket": "leather", "trousers": "denim", "pants": map[interface{}]interface{}{"size": "large"}}, Get("clothing")) assert.Equal(t, 35, Get("age")) } @@ -215,12 +302,23 @@ func TestTOML(t *testing.T) { assert.Equal(t, "TOML Example", Get("title")) } +func TestHCL(t *testing.T) { + initHcl() + assert.Equal(t, "0001", Get("id")) + assert.Equal(t, 0.55, Get("ppu")) + assert.Equal(t, "donut", Get("type")) + assert.Equal(t, "Cake", Get("name")) + Set("id", "0002") + assert.Equal(t, "0002", Get("id")) + assert.NotEqual(t, "cronut", Get("type")) +} + func TestRemotePrecedence(t *testing.T) { initJSON() remote := bytes.NewReader(remoteExample) assert.Equal(t, "0001", Get("id")) - marshalReader(remote, v.kvstore) + unmarshalReader(remote, v.kvstore) assert.Equal(t, "0001", Get("id")) assert.NotEqual(t, "cronut", Get("type")) assert.Equal(t, "remote", Get("newkey")) @@ -302,9 +400,9 @@ func TestSetEnvReplacer(t *testing.T) { func TestAllKeys(t *testing.T) { initConfigs() - ks := sort.StringSlice{"title", "newkey", "owner", "name", "beard", "ppu", "batters", "hobbies", "clothing", "age", "hacker", "id", "type", "eyes", "p_id", "p_ppu", "p_batters.batter.type", "p_type", "p_name"} + ks := sort.StringSlice{"title", "newkey", "owner", "name", "beard", "ppu", "batters", "hobbies", "clothing", "age", "hacker", "id", "type", "eyes", "p_id", "p_ppu", "p_batters.batter.type", "p_type", "p_name", "foos"} dob, _ := time.Parse(time.RFC3339, "1979-05-27T07:32:00Z") - all := map[string]interface{}{"owner": map[string]interface{}{"organization": "MongoDB", "Bio": "MongoDB Chief Developer Advocate & Hacker at Large", "dob": dob}, "title": "TOML Example", "ppu": 0.55, "eyes": "brown", "clothing": map[interface{}]interface{}{"trousers": "denim", "jacket": "leather"}, "id": "0001", "batters": map[string]interface{}{"batter": []interface{}{map[string]interface{}{"type": "Regular"}, map[string]interface{}{"type": "Chocolate"}, map[string]interface{}{"type": "Blueberry"}, map[string]interface{}{"type": "Devil's Food"}}}, "hacker": true, "beard": true, "hobbies": []interface{}{"skateboarding", "snowboarding", "go"}, "age": 35, "type": "donut", "newkey": "remote", "name": "Cake", "p_id": "0001", "p_ppu": "0.55", "p_name": "Cake", "p_batters.batter.type": "Regular", "p_type": "donut"} + all := map[string]interface{}{"owner": map[string]interface{}{"organization": "MongoDB", "Bio": "MongoDB Chief Developer Advocate & Hacker at Large", "dob": dob}, "title": "TOML Example", "ppu": 0.55, "eyes": "brown", "clothing": map[interface{}]interface{}{"trousers": "denim", "jacket": "leather", "pants": map[interface{}]interface{}{"size": "large"}}, "id": "0001", "batters": map[string]interface{}{"batter": []interface{}{map[string]interface{}{"type": "Regular"}, map[string]interface{}{"type": "Chocolate"}, map[string]interface{}{"type": "Blueberry"}, map[string]interface{}{"type": "Devil's Food"}}}, "hacker": true, "beard": true, "hobbies": []interface{}{"skateboarding", "snowboarding", "go"}, "age": 35, "type": "donut", "newkey": "remote", "name": "Cake", "p_id": "0001", "p_ppu": "0.55", "p_name": "Cake", "p_batters.batter.type": "Regular", "p_type": "donut", "foos": []map[string]interface{}{map[string]interface{}{"foo": []map[string]interface{}{map[string]interface{}{"key": 1}, map[string]interface{}{"key": 2}, map[string]interface{}{"key": 3}, map[string]interface{}{"key": 4}}}}} var allkeys sort.StringSlice allkeys = AllKeys() @@ -332,7 +430,7 @@ func TestRecursiveAliases(t *testing.T) { RegisterAlias("Roo", "baz") } -func TestMarshal(t *testing.T) { +func TestUnmarshal(t *testing.T) { SetDefault("port", 1313) Set("name", "Steve") @@ -343,7 +441,7 @@ func TestMarshal(t *testing.T) { var C config - err := Marshal(&C) + err := Unmarshal(&C) if err != nil { t.Fatalf("unable to decode into struct, %v", err) } @@ -351,7 +449,7 @@ func TestMarshal(t *testing.T) { assert.Equal(t, &C, &config{Name: "Steve", Port: 1313}) Set("port", 1234) - err = Marshal(&C) + err = Unmarshal(&C) if err != nil { t.Fatalf("unable to decode into struct, %v", err) } @@ -524,11 +622,33 @@ func TestFindsNestedKeys(t *testing.T) { "clothing": map[interface{}]interface{}{ "jacket": "leather", "trousers": "denim", + "pants": map[interface{}]interface{}{ + "size": "large", + }, + }, + "clothing.jacket": "leather", + "clothing.pants.size": "large", + "clothing.trousers": "denim", + "owner.dob": dob, + "beard": true, + "foos": []map[string]interface{}{ + map[string]interface{}{ + "foo": []map[string]interface{}{ + map[string]interface{}{ + "key": 1, + }, + map[string]interface{}{ + "key": 2, + }, + map[string]interface{}{ + "key": 3, + }, + map[string]interface{}{ + "key": 4, + }, + }, + }, }, - "clothing.jacket": "leather", - "clothing.trousers": "denim", - "owner.dob": dob, - "beard": true, } for key, expectedValue := range expected { @@ -548,6 +668,173 @@ func TestReadBufConfig(t *testing.T) { assert.False(t, v.InConfig("state")) assert.Equal(t, "steve", v.Get("name")) assert.Equal(t, []interface{}{"skateboarding", "snowboarding", "go"}, v.Get("hobbies")) - assert.Equal(t, map[interface{}]interface{}{"jacket": "leather", "trousers": "denim"}, v.Get("clothing")) + assert.Equal(t, map[interface{}]interface{}{"jacket": "leather", "trousers": "denim", "pants": map[interface{}]interface{}{"size": "large"}}, v.Get("clothing")) assert.Equal(t, 35, v.Get("age")) } + +func TestIsSet(t *testing.T) { + v := New() + v.SetConfigType("yaml") + v.ReadConfig(bytes.NewBuffer(yamlExample)) + assert.True(t, v.IsSet("clothing.jacket")) + assert.False(t, v.IsSet("clothing.jackets")) + assert.False(t, v.IsSet("helloworld")) + v.Set("helloworld", "fubar") + assert.True(t, v.IsSet("helloworld")) +} + +func TestDirsSearch(t *testing.T) { + + root, config, cleanup := initDirs(t) + defer cleanup() + + v := New() + v.SetConfigName(config) + v.SetDefault(`key`, `default`) + + entries, err := ioutil.ReadDir(root) + for _, e := range entries { + if e.IsDir() { + v.AddConfigPath(e.Name()) + } + } + + err = v.ReadInConfig() + assert.Nil(t, err) + + assert.Equal(t, `value is `+path.Base(v.configPaths[0]), v.GetString(`key`)) +} + +func TestWrongDirsSearchNotFound(t *testing.T) { + + _, config, cleanup := initDirs(t) + defer cleanup() + + v := New() + v.SetConfigName(config) + v.SetDefault(`key`, `default`) + + v.AddConfigPath(`whattayoutalkingbout`) + v.AddConfigPath(`thispathaintthere`) + + err := v.ReadInConfig() + assert.Equal(t, reflect.TypeOf(UnsupportedConfigError("")), reflect.TypeOf(err)) + + // Even though config did not load and the error might have + // been ignored by the client, the default still loads + assert.Equal(t, `default`, v.GetString(`key`)) +} + +func TestSub(t *testing.T) { + v := New() + v.SetConfigType("yaml") + v.ReadConfig(bytes.NewBuffer(yamlExample)) + + subv := v.Sub("clothing") + assert.Equal(t, v.Get("clothing.pants.size"), subv.Get("pants.size")) + + subv = v.Sub("clothing.pants") + assert.Equal(t, v.Get("clothing.pants.size"), subv.Get("size")) + + subv = v.Sub("clothing.pants.size") + assert.Equal(t, subv, (*Viper)(nil)) +} + +var yamlMergeExampleTgt = []byte(` +hello: + pop: 37890 + world: + - us + - uk + - fr + - de +`) + +var yamlMergeExampleSrc = []byte(` +hello: + pop: 45000 + universe: + - mw + - ad +fu: bar +`) + +func TestMergeConfig(t *testing.T) { + v := New() + v.SetConfigType("yml") + if err := v.ReadConfig(bytes.NewBuffer(yamlMergeExampleTgt)); err != nil { + t.Fatal(err) + } + + if pop := v.GetInt("hello.pop"); pop != 37890 { + t.Fatalf("pop != 37890, = %d", pop) + } + + if world := v.GetStringSlice("hello.world"); len(world) != 4 { + t.Fatalf("len(world) != 4, = %d", len(world)) + } + + if fu := v.GetString("fu"); fu != "" { + t.Fatalf("fu != \"\", = %s", fu) + } + + if err := v.MergeConfig(bytes.NewBuffer(yamlMergeExampleSrc)); err != nil { + t.Fatal(err) + } + + if pop := v.GetInt("hello.pop"); pop != 45000 { + t.Fatalf("pop != 45000, = %d", pop) + } + + if world := v.GetStringSlice("hello.world"); len(world) != 4 { + t.Fatalf("len(world) != 4, = %d", len(world)) + } + + if universe := v.GetStringSlice("hello.universe"); len(universe) != 2 { + t.Fatalf("len(universe) != 2, = %d", len(universe)) + } + + if fu := v.GetString("fu"); fu != "bar" { + t.Fatalf("fu != \"bar\", = %s", fu) + } +} + +func TestMergeConfigNoMerge(t *testing.T) { + v := New() + v.SetConfigType("yml") + if err := v.ReadConfig(bytes.NewBuffer(yamlMergeExampleTgt)); err != nil { + t.Fatal(err) + } + + if pop := v.GetInt("hello.pop"); pop != 37890 { + t.Fatalf("pop != 37890, = %d", pop) + } + + if world := v.GetStringSlice("hello.world"); len(world) != 4 { + t.Fatalf("len(world) != 4, = %d", len(world)) + } + + if fu := v.GetString("fu"); fu != "" { + t.Fatalf("fu != \"\", = %s", fu) + } + + if err := v.ReadConfig(bytes.NewBuffer(yamlMergeExampleSrc)); err != nil { + t.Fatal(err) + } + + if pop := v.GetInt("hello.pop"); pop != 45000 { + t.Fatalf("pop != 45000, = %d", pop) + } + + if world := v.GetStringSlice("hello.world"); len(world) != 0 { + t.Fatalf("len(world) != 0, = %d", len(world)) + } + + if universe := v.GetStringSlice("hello.universe"); len(universe) != 2 { + t.Fatalf("len(universe) != 2, = %d", len(universe)) + } + + if fu := v.GetString("fu"); fu != "bar" { + t.Fatalf("fu != \"bar\", = %s", fu) + } +} diff --git a/Godeps/_workspace/src/github.com/xenolf/lego/acme/client.go b/Godeps/_workspace/src/github.com/xenolf/lego/acme/client.go index a9438cbd..b8682956 100644 --- a/Godeps/_workspace/src/github.com/xenolf/lego/acme/client.go +++ b/Godeps/_workspace/src/github.com/xenolf/lego/acme/client.go @@ -1,6 +1,7 @@ package acme import ( + "crypto" "crypto/rsa" "crypto/x509" "encoding/base64" @@ -9,7 +10,7 @@ import ( "fmt" "io/ioutil" "log" - "net/http" + "net" "regexp" "strconv" "strings" @@ -99,20 +100,38 @@ func NewClient(caDirURL string, user User, keyBits int) (*Client, error) { return &Client{directory: dir, user: user, jws: jws, keyBits: keyBits, solvers: solvers}, nil } -// SetHTTPPort specifies a custom port to be used for HTTP based challenges. -// If this option is not used, the default port 80 will be used. -func (c *Client) SetHTTPPort(port string) { - if chlng, ok := c.solvers["http-01"]; ok { - chlng.(*httpChallenge).optPort = port +// SetHTTPAddress specifies a custom interface:port to be used for HTTP based challenges. +// If this option is not used, the default port 80 and all interfaces will be used. +// To only specify a port and no interface use the ":port" notation. +func (c *Client) SetHTTPAddress(iface string) error { + host, port, err := net.SplitHostPort(iface) + if err != nil { + return err } + + if chlng, ok := c.solvers["http-01"]; ok { + chlng.(*httpChallenge).iface = host + chlng.(*httpChallenge).port = port + } + + return nil } -// SetTLSPort specifies a custom port to be used for TLS based challenges. -// If this option is not used, the default port 443 will be used. -func (c *Client) SetTLSPort(port string) { - if chlng, ok := c.solvers["tls-sni-01"]; ok { - chlng.(*tlsSNIChallenge).optPort = port +// SetTLSAddress specifies a custom interface:port to be used for TLS based challenges. +// If this option is not used, the default port 443 and all interfaces will be used. +// To only specify a port and no interface use the ":port" notation. +func (c *Client) SetTLSAddress(iface string) error { + host, port, err := net.SplitHostPort(iface) + if err != nil { + return err } + + if chlng, ok := c.solvers["tls-sni-01"]; ok { + chlng.(*tlsSNIChallenge).iface = host + chlng.(*tlsSNIChallenge).port = port + } + + return nil } // ExcludeChallenges explicitly removes challenges from the pool for solving. @@ -175,12 +194,14 @@ func (c *Client) AgreeToTOS() error { // ObtainCertificate tries to obtain a single certificate using all domains passed into it. // The first domain in domains is used for the CommonName field of the certificate, all other -// domains are added using the Subject Alternate Names extension. +// domains are added using the Subject Alternate Names extension. A new private key is generated +// for every invocation of this function. If you do not want that you can supply your own private key +// in the privKey parameter. If this parameter is non-nil it will be used instead of generating a new one. // If bundle is true, the []byte contains both the issuer certificate and // your issued certificate as a bundle. // This function will never return a partial certificate. If one domain in the list fails, // the whole certificate will fail. -func (c *Client) ObtainCertificate(domains []string, bundle bool) (CertificateResource, map[string]error) { +func (c *Client) ObtainCertificate(domains []string, bundle bool, privKey crypto.PrivateKey) (CertificateResource, map[string]error) { if bundle { logf("[INFO][%s] acme: Obtaining bundled SAN certificate", strings.Join(domains, ", ")) } else { @@ -201,7 +222,7 @@ func (c *Client) ObtainCertificate(domains []string, bundle bool) (CertificateRe logf("[INFO][%s] acme: Validations succeeded; requesting certificates", strings.Join(domains, ", ")) - cert, err := c.requestCertificate(challenges, bundle) + cert, err := c.requestCertificate(challenges, bundle, privKey) if err != nil { for _, chln := range challenges { failures[chln.Domain] = err @@ -236,6 +257,7 @@ func (c *Client) RevokeCertificate(certificate []byte) error { // this function will start a new-cert flow where a new certificate gets generated. // If bundle is true, the []byte contains both the issuer certificate and // your issued certificate as a bundle. +// For private key reuse the PrivateKey property of the passed in CertificateResource should be non-nil. func (c *Client) RenewCertificate(cert CertificateResource, bundle bool) (CertificateResource, error) { // Input certificate is PEM encoded. Decode it here as we may need the decoded // cert later on in the renewal process. The input may be a bundle or a single certificate. @@ -255,7 +277,7 @@ func (c *Client) RenewCertificate(cert CertificateResource, bundle bool) (Certif // The first step of renewal is to check if we get a renewed cert // directly from the cert URL. - resp, err := http.Get(cert.CertURL) + resp, err := httpGet(cert.CertURL) if err != nil { return CertificateResource{}, err } @@ -297,7 +319,15 @@ func (c *Client) RenewCertificate(cert CertificateResource, bundle bool) (Certif return cert, nil } - newCert, failures := c.ObtainCertificate([]string{cert.Domain}, bundle) + var privKey crypto.PrivateKey + if cert.PrivateKey != nil { + privKey, err = parsePEMPrivateKey(cert.PrivateKey) + if err != nil { + return CertificateResource{}, err + } + } + + newCert, failures := c.ObtainCertificate([]string{cert.Domain}, bundle, privKey) return newCert, failures[cert.Domain] } @@ -393,15 +423,18 @@ func (c *Client) getChallenges(domains []string) ([]authorizationResource, map[s return challenges, failures } -func (c *Client) requestCertificate(authz []authorizationResource, bundle bool) (CertificateResource, error) { +func (c *Client) requestCertificate(authz []authorizationResource, bundle bool, privKey crypto.PrivateKey) (CertificateResource, error) { if len(authz) == 0 { return CertificateResource{}, errors.New("Passed no authorizations to requestCertificate!") } commonName := authz[0] - privKey, err := generatePrivateKey(rsakey, c.keyBits) - if err != nil { - return CertificateResource{}, err + var err error + if privKey == nil { + privKey, err = generatePrivateKey(rsakey, c.keyBits) + if err != nil { + return CertificateResource{}, err + } } var san []string @@ -435,11 +468,8 @@ func (c *Client) requestCertificate(authz []authorizationResource, bundle bool) PrivateKey: privateKeyPem} for { - switch resp.StatusCode { - case 202: - case 201: - + case 201, 202: cert, err := ioutil.ReadAll(limitReader(resp.Body, 1024*1024)) resp.Body.Close() if err != nil { @@ -492,7 +522,7 @@ func (c *Client) requestCertificate(authz []authorizationResource, bundle bool) return CertificateResource{}, handleHTTPError(resp) } - resp, err = http.Get(cerRes.CertURL) + resp, err = httpGet(cerRes.CertURL) if err != nil { return CertificateResource{}, err } @@ -507,7 +537,7 @@ func (c *Client) getIssuerCertificate(url string) ([]byte, error) { return c.issuerCert, nil } - resp, err := http.Get(url) + resp, err := httpGet(url) if err != nil { return nil, err } @@ -585,44 +615,3 @@ func validate(j *jws, domain, uri string, chlng challenge) error { } } } - -// getJSON performs an HTTP GET request and parses the response body -// as JSON, into the provided respBody object. -func getJSON(uri string, respBody interface{}) (http.Header, error) { - resp, err := http.Get(uri) - if err != nil { - return nil, fmt.Errorf("failed to get %q: %v", uri, err) - } - defer resp.Body.Close() - - if resp.StatusCode >= http.StatusBadRequest { - return resp.Header, handleHTTPError(resp) - } - - return resp.Header, json.NewDecoder(resp.Body).Decode(respBody) -} - -// postJSON performs an HTTP POST request and parses the response body -// as JSON, into the provided respBody object. -func postJSON(j *jws, uri string, reqBody, respBody interface{}) (http.Header, error) { - jsonBytes, err := json.Marshal(reqBody) - if err != nil { - return nil, errors.New("Failed to marshal network message...") - } - - resp, err := j.post(uri, jsonBytes) - if err != nil { - return nil, fmt.Errorf("Failed to post JWS message. -> %v", err) - } - defer resp.Body.Close() - - if resp.StatusCode >= http.StatusBadRequest { - return resp.Header, handleHTTPError(resp) - } - - if respBody == nil { - return resp.Header, nil - } - - return resp.Header, json.NewDecoder(resp.Body).Decode(respBody) -} diff --git a/Godeps/_workspace/src/github.com/xenolf/lego/acme/client_test.go b/Godeps/_workspace/src/github.com/xenolf/lego/acme/client_test.go index cd89b257..3daaed5f 100644 --- a/Godeps/_workspace/src/github.com/xenolf/lego/acme/client_test.go +++ b/Godeps/_workspace/src/github.com/xenolf/lego/acme/client_test.go @@ -4,6 +4,7 @@ import ( "crypto/rand" "crypto/rsa" "encoding/json" + "net" "net/http" "net/http/httptest" "strings" @@ -66,12 +67,13 @@ func TestClientOptPort(t *testing.T) { })) optPort := "1234" + optHost := "" client, err := NewClient(ts.URL, user, keyBits) if err != nil { t.Fatalf("Could not create client: %v", err) } - client.SetHTTPPort(optPort) - client.SetTLSPort(optPort) + client.SetHTTPAddress(net.JoinHostPort(optHost, optPort)) + client.SetTLSAddress(net.JoinHostPort(optHost, optPort)) httpSolver, ok := client.solvers["http-01"].(*httpChallenge) if !ok { @@ -80,8 +82,11 @@ func TestClientOptPort(t *testing.T) { if httpSolver.jws != client.jws { t.Error("Expected http-01 to have same jws as client") } - if httpSolver.optPort != optPort { - t.Errorf("Expected http-01 to have optPort %s but was %s", optPort, httpSolver.optPort) + if httpSolver.port != optPort { + t.Errorf("Expected http-01 to have port %s but was %s", optPort, httpSolver.port) + } + if httpSolver.iface != optHost { + t.Errorf("Expected http-01 to have iface %s but was %s", optHost, httpSolver.iface) } httpsSolver, ok := client.solvers["tls-sni-01"].(*tlsSNIChallenge) @@ -91,8 +96,23 @@ func TestClientOptPort(t *testing.T) { if httpsSolver.jws != client.jws { t.Error("Expected tls-sni-01 to have same jws as client") } - if httpsSolver.optPort != optPort { - t.Errorf("Expected tls-sni-01 to have optPort %s but was %s", optPort, httpSolver.optPort) + if httpsSolver.port != optPort { + t.Errorf("Expected tls-sni-01 to have port %s but was %s", optPort, httpSolver.port) + } + if httpsSolver.port != optPort { + t.Errorf("Expected tls-sni-01 to have port %s but was %s", optHost, httpSolver.iface) + } + + // test setting different host + optHost = "127.0.0.1" + client.SetHTTPAddress(net.JoinHostPort(optHost, optPort)) + client.SetTLSAddress(net.JoinHostPort(optHost, optPort)) + + if httpSolver.iface != optHost { + t.Errorf("Expected http-01 to have iface %s but was %s", optHost, httpSolver.iface) + } + if httpsSolver.port != optPort { + t.Errorf("Expected tls-sni-01 to have port %s but was %s", optHost, httpSolver.iface) } } diff --git a/Godeps/_workspace/src/github.com/xenolf/lego/acme/crypto.go b/Godeps/_workspace/src/github.com/xenolf/lego/acme/crypto.go index ee0fcdd6..c1dbea21 100644 --- a/Godeps/_workspace/src/github.com/xenolf/lego/acme/crypto.go +++ b/Godeps/_workspace/src/github.com/xenolf/lego/acme/crypto.go @@ -63,7 +63,7 @@ func GetOCSPForCert(bundle []byte) ([]byte, *ocsp.Response, error) { return nil, nil, errors.New("no issuing certificate URL") } - resp, err := http.Get(certificates[0].IssuingCertificateURL[0]) + resp, err := httpGet(certificates[0].IssuingCertificateURL[0]) if err != nil { return nil, nil, err } @@ -97,7 +97,7 @@ func GetOCSPForCert(bundle []byte) ([]byte, *ocsp.Response, error) { } reader := bytes.NewReader(ocspReq) - req, err := http.Post(issuedCert.OCSPServer[0], "application/ocsp-request", reader) + req, err := httpPost(issuedCert.OCSPServer[0], "application/ocsp-request", reader) if err != nil { return nil, nil, err } @@ -177,22 +177,21 @@ func performECDH(priv *ecdsa.PrivateKey, pub *ecdsa.PublicKey, outLen int, label // a slice of x509 certificates. This function will error if no certificates are found. func parsePEMBundle(bundle []byte) ([]*x509.Certificate, error) { var certificates []*x509.Certificate + var certDERBlock *pem.Block - remaining := bundle - for len(remaining) != 0 { - certBlock, rem := pem.Decode(remaining) - // Thanks golang for having me do this :[ - remaining = rem - if certBlock == nil { - return nil, errors.New("Could not decode certificate.") + for { + certDERBlock, bundle = pem.Decode(bundle) + if certDERBlock == nil { + break } - cert, err := x509.ParseCertificate(certBlock.Bytes) - if err != nil { - return nil, err + if certDERBlock.Type == "CERTIFICATE" { + cert, err := x509.ParseCertificate(certDERBlock.Bytes) + if err != nil { + return nil, err + } + certificates = append(certificates, cert) } - - certificates = append(certificates, cert) } if len(certificates) == 0 { @@ -202,6 +201,19 @@ func parsePEMBundle(bundle []byte) ([]*x509.Certificate, error) { return certificates, nil } +func parsePEMPrivateKey(key []byte) (crypto.PrivateKey, error) { + keyBlock, _ := pem.Decode(key) + + switch keyBlock.Type { + case "RSA PRIVATE KEY": + return x509.ParsePKCS1PrivateKey(keyBlock.Bytes) + case "EC PRIVATE KEY": + return x509.ParseECPrivateKey(keyBlock.Bytes) + default: + return nil, errors.New("Unknown PEM header value") + } +} + func generatePrivateKey(t keyType, keyLength int) (crypto.PrivateKey, error) { switch t { case eckey: diff --git a/Godeps/_workspace/src/github.com/xenolf/lego/acme/http.go b/Godeps/_workspace/src/github.com/xenolf/lego/acme/http.go new file mode 100644 index 00000000..661a0588 --- /dev/null +++ b/Godeps/_workspace/src/github.com/xenolf/lego/acme/http.go @@ -0,0 +1,115 @@ +package acme + +import ( + "encoding/json" + "errors" + "fmt" + "io" + "net/http" + "runtime" + "strings" +) + +// UserAgent, if non-empty, will be tacked onto the User-Agent string in requests. +var UserAgent string + +const ( + // defaultGoUserAgent is the Go HTTP package user agent string. Too + // bad it isn't exported. If it changes, we should update it here, too. + defaultGoUserAgent = "Go-http-client/1.1" + + // ourUserAgent is the User-Agent of this underlying library package. + ourUserAgent = "xenolf-acme" +) + +// httpHead performs a HEAD request with a proper User-Agent string. +// The response body (resp.Body) is already closed when this function returns. +func httpHead(url string) (resp *http.Response, err error) { + req, err := http.NewRequest("HEAD", url, nil) + if err != nil { + return nil, err + } + + req.Header.Set("User-Agent", userAgent()) + + client := http.Client{} + resp, err = client.Do(req) + if resp.Body != nil { + resp.Body.Close() + } + return resp, err +} + +// httpPost performs a POST request with a proper User-Agent string. +// Callers should close resp.Body when done reading from it. +func httpPost(url string, bodyType string, body io.Reader) (resp *http.Response, err error) { + req, err := http.NewRequest("POST", url, body) + if err != nil { + return nil, err + } + req.Header.Set("Content-Type", bodyType) + req.Header.Set("User-Agent", userAgent()) + + client := http.Client{} + return client.Do(req) +} + +// httpGet performs a GET request with a proper User-Agent string. +// Callers should close resp.Body when done reading from it. +func httpGet(url string) (resp *http.Response, err error) { + req, err := http.NewRequest("GET", url, nil) + if err != nil { + return nil, err + } + req.Header.Set("User-Agent", userAgent()) + + client := http.Client{} + return client.Do(req) +} + +// getJSON performs an HTTP GET request and parses the response body +// as JSON, into the provided respBody object. +func getJSON(uri string, respBody interface{}) (http.Header, error) { + resp, err := httpGet(uri) + if err != nil { + return nil, fmt.Errorf("failed to get %q: %v", uri, err) + } + defer resp.Body.Close() + + if resp.StatusCode >= http.StatusBadRequest { + return resp.Header, handleHTTPError(resp) + } + + return resp.Header, json.NewDecoder(resp.Body).Decode(respBody) +} + +// postJSON performs an HTTP POST request and parses the response body +// as JSON, into the provided respBody object. +func postJSON(j *jws, uri string, reqBody, respBody interface{}) (http.Header, error) { + jsonBytes, err := json.Marshal(reqBody) + if err != nil { + return nil, errors.New("Failed to marshal network message...") + } + + resp, err := j.post(uri, jsonBytes) + if err != nil { + return nil, fmt.Errorf("Failed to post JWS message. -> %v", err) + } + defer resp.Body.Close() + + if resp.StatusCode >= http.StatusBadRequest { + return resp.Header, handleHTTPError(resp) + } + + if respBody == nil { + return resp.Header, nil + } + + return resp.Header, json.NewDecoder(resp.Body).Decode(respBody) +} + +// userAgent builds and returns the User-Agent string to use in requests. +func userAgent() string { + ua := fmt.Sprintf("%s (%s; %s) %s %s", defaultGoUserAgent, runtime.GOOS, runtime.GOARCH, ourUserAgent, UserAgent) + return strings.TrimSpace(ua) +} diff --git a/Godeps/_workspace/src/github.com/xenolf/lego/acme/http_challenge.go b/Godeps/_workspace/src/github.com/xenolf/lego/acme/http_challenge.go index 00ad5895..c3481fe1 100644 --- a/Godeps/_workspace/src/github.com/xenolf/lego/acme/http_challenge.go +++ b/Godeps/_workspace/src/github.com/xenolf/lego/acme/http_challenge.go @@ -10,13 +10,17 @@ import ( type httpChallenge struct { jws *jws validate validateFunc - optPort string + iface string + port string + done chan bool } func (s *httpChallenge) Solve(chlng challenge, domain string) error { logf("[INFO][%s] acme: Trying to solve HTTP-01", domain) + s.done = make(chan bool) + // Generate the Key Authorization for the challenge keyAuth, err := getKeyAuthorization(chlng.Token, &s.jws.privKey.PublicKey) if err != nil { @@ -24,23 +28,33 @@ func (s *httpChallenge) Solve(chlng challenge, domain string) error { } // Allow for CLI port override - port := ":80" - if s.optPort != "" { - port = ":" + s.optPort + port := "80" + if s.port != "" { + port = s.port } - listener, err := net.Listen("tcp", domain+port) - if err != nil { - // if the domain:port bind failed, fall back to :port bind and try that instead. - listener, err = net.Listen("tcp", port) - if err != nil { - return fmt.Errorf("Could not start HTTP server for challenge -> %v", err) - } + iface := "" + if s.iface != "" { + iface = s.iface + } + + listener, err := net.Listen("tcp", net.JoinHostPort(iface, port)) + if err != nil { + return fmt.Errorf("Could not start HTTP server for challenge -> %v", err) } - defer listener.Close() path := "/.well-known/acme-challenge/" + chlng.Token + go s.serve(listener, path, keyAuth, domain) + + err = s.validate(s.jws, domain, chlng.URI, challenge{Resource: "challenge", Type: chlng.Type, Token: chlng.Token, KeyAuthorization: keyAuth}) + listener.Close() + <-s.done + + return err +} + +func (s *httpChallenge) serve(listener net.Listener, path, keyAuth, domain string) { // The handler validates the HOST header and request type. // For validation it then writes the token the server returned with the challenge mux := http.NewServeMux() @@ -55,7 +69,6 @@ func (s *httpChallenge) Solve(chlng challenge, domain string) error { } }) - go http.Serve(listener, mux) - - return s.validate(s.jws, domain, chlng.URI, challenge{Resource: "challenge", Type: chlng.Type, Token: chlng.Token, KeyAuthorization: keyAuth}) + http.Serve(listener, mux) + s.done <- true } diff --git a/Godeps/_workspace/src/github.com/xenolf/lego/acme/http_challenge_test.go b/Godeps/_workspace/src/github.com/xenolf/lego/acme/http_challenge_test.go index 97a9d979..9ffb27f8 100644 --- a/Godeps/_workspace/src/github.com/xenolf/lego/acme/http_challenge_test.go +++ b/Godeps/_workspace/src/github.com/xenolf/lego/acme/http_challenge_test.go @@ -3,7 +3,6 @@ package acme import ( "crypto/rsa" "io/ioutil" - "net/http" "strings" "testing" ) @@ -14,7 +13,7 @@ func TestHTTPChallenge(t *testing.T) { clientChallenge := challenge{Type: "http-01", Token: "http1"} mockValidate := func(_ *jws, _, _ string, chlng challenge) error { uri := "http://localhost:23457/.well-known/acme-challenge/" + chlng.Token - resp, err := http.Get(uri) + resp, err := httpGet(uri) if err != nil { return err } @@ -36,7 +35,7 @@ func TestHTTPChallenge(t *testing.T) { return nil } - solver := &httpChallenge{jws: j, validate: mockValidate, optPort: "23457"} + solver := &httpChallenge{jws: j, validate: mockValidate, port: "23457"} if err := solver.Solve(clientChallenge, "localhost:23457"); err != nil { t.Errorf("Solve error: got %v, want nil", err) @@ -47,10 +46,10 @@ func TestHTTPChallengeInvalidPort(t *testing.T) { privKey, _ := generatePrivateKey(rsakey, 128) j := &jws{privKey: privKey.(*rsa.PrivateKey)} clientChallenge := challenge{Type: "http-01", Token: "http2"} - solver := &httpChallenge{jws: j, validate: stubValidate, optPort: "123456"} + solver := &httpChallenge{jws: j, validate: stubValidate, port: "123456"} if err := solver.Solve(clientChallenge, "localhost:123456"); err == nil { - t.Error("Solve error: got %v, want error", err) + t.Errorf("Solve error: got %v, want error", err) } else if want := "invalid port 123456"; !strings.HasSuffix(err.Error(), want) { t.Errorf("Solve error: got %q, want suffix %q", err.Error(), want) } diff --git a/Godeps/_workspace/src/github.com/xenolf/lego/acme/http_test.go b/Godeps/_workspace/src/github.com/xenolf/lego/acme/http_test.go new file mode 100644 index 00000000..33a48a33 --- /dev/null +++ b/Godeps/_workspace/src/github.com/xenolf/lego/acme/http_test.go @@ -0,0 +1,100 @@ +package acme + +import ( + "net/http" + "net/http/httptest" + "strings" + "testing" +) + +func TestHTTPHeadUserAgent(t *testing.T) { + var ua, method string + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ua = r.Header.Get("User-Agent") + method = r.Method + })) + defer ts.Close() + + _, err := httpHead(ts.URL) + if err != nil { + t.Fatal(err) + } + + if method != "HEAD" { + t.Errorf("Expected method to be HEAD, got %s", method) + } + if !strings.Contains(ua, ourUserAgent) { + t.Errorf("Expected User-Agent to contain '%s', got: '%s'", ourUserAgent, ua) + } +} + +func TestHTTPGetUserAgent(t *testing.T) { + var ua, method string + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ua = r.Header.Get("User-Agent") + method = r.Method + })) + defer ts.Close() + + res, err := httpGet(ts.URL) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + + if method != "GET" { + t.Errorf("Expected method to be GET, got %s", method) + } + if !strings.Contains(ua, ourUserAgent) { + t.Errorf("Expected User-Agent to contain '%s', got: '%s'", ourUserAgent, ua) + } +} + +func TestHTTPPostUserAgent(t *testing.T) { + var ua, method string + ts := httptest.NewServer(http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ua = r.Header.Get("User-Agent") + method = r.Method + })) + defer ts.Close() + + res, err := httpPost(ts.URL, "text/plain", strings.NewReader("falalalala")) + if err != nil { + t.Fatal(err) + } + res.Body.Close() + + if method != "POST" { + t.Errorf("Expected method to be POST, got %s", method) + } + if !strings.Contains(ua, ourUserAgent) { + t.Errorf("Expected User-Agent to contain '%s', got: '%s'", ourUserAgent, ua) + } +} + +func TestUserAgent(t *testing.T) { + ua := userAgent() + + if !strings.Contains(ua, defaultGoUserAgent) { + t.Errorf("Expected UA to contain %s, got '%s'", defaultGoUserAgent, ua) + } + if !strings.Contains(ua, ourUserAgent) { + t.Errorf("Expected UA to contain %s, got '%s'", ourUserAgent, ua) + } + if strings.HasSuffix(ua, " ") { + t.Errorf("UA should not have trailing spaces; got '%s'", ua) + } + + // customize the UA by appending a value + UserAgent = "MyApp/1.2.3" + ua = userAgent() + if !strings.Contains(ua, defaultGoUserAgent) { + t.Errorf("Expected UA to contain %s, got '%s'", defaultGoUserAgent, ua) + } + if !strings.Contains(ua, ourUserAgent) { + t.Errorf("Expected UA to contain %s, got '%s'", ourUserAgent, ua) + } + if !strings.Contains(ua, UserAgent) { + t.Errorf("Expected custom UA to contain %s, got '%s'", UserAgent, ua) + } +} diff --git a/Godeps/_workspace/src/github.com/xenolf/lego/acme/jws.go b/Godeps/_workspace/src/github.com/xenolf/lego/acme/jws.go index 10ddb1c3..01f6f0f5 100644 --- a/Godeps/_workspace/src/github.com/xenolf/lego/acme/jws.go +++ b/Godeps/_workspace/src/github.com/xenolf/lego/acme/jws.go @@ -35,7 +35,7 @@ func (j *jws) post(url string, content []byte) (*http.Response, error) { return nil, err } - resp, err := http.Post(url, "application/jose+json", bytes.NewBuffer([]byte(signedContent.FullSerialize()))) + resp, err := httpPost(url, "application/jose+json", bytes.NewBuffer([]byte(signedContent.FullSerialize()))) if err != nil { return nil, err } @@ -71,7 +71,7 @@ func (j *jws) getNonceFromResponse(resp *http.Response) error { } func (j *jws) getNonce() error { - resp, err := http.Head(j.directoryURL) + resp, err := httpHead(j.directoryURL) if err != nil { return err } diff --git a/Godeps/_workspace/src/github.com/xenolf/lego/acme/tls_sni_challenge.go b/Godeps/_workspace/src/github.com/xenolf/lego/acme/tls_sni_challenge.go index ad099d54..e2511ad3 100644 --- a/Godeps/_workspace/src/github.com/xenolf/lego/acme/tls_sni_challenge.go +++ b/Godeps/_workspace/src/github.com/xenolf/lego/acme/tls_sni_challenge.go @@ -6,13 +6,15 @@ import ( "crypto/tls" "encoding/hex" "fmt" + "net" "net/http" ) type tlsSNIChallenge struct { jws *jws validate validateFunc - optPort string + iface string + port string } func (t *tlsSNIChallenge) Solve(chlng challenge, domain string) error { @@ -33,15 +35,20 @@ func (t *tlsSNIChallenge) Solve(chlng challenge, domain string) error { } // Allow for CLI port override - port := ":443" - if t.optPort != "" { - port = ":" + t.optPort + port := "443" + if t.port != "" { + port = t.port + } + + iface := "" + if t.iface != "" { + iface = t.iface } tlsConf := new(tls.Config) tlsConf.Certificates = []tls.Certificate{cert} - listener, err := tls.Listen("tcp", port, tlsConf) + listener, err := tls.Listen("tcp", net.JoinHostPort(iface, port), tlsConf) if err != nil { return fmt.Errorf("Could not start HTTPS server for challenge -> %v", err) } diff --git a/Godeps/_workspace/src/github.com/xenolf/lego/acme/tls_sni_challenge_test.go b/Godeps/_workspace/src/github.com/xenolf/lego/acme/tls_sni_challenge_test.go index 8f3ccbe1..f2350303 100644 --- a/Godeps/_workspace/src/github.com/xenolf/lego/acme/tls_sni_challenge_test.go +++ b/Godeps/_workspace/src/github.com/xenolf/lego/acme/tls_sni_challenge_test.go @@ -43,7 +43,7 @@ func TestTLSSNIChallenge(t *testing.T) { return nil } - solver := &tlsSNIChallenge{jws: j, validate: mockValidate, optPort: "23457"} + solver := &tlsSNIChallenge{jws: j, validate: mockValidate, port: "23457"} if err := solver.Solve(clientChallenge, "localhost:23457"); err != nil { t.Errorf("Solve error: got %v, want nil", err) @@ -54,10 +54,10 @@ func TestTLSSNIChallengeInvalidPort(t *testing.T) { privKey, _ := generatePrivateKey(rsakey, 128) j := &jws{privKey: privKey.(*rsa.PrivateKey)} clientChallenge := challenge{Type: "tls-sni-01", Token: "tlssni2"} - solver := &tlsSNIChallenge{jws: j, validate: stubValidate, optPort: "123456"} + solver := &tlsSNIChallenge{jws: j, validate: stubValidate, port: "123456"} if err := solver.Solve(clientChallenge, "localhost:123456"); err == nil { - t.Error("Solve error: got %v, want error", err) + t.Errorf("Solve error: got %v, want error", err) } else if want := "invalid port 123456"; !strings.HasSuffix(err.Error(), want) { t.Errorf("Solve error: got %q, want suffix %q", err.Error(), want) } diff --git a/letsencrypt/letsencrypt.go b/letsencrypt/letsencrypt.go index fb1152b7..ecd141c7 100644 --- a/letsencrypt/letsencrypt.go +++ b/letsencrypt/letsencrypt.go @@ -26,7 +26,7 @@ func Run(dir, domain, email, port string) (*state, error) { client, err := acme.NewClient(URL, &user, KeySize) client.ExcludeChallenges([]string{"tls-sni-01"}) - client.SetHTTPPort(port) + client.SetHTTPAddress(port) if user.Registration == nil { user.Registration, err = client.Register() @@ -123,7 +123,7 @@ func (s *state) setOCSP(ocsp []byte) { } func (s *state) obtain() error { - cert, errors := s.client.ObtainCertificate([]string{s.domain}, true) + cert, errors := s.client.ObtainCertificate([]string{s.domain}, true, nil) if err := errors[s.domain]; err != nil { if _, ok := err.(acme.TOSError); ok { err := s.client.AgreeToTOS() diff --git a/server/server.go b/server/server.go index 07734e05..1adf235b 100644 --- a/server/server.go +++ b/server/server.go @@ -87,7 +87,7 @@ func startHTTP() { go http.ListenAndServe(":80", http.HandlerFunc(letsEncryptProxy)) } - letsEncrypt, err := letsencrypt.Run(dir, domain, email, lePort) + letsEncrypt, err := letsencrypt.Run(dir, domain, email, ":"+lePort) if err != nil { log.Fatal(err) }