diff --git a/Makefile b/Makefile index 574970109..69efa7e8d 100644 --- a/Makefile +++ b/Makefile @@ -34,7 +34,8 @@ lint: cyclo: @echo "Running $@:" - @GO15VENDOREXPERIMENT=1 gocyclo -over 25 . + @GO15VENDOREXPERIMENT=1 gocyclo -over 25 *.go + @GO15VENDOREXPERIMENT=1 gocyclo -over 25 pkg build: getdeps verifiers @echo "Installing minio:" @@ -42,6 +43,7 @@ build: getdeps verifiers test: build @echo "Running all testing:" + @GO15VENDOREXPERIMENT=1 go test $(GOFLAGS) . @GO15VENDOREXPERIMENT=1 go test $(GOFLAGS) github.com/minio/minio/pkg... gomake-all: build diff --git a/buildscripts/checkdeps.sh b/buildscripts/checkdeps.sh index 7d0e4edfb..7fc543304 100644 --- a/buildscripts/checkdeps.sh +++ b/buildscripts/checkdeps.sh @@ -21,7 +21,7 @@ _init() { ## Minimum required versions for build dependencies GCC_VERSION="4.0" - CLANG_VERSION="3.5" + LLVM_VERSION="7.0.0" YASM_VERSION="1.2.0" GIT_VERSION="1.0" GO_VERSION="1.5.1" @@ -173,7 +173,7 @@ is_supported_arch() { check_deps() { check_version "$(env go version 2>/dev/null | sed 's/^.* go\([0-9.]*\).*$/\1/')" "${GO_VERSION}" if [ $? -ge 2 ]; then - MISSING="${MISSING} golang(1.5)" + MISSING="${MISSING} golang(${GO_VERSION})" fi check_version "$(env git --version 2>/dev/null | sed -e 's/^.* \([0-9.\].*\).*$/\1/' -e 's/^\([0-9.\]*\).*/\1/g')" "${GIT_VERSION}" @@ -185,13 +185,13 @@ check_deps() { "Linux") check_version "$(env gcc --version 2>/dev/null | sed 's/^.* \([0-9.]*\).*$/\1/' | head -1)" "${GCC_VERSION}" if [ $? -ge 2 ]; then - MISSING="${MISSING} build-essential" + MISSING="${MISSING} build-essential(${GCC_VERSION})" fi ;; "Darwin") - check_version "$(env gcc --version 2>/dev/null | sed 's/^.* \([0-9.]*\).*$/\1/' | head -1)" "${CLANG_VERSION}" + check_version "$(env gcc --version 2>/dev/null | awk '{print $4}' | head -1)" "${LLVM_VERSION}" if [ $? -ge 2 ]; then - MISSING="${MISSING} xcode-cli" + MISSING="${MISSING} xcode-cli(${LLVM_VERSION})" fi ;; "*") @@ -200,7 +200,7 @@ check_deps() { check_version "$(env yasm --version 2>/dev/null | sed 's/^.* \([0-9.]*\).*$/\1/' | head -1)" "${YASM_VERSION}" if [ $? -ge 2 ]; then - MISSING="${MISSING} yasm(1.2.0)" + MISSING="${MISSING} yasm(${YASM_VERSION})" fi } diff --git a/controller-main.go b/controller-main.go index ee2fad55c..bd458bacc 100644 --- a/controller-main.go +++ b/controller-main.go @@ -16,10 +16,7 @@ package main -import ( - "github.com/minio/cli" - "github.com/minio/minio/pkg/controller" -) +import "github.com/minio/cli" var controllerCmd = cli.Command{ Name: "controller", @@ -43,6 +40,6 @@ func controllerMain(c *cli.Context) { cli.ShowCommandHelpAndExit(c, "controller", 1) } - err := controller.Start() + err := StartController() errorIf(err.Trace(), "Failed to start minio controller.", nil) } diff --git a/pkg/controller/router.go b/controller-router.go similarity index 80% rename from pkg/controller/router.go rename to controller-router.go index c50b958ab..b3a6e13b9 100644 --- a/pkg/controller/router.go +++ b/controller-router.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package controller +package main import ( "net/http" @@ -22,17 +22,16 @@ import ( router "github.com/gorilla/mux" jsonrpc "github.com/gorilla/rpc/v2" "github.com/gorilla/rpc/v2/json" - "github.com/minio/minio/pkg/controller/rpc" ) // getRPCHandler rpc handler func getRPCHandler() http.Handler { s := jsonrpc.NewServer() s.RegisterCodec(json.NewCodec(), "application/json") - s.RegisterService(new(rpc.VersionService), "Version") - s.RegisterService(new(rpc.DonutService), "Donut") - s.RegisterService(new(rpc.AuthService), "Auth") - s.RegisterService(new(rpc.ServerService), "Server") + s.RegisterService(new(VersionService), "Version") + s.RegisterService(new(DonutService), "Donut") + s.RegisterService(new(AuthService), "Auth") + s.RegisterService(new(ServerService), "Server") // Add new RPC services here return registerRPC(router.NewRouter(), s) } diff --git a/pkg/controller/rpc/auth.go b/controller-rpc-auth.go similarity index 99% rename from pkg/controller/rpc/auth.go rename to controller-rpc-auth.go index 730370f91..fbe3be2f0 100644 --- a/pkg/controller/rpc/auth.go +++ b/controller-rpc-auth.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package rpc +package main import ( "errors" diff --git a/pkg/controller/rpc/donut.go b/controller-rpc-donut.go similarity index 99% rename from pkg/controller/rpc/donut.go rename to controller-rpc-donut.go index 1ace43007..bb0eac123 100644 --- a/pkg/controller/rpc/donut.go +++ b/controller-rpc-donut.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package rpc +package main import ( "net/http" diff --git a/pkg/controller/rpc/server.go b/controller-rpc-server.go similarity index 99% rename from pkg/controller/rpc/server.go rename to controller-rpc-server.go index b0ede6273..0d618c774 100644 --- a/pkg/controller/rpc/server.go +++ b/controller-rpc-server.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package rpc +package main import ( "net/http" diff --git a/pkg/controller/rpc/version.go b/controller-rpc-version.go similarity index 90% rename from pkg/controller/rpc/version.go rename to controller-rpc-version.go index d617c5591..0a070ff00 100644 --- a/pkg/controller/rpc/version.go +++ b/controller-rpc-version.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package rpc +package main import ( "net/http" @@ -38,8 +38,7 @@ type VersionReply struct { // Get version func (v *VersionService) Get(r *http.Request, args *VersionArgs, reply *VersionReply) error { reply.Version = "0.0.1" - //TODO: Better approach needed here to pass global states like version. --ab. - // reply.BuildDate = version.Version + reply.BuildDate = minioVersion reply.Architecture = runtime.GOARCH reply.OperatingSystem = runtime.GOOS return nil diff --git a/pkg/controller/server.go b/controller.go similarity index 94% rename from pkg/controller/server.go rename to controller.go index 65d922430..39dda56d9 100644 --- a/pkg/controller/server.go +++ b/controller.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package controller +package main import ( "fmt" @@ -54,8 +54,8 @@ func getRPCServer(rpcHandler http.Handler) (*http.Server, *probe.Error) { return httpServer, nil } -// Start starts a controller -func Start() *probe.Error { +// StartController starts a minio controller +func StartController() *probe.Error { rpcServer, err := getRPCServer(getRPCHandler()) if err != nil { return err.Trace() diff --git a/pkg/controller/rpc_test.go b/controller_rpc_test.go similarity index 56% rename from pkg/controller/rpc_test.go rename to controller_rpc_test.go index b6a47b1f1..ffb412852 100644 --- a/pkg/controller/rpc_test.go +++ b/controller_rpc_test.go @@ -14,31 +14,26 @@ * limitations under the License. */ -package controller +package main import ( "io/ioutil" "net/http" "net/http/httptest" "os" - "testing" - jsonrpc "github.com/gorilla/rpc/v2/json" + "github.com/gorilla/rpc/v2/json" "github.com/minio/minio/pkg/auth" - "github.com/minio/minio/pkg/controller/rpc" . "gopkg.in/check.v1" ) -// Hook up gocheck into the "go test" runner. -func Test(t *testing.T) { TestingT(t) } +type ControllerRPCSuite struct{} -type MySuite struct{} - -var _ = Suite(&MySuite{}) +var _ = Suite(&ControllerRPCSuite{}) var testRPCServer *httptest.Server -func (s *MySuite) SetUpSuite(c *C) { +func (s *ControllerRPCSuite) SetUpSuite(c *C) { root, err := ioutil.TempDir(os.TempDir(), "api-") c.Assert(err, IsNil) auth.SetAuthConfigPath(root) @@ -46,136 +41,136 @@ func (s *MySuite) SetUpSuite(c *C) { testRPCServer = httptest.NewServer(getRPCHandler()) } -func (s *MySuite) TearDownSuite(c *C) { +func (s *ControllerRPCSuite) TearDownSuite(c *C) { testRPCServer.Close() } -func (s *MySuite) TestMemStats(c *C) { - op := rpc.Operation{ +func (s *ControllerRPCSuite) TestMemStats(c *C) { + op := rpcOperation{ Method: "Server.MemStats", - Request: rpc.ServerArgs{}, + Request: ServerArgs{}, } - req, err := rpc.NewRequest(testRPCServer.URL+"/rpc", op, http.DefaultTransport) + req, err := newRPCRequest(testRPCServer.URL+"/rpc", op, http.DefaultTransport) c.Assert(err, IsNil) c.Assert(req.Get("Content-Type"), Equals, "application/json") resp, err := req.Do() c.Assert(err, IsNil) c.Assert(resp.StatusCode, Equals, http.StatusOK) - var reply rpc.MemStatsReply - c.Assert(jsonrpc.DecodeClientResponse(resp.Body, &reply), IsNil) + var reply MemStatsReply + c.Assert(json.DecodeClientResponse(resp.Body, &reply), IsNil) resp.Body.Close() - c.Assert(reply, Not(DeepEquals), rpc.MemStatsReply{}) + c.Assert(reply, Not(DeepEquals), MemStatsReply{}) } -func (s *MySuite) TestSysInfo(c *C) { - op := rpc.Operation{ +func (s *ControllerRPCSuite) TestSysInfo(c *C) { + op := rpcOperation{ Method: "Server.SysInfo", - Request: rpc.ServerArgs{}, + Request: ServerArgs{}, } - req, err := rpc.NewRequest(testRPCServer.URL+"/rpc", op, http.DefaultTransport) + req, err := newRPCRequest(testRPCServer.URL+"/rpc", op, http.DefaultTransport) c.Assert(err, IsNil) c.Assert(req.Get("Content-Type"), Equals, "application/json") resp, err := req.Do() c.Assert(err, IsNil) c.Assert(resp.StatusCode, Equals, http.StatusOK) - var reply rpc.SysInfoReply - c.Assert(jsonrpc.DecodeClientResponse(resp.Body, &reply), IsNil) + var reply SysInfoReply + c.Assert(json.DecodeClientResponse(resp.Body, &reply), IsNil) resp.Body.Close() - c.Assert(reply, Not(DeepEquals), rpc.SysInfoReply{}) + c.Assert(reply, Not(DeepEquals), SysInfoReply{}) } -func (s *MySuite) TestServerList(c *C) { - op := rpc.Operation{ +func (s *ControllerRPCSuite) TestServerList(c *C) { + op := rpcOperation{ Method: "Server.List", - Request: rpc.ServerArgs{}, + Request: ServerArgs{}, } - req, err := rpc.NewRequest(testRPCServer.URL+"/rpc", op, http.DefaultTransport) + req, err := newRPCRequest(testRPCServer.URL+"/rpc", op, http.DefaultTransport) c.Assert(err, IsNil) c.Assert(req.Get("Content-Type"), Equals, "application/json") resp, err := req.Do() c.Assert(err, IsNil) c.Assert(resp.StatusCode, Equals, http.StatusOK) - var reply rpc.ServerListReply - c.Assert(jsonrpc.DecodeClientResponse(resp.Body, &reply), IsNil) + var reply ServerListReply + c.Assert(json.DecodeClientResponse(resp.Body, &reply), IsNil) resp.Body.Close() - c.Assert(reply, Not(DeepEquals), rpc.ServerListReply{}) + c.Assert(reply, Not(DeepEquals), ServerListReply{}) } -func (s *MySuite) TestServerAdd(c *C) { - op := rpc.Operation{ +func (s *ControllerRPCSuite) TestServerAdd(c *C) { + op := rpcOperation{ Method: "Server.Add", - Request: rpc.ServerArgs{MinioServers: []rpc.MinioServer{}}, + Request: ServerArgs{MinioServers: []MinioServer{}}, } - req, err := rpc.NewRequest(testRPCServer.URL+"/rpc", op, http.DefaultTransport) + req, err := newRPCRequest(testRPCServer.URL+"/rpc", op, http.DefaultTransport) c.Assert(err, IsNil) c.Assert(req.Get("Content-Type"), Equals, "application/json") resp, err := req.Do() c.Assert(err, IsNil) c.Assert(resp.StatusCode, Equals, http.StatusOK) - var reply rpc.ServerAddReply - c.Assert(jsonrpc.DecodeClientResponse(resp.Body, &reply), IsNil) + var reply ServerAddReply + c.Assert(json.DecodeClientResponse(resp.Body, &reply), IsNil) resp.Body.Close() - c.Assert(reply, Not(DeepEquals), rpc.ServerAddReply{ServersAdded: []rpc.MinioServer{}}) + c.Assert(reply, Not(DeepEquals), ServerAddReply{ServersAdded: []MinioServer{}}) } -func (s *MySuite) TestAuth(c *C) { - op := rpc.Operation{ +func (s *ControllerRPCSuite) TestAuth(c *C) { + op := rpcOperation{ Method: "Auth.Generate", - Request: rpc.AuthArgs{User: "newuser"}, + Request: AuthArgs{User: "newuser"}, } - req, err := rpc.NewRequest(testRPCServer.URL+"/rpc", op, http.DefaultTransport) + req, err := newRPCRequest(testRPCServer.URL+"/rpc", op, http.DefaultTransport) c.Assert(err, IsNil) c.Assert(req.Get("Content-Type"), Equals, "application/json") resp, err := req.Do() c.Assert(err, IsNil) c.Assert(resp.StatusCode, Equals, http.StatusOK) - var reply rpc.AuthReply - c.Assert(jsonrpc.DecodeClientResponse(resp.Body, &reply), IsNil) + var reply AuthReply + c.Assert(json.DecodeClientResponse(resp.Body, &reply), IsNil) resp.Body.Close() - c.Assert(reply, Not(DeepEquals), rpc.AuthReply{}) + c.Assert(reply, Not(DeepEquals), AuthReply{}) c.Assert(len(reply.AccessKeyID), Equals, 20) c.Assert(len(reply.SecretAccessKey), Equals, 40) c.Assert(len(reply.Name), Not(Equals), 0) - op = rpc.Operation{ + op = rpcOperation{ Method: "Auth.Fetch", - Request: rpc.AuthArgs{User: "newuser"}, + Request: AuthArgs{User: "newuser"}, } - req, err = rpc.NewRequest(testRPCServer.URL+"/rpc", op, http.DefaultTransport) + req, err = newRPCRequest(testRPCServer.URL+"/rpc", op, http.DefaultTransport) c.Assert(err, IsNil) c.Assert(req.Get("Content-Type"), Equals, "application/json") resp, err = req.Do() c.Assert(err, IsNil) c.Assert(resp.StatusCode, Equals, http.StatusOK) - var newReply rpc.AuthReply - c.Assert(jsonrpc.DecodeClientResponse(resp.Body, &newReply), IsNil) + var newReply AuthReply + c.Assert(json.DecodeClientResponse(resp.Body, &newReply), IsNil) resp.Body.Close() - c.Assert(newReply, Not(DeepEquals), rpc.AuthReply{}) + c.Assert(newReply, Not(DeepEquals), AuthReply{}) c.Assert(reply.AccessKeyID, Equals, newReply.AccessKeyID) c.Assert(reply.SecretAccessKey, Equals, newReply.SecretAccessKey) c.Assert(len(reply.Name), Not(Equals), 0) - op = rpc.Operation{ + op = rpcOperation{ Method: "Auth.Reset", - Request: rpc.AuthArgs{User: "newuser"}, + Request: AuthArgs{User: "newuser"}, } - req, err = rpc.NewRequest(testRPCServer.URL+"/rpc", op, http.DefaultTransport) + req, err = newRPCRequest(testRPCServer.URL+"/rpc", op, http.DefaultTransport) c.Assert(err, IsNil) c.Assert(req.Get("Content-Type"), Equals, "application/json") resp, err = req.Do() c.Assert(err, IsNil) c.Assert(resp.StatusCode, Equals, http.StatusOK) - var resetReply rpc.AuthReply - c.Assert(jsonrpc.DecodeClientResponse(resp.Body, &resetReply), IsNil) + var resetReply AuthReply + c.Assert(json.DecodeClientResponse(resp.Body, &resetReply), IsNil) resp.Body.Close() - c.Assert(newReply, Not(DeepEquals), rpc.AuthReply{}) + c.Assert(newReply, Not(DeepEquals), AuthReply{}) c.Assert(reply.AccessKeyID, Not(Equals), resetReply.AccessKeyID) c.Assert(reply.SecretAccessKey, Not(Equals), resetReply.SecretAccessKey) c.Assert(len(reply.Name), Not(Equals), 0) @@ -183,11 +178,11 @@ func (s *MySuite) TestAuth(c *C) { // these operations should fail /// generating access for existing user fails - op = rpc.Operation{ + op = rpcOperation{ Method: "Auth.Generate", - Request: rpc.AuthArgs{User: "newuser"}, + Request: AuthArgs{User: "newuser"}, } - req, err = rpc.NewRequest(testRPCServer.URL+"/rpc", op, http.DefaultTransport) + req, err = newRPCRequest(testRPCServer.URL+"/rpc", op, http.DefaultTransport) c.Assert(err, IsNil) c.Assert(req.Get("Content-Type"), Equals, "application/json") resp, err = req.Do() @@ -195,11 +190,11 @@ func (s *MySuite) TestAuth(c *C) { c.Assert(resp.StatusCode, Equals, http.StatusBadRequest) /// null user provided invalid - op = rpc.Operation{ + op = rpcOperation{ Method: "Auth.Generate", - Request: rpc.AuthArgs{User: ""}, + Request: AuthArgs{User: ""}, } - req, err = rpc.NewRequest(testRPCServer.URL+"/rpc", op, http.DefaultTransport) + req, err = newRPCRequest(testRPCServer.URL+"/rpc", op, http.DefaultTransport) c.Assert(err, IsNil) c.Assert(req.Get("Content-Type"), Equals, "application/json") resp, err = req.Do() diff --git a/logger_test.go b/logger_test.go index 36600ec89..a9dc7f3c5 100644 --- a/logger_test.go +++ b/logger_test.go @@ -27,7 +27,11 @@ import ( . "gopkg.in/check.v1" ) -func (s *TestSuite) TestLogger(c *C) { +type LoggerSuite struct{} + +var _ = Suite(&LoggerSuite{}) + +func (s *LoggerSuite) TestLogger(c *C) { var buffer bytes.Buffer var fields logrus.Fields log.Out = &buffer diff --git a/minio_test.go b/minio_test.go index 11fdcf121..c13d2fa67 100644 --- a/minio_test.go +++ b/minio_test.go @@ -1,5 +1,5 @@ /* - * Minio Cloud Storage (C) 2015 Minio, Inc. + * Minio Cloud Storage, (C) 2015 Minio, Inc. * * Licensed under the Apache License, Version 2.0 (the "License"); * you may not use this file except in compliance with the License. @@ -22,14 +22,5 @@ import ( . "gopkg.in/check.v1" ) +// Hook up gocheck into the "go test" runner. func Test(t *testing.T) { TestingT(t) } - -type TestSuite struct{} - -var _ = Suite(&TestSuite{}) - -func (s *TestSuite) SetUpSuite(c *C) { -} - -func (s *TestSuite) TearDownSuite(c *C) { -} diff --git a/pkg/controller/rpc/request.go b/rpc-request.go similarity index 77% rename from pkg/controller/rpc/request.go rename to rpc-request.go index 224b664d3..c8a3113d2 100644 --- a/pkg/controller/rpc/request.go +++ b/rpc-request.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package rpc +package main import ( "bytes" @@ -24,20 +24,20 @@ import ( "github.com/minio/minio/pkg/probe" ) -// Operation RPC operation -type Operation struct { +// rpcOperation RPC operation +type rpcOperation struct { Method string Request interface{} } -// Request rpc client request -type Request struct { +// rpcRequest rpc client request +type rpcRequest struct { req *http.Request transport http.RoundTripper } -// NewRequest initiate a new client RPC request -func NewRequest(url string, op Operation, transport http.RoundTripper) (*Request, *probe.Error) { +// newRPCRequest initiate a new client RPC request +func newRPCRequest(url string, op rpcOperation, transport http.RoundTripper) (*rpcRequest, *probe.Error) { params, err := json.EncodeClientRequest(op.Method, op.Request) if err != nil { return nil, probe.NewError(err) @@ -46,7 +46,7 @@ func NewRequest(url string, op Operation, transport http.RoundTripper) (*Request if err != nil { return nil, probe.NewError(err) } - rpcReq := &Request{} + rpcReq := &rpcRequest{} rpcReq.req = req rpcReq.req.Header.Set("Content-Type", "application/json") if transport == nil { @@ -57,7 +57,7 @@ func NewRequest(url string, op Operation, transport http.RoundTripper) (*Request } // Do - make a http connection -func (r Request) Do() (*http.Response, *probe.Error) { +func (r rpcRequest) Do() (*http.Response, *probe.Error) { resp, err := r.transport.RoundTrip(r.req) if err != nil { if err, ok := probe.UnwrapError(err); ok { @@ -69,11 +69,11 @@ func (r Request) Do() (*http.Response, *probe.Error) { } // Get - get value of requested header -func (r Request) Get(key string) string { +func (r rpcRequest) Get(key string) string { return r.req.Header.Get(key) } // Set - set value of a header key -func (r *Request) Set(key, value string) { +func (r *rpcRequest) Set(key, value string) { r.req.Header.Set(key, value) } diff --git a/pkg/server/api/acl.go b/server-api-acl.go similarity index 99% rename from pkg/server/api/acl.go rename to server-api-acl.go index 3e215638e..fcc3b37ce 100644 --- a/pkg/server/api/acl.go +++ b/server-api-acl.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package api +package main import "net/http" diff --git a/pkg/server/api/bucket-handlers.go b/server-api-bucket-handlers.go similarity index 94% rename from pkg/server/api/bucket-handlers.go rename to server-api-bucket-handlers.go index c5bb45877..4843f70a9 100644 --- a/pkg/server/api/bucket-handlers.go +++ b/server-api-bucket-handlers.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package api +package main import ( "net/http" @@ -24,7 +24,7 @@ import ( "github.com/minio/minio/pkg/probe" ) -func (api Minio) isValidOp(w http.ResponseWriter, req *http.Request, acceptsContentType contentType) bool { +func (api MinioAPI) isValidOp(w http.ResponseWriter, req *http.Request, acceptsContentType contentType) bool { vars := mux.Vars(req) bucket := vars["bucket"] @@ -67,10 +67,10 @@ func (api Minio) isValidOp(w http.ResponseWriter, req *http.Request, acceptsCont // using the Initiate Multipart Upload request, but has not yet been completed or aborted. // This operation returns at most 1,000 multipart uploads in the response. // -func (api Minio) ListMultipartUploadsHandler(w http.ResponseWriter, req *http.Request) { +func (api MinioAPI) ListMultipartUploadsHandler(w http.ResponseWriter, req *http.Request) { // Ticket master block { - op := Operation{} + op := APIOperation{} op.ProceedCh = make(chan struct{}) api.OP <- op // block until ticket master gives us a go @@ -133,10 +133,10 @@ func (api Minio) ListMultipartUploadsHandler(w http.ResponseWriter, req *http.Re // of the objects in a bucket. You can use the request parameters as selection // criteria to return a subset of the objects in a bucket. // -func (api Minio) ListObjectsHandler(w http.ResponseWriter, req *http.Request) { +func (api MinioAPI) ListObjectsHandler(w http.ResponseWriter, req *http.Request) { // Ticket master block { - op := Operation{} + op := APIOperation{} op.ProceedCh = make(chan struct{}) api.OP <- op // block until Ticket master gives us a go @@ -208,10 +208,10 @@ func (api Minio) ListObjectsHandler(w http.ResponseWriter, req *http.Request) { // ----------- // This implementation of the GET operation returns a list of all buckets // owned by the authenticated sender of the request. -func (api Minio) ListBucketsHandler(w http.ResponseWriter, req *http.Request) { +func (api MinioAPI) ListBucketsHandler(w http.ResponseWriter, req *http.Request) { // Ticket master block { - op := Operation{} + op := APIOperation{} op.ProceedCh = make(chan struct{}) api.OP <- op // block until Ticket master gives us a go @@ -260,10 +260,10 @@ func (api Minio) ListBucketsHandler(w http.ResponseWriter, req *http.Request) { // PutBucketHandler - PUT Bucket // ---------- // This implementation of the PUT operation creates a new bucket for authenticated request -func (api Minio) PutBucketHandler(w http.ResponseWriter, req *http.Request) { +func (api MinioAPI) PutBucketHandler(w http.ResponseWriter, req *http.Request) { // Ticket master block { - op := Operation{} + op := APIOperation{} op.ProceedCh = make(chan struct{}) api.OP <- op // block until Ticket master gives us a go @@ -338,10 +338,10 @@ func (api Minio) PutBucketHandler(w http.ResponseWriter, req *http.Request) { // PutBucketACLHandler - PUT Bucket ACL // ---------- // This implementation of the PUT operation modifies the bucketACL for authenticated request -func (api Minio) PutBucketACLHandler(w http.ResponseWriter, req *http.Request) { +func (api MinioAPI) PutBucketACLHandler(w http.ResponseWriter, req *http.Request) { // Ticket master block { - op := Operation{} + op := APIOperation{} op.ProceedCh = make(chan struct{}) api.OP <- op // block until Ticket master gives us a go @@ -395,10 +395,10 @@ func (api Minio) PutBucketACLHandler(w http.ResponseWriter, req *http.Request) { // The operation returns a 200 OK if the bucket exists and you // have permission to access it. Otherwise, the operation might // return responses such as 404 Not Found and 403 Forbidden. -func (api Minio) HeadBucketHandler(w http.ResponseWriter, req *http.Request) { +func (api MinioAPI) HeadBucketHandler(w http.ResponseWriter, req *http.Request) { // Ticket master block { - op := Operation{} + op := APIOperation{} op.ProceedCh = make(chan struct{}) api.OP <- op // block until Ticket master gives us a go diff --git a/pkg/server/api/contenttype.go b/server-api-contenttype.go similarity index 98% rename from pkg/server/api/contenttype.go rename to server-api-contenttype.go index 954f430fd..49e95d00e 100644 --- a/pkg/server/api/contenttype.go +++ b/server-api-contenttype.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package api +package main import "net/http" diff --git a/pkg/server/api/definitions.go b/server-api-definitions.go similarity index 98% rename from pkg/server/api/definitions.go rename to server-api-definitions.go index ecbf8b310..727b10c22 100644 --- a/pkg/server/api/definitions.go +++ b/server-api-definitions.go @@ -14,12 +14,12 @@ * limitations under the License. */ -package api +package main import "encoding/xml" -// Config - http server config -type Config struct { +// APIConfig - http server config +type APIConfig struct { Address string TLS bool CertFile string diff --git a/pkg/server/api/errors.go b/server-api-errors.go similarity index 95% rename from pkg/server/api/errors.go rename to server-api-errors.go index df81d198f..cfea0441b 100644 --- a/pkg/server/api/errors.go +++ b/server-api-errors.go @@ -14,22 +14,22 @@ * limitations under the License. */ -package api +package main import ( "encoding/xml" "net/http" ) -// Error structure -type Error struct { +// APIError structure +type APIError struct { Code string Description string HTTPStatusCode int } -// ErrorResponse - error response format -type ErrorResponse struct { +// APIErrorResponse - error response format +type APIErrorResponse struct { XMLName xml.Name `xml:"Error" json:"-"` Code string Message string @@ -77,8 +77,8 @@ const ( NotAcceptable = iota + 30 ) -// Error code to Error structure map -var errorCodeResponse = map[int]Error{ +// APIError code to Error structure map +var errorCodeResponse = map[int]APIError{ InvalidMaxUploads: { Code: "InvalidArgument", Description: "Argument maxUploads must be an integer between 0 and 2147483647.", @@ -232,14 +232,14 @@ var errorCodeResponse = map[int]Error{ } // errorCodeError provides errorCode to Error. It returns empty if the code provided is unknown -func getErrorCode(code int) Error { +func getErrorCode(code int) APIError { return errorCodeResponse[code] } // getErrorResponse gets in standard error and resource value and // provides a encodable populated response values -func getErrorResponse(err Error, resource string) ErrorResponse { - var data = ErrorResponse{} +func getErrorResponse(err APIError, resource string) APIErrorResponse { + var data = APIErrorResponse{} data.Code = err.Code data.Message = err.Description if resource != "" { diff --git a/pkg/server/api/generic-handlers.go b/server-api-generic-handlers.go similarity index 99% rename from pkg/server/api/generic-handlers.go rename to server-api-generic-handlers.go index dca2e4996..0f06848bd 100644 --- a/pkg/server/api/generic-handlers.go +++ b/server-api-generic-handlers.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package api +package main import ( "errors" @@ -44,10 +44,6 @@ type resourceHandler struct { handler http.Handler } -const ( - iso8601Format = "20060102T150405Z" -) - func parseDate(req *http.Request) (time.Time, error) { amzDate := req.Header.Get(http.CanonicalHeaderKey("x-amz-date")) switch { diff --git a/pkg/server/api/headers.go b/server-api-headers.go similarity index 93% rename from pkg/server/api/headers.go rename to server-api-headers.go index f220cfd4b..c05e8685e 100644 --- a/pkg/server/api/headers.go +++ b/server-api-headers.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package api +package main import ( "bytes" @@ -22,6 +22,7 @@ import ( "encoding/json" "encoding/xml" "net/http" + "runtime" "strconv" "github.com/minio/minio/pkg/donut" @@ -51,9 +52,7 @@ func generateRequestID() []byte { func setCommonHeaders(w http.ResponseWriter, acceptsType string, contentLength int) { // set unique request ID for each reply w.Header().Set("X-Amz-Request-Id", string(generateRequestID())) - - // TODO: Modularity comes in the way of passing global state like "version". A better approach needed here. -ab - // w.Header().Set("Server", ("Minio/" + version + " (" + runtime.GOOS + ";" + runtime.GOARCH + ")")) + w.Header().Set("Server", ("Minio/" + minioReleaseTag + " (" + runtime.GOOS + ";" + runtime.GOARCH + ")")) w.Header().Set("Accept-Ranges", "bytes") w.Header().Set("Content-Type", acceptsType) diff --git a/pkg/server/api/range.go b/server-api-httprange.go similarity index 99% rename from pkg/server/api/range.go rename to server-api-httprange.go index e90dd41ec..87dd06c80 100644 --- a/pkg/server/api/range.go +++ b/server-api-httprange.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package api +package main import ( "errors" diff --git a/pkg/server/api/logging-handlers.go b/server-api-logging-handlers.go similarity index 99% rename from pkg/server/api/logging-handlers.go rename to server-api-logging-handlers.go index 7b7b365f5..2394f9014 100644 --- a/pkg/server/api/logging-handlers.go +++ b/server-api-logging-handlers.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package api +package main import ( "bytes" diff --git a/pkg/server/api/object-handlers.go b/server-api-object-handlers.go similarity index 95% rename from pkg/server/api/object-handlers.go rename to server-api-object-handlers.go index fa7fee135..ee4270b2b 100644 --- a/pkg/server/api/object-handlers.go +++ b/server-api-object-handlers.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package api +package main import ( "net/http" @@ -33,10 +33,10 @@ const ( // ---------- // This implementation of the GET operation retrieves object. To use GET, // you must have READ access to the object. -func (api Minio) GetObjectHandler(w http.ResponseWriter, req *http.Request) { +func (api MinioAPI) GetObjectHandler(w http.ResponseWriter, req *http.Request) { // ticket master block { - op := Operation{} + op := APIOperation{} op.ProceedCh = make(chan struct{}) api.OP <- op // block until Ticket master gives us a go @@ -100,10 +100,10 @@ func (api Minio) GetObjectHandler(w http.ResponseWriter, req *http.Request) { // HeadObjectHandler - HEAD Object // ----------- // The HEAD operation retrieves metadata from an object without returning the object itself. -func (api Minio) HeadObjectHandler(w http.ResponseWriter, req *http.Request) { +func (api MinioAPI) HeadObjectHandler(w http.ResponseWriter, req *http.Request) { // ticket master block { - op := Operation{} + op := APIOperation{} op.ProceedCh = make(chan struct{}) api.OP <- op // block until Ticket master gives us a go @@ -157,10 +157,10 @@ func (api Minio) HeadObjectHandler(w http.ResponseWriter, req *http.Request) { // PutObjectHandler - PUT Object // ---------- // This implementation of the PUT operation adds an object to a bucket. -func (api Minio) PutObjectHandler(w http.ResponseWriter, req *http.Request) { +func (api MinioAPI) PutObjectHandler(w http.ResponseWriter, req *http.Request) { // Ticket master block { - op := Operation{} + op := APIOperation{} op.ProceedCh = make(chan struct{}) api.OP <- op // block until Ticket master gives us a go @@ -259,10 +259,10 @@ func (api Minio) PutObjectHandler(w http.ResponseWriter, req *http.Request) { /// Multipart API // NewMultipartUploadHandler - New multipart upload -func (api Minio) NewMultipartUploadHandler(w http.ResponseWriter, req *http.Request) { +func (api MinioAPI) NewMultipartUploadHandler(w http.ResponseWriter, req *http.Request) { // Ticket master block { - op := Operation{} + op := APIOperation{} op.ProceedCh = make(chan struct{}) api.OP <- op // block until Ticket master gives us a go @@ -317,10 +317,10 @@ func (api Minio) NewMultipartUploadHandler(w http.ResponseWriter, req *http.Requ } // PutObjectPartHandler - Upload part -func (api Minio) PutObjectPartHandler(w http.ResponseWriter, req *http.Request) { +func (api MinioAPI) PutObjectPartHandler(w http.ResponseWriter, req *http.Request) { // Ticket master block { - op := Operation{} + op := APIOperation{} op.ProceedCh = make(chan struct{}) api.OP <- op // block until Ticket master gives us a go @@ -417,10 +417,10 @@ func (api Minio) PutObjectPartHandler(w http.ResponseWriter, req *http.Request) } // AbortMultipartUploadHandler - Abort multipart upload -func (api Minio) AbortMultipartUploadHandler(w http.ResponseWriter, req *http.Request) { +func (api MinioAPI) AbortMultipartUploadHandler(w http.ResponseWriter, req *http.Request) { // Ticket master block { - op := Operation{} + op := APIOperation{} op.ProceedCh = make(chan struct{}) api.OP <- op // block until Ticket master gives us a go @@ -467,10 +467,10 @@ func (api Minio) AbortMultipartUploadHandler(w http.ResponseWriter, req *http.Re } // ListObjectPartsHandler - List object parts -func (api Minio) ListObjectPartsHandler(w http.ResponseWriter, req *http.Request) { +func (api MinioAPI) ListObjectPartsHandler(w http.ResponseWriter, req *http.Request) { // Ticket master block { - op := Operation{} + op := APIOperation{} op.ProceedCh = make(chan struct{}) api.OP <- op // block until Ticket master gives us a go @@ -532,10 +532,10 @@ func (api Minio) ListObjectPartsHandler(w http.ResponseWriter, req *http.Request } // CompleteMultipartUploadHandler - Complete multipart upload -func (api Minio) CompleteMultipartUploadHandler(w http.ResponseWriter, req *http.Request) { +func (api MinioAPI) CompleteMultipartUploadHandler(w http.ResponseWriter, req *http.Request) { // Ticket master block { - op := Operation{} + op := APIOperation{} op.ProceedCh = make(chan struct{}) api.OP <- op // block until Ticket master gives us a go @@ -597,13 +597,13 @@ func (api Minio) CompleteMultipartUploadHandler(w http.ResponseWriter, req *http /// Delete API // DeleteBucketHandler - Delete bucket -func (api Minio) DeleteBucketHandler(w http.ResponseWriter, req *http.Request) { +func (api MinioAPI) DeleteBucketHandler(w http.ResponseWriter, req *http.Request) { error := getErrorCode(MethodNotAllowed) w.WriteHeader(error.HTTPStatusCode) } // DeleteObjectHandler - Delete object -func (api Minio) DeleteObjectHandler(w http.ResponseWriter, req *http.Request) { +func (api MinioAPI) DeleteObjectHandler(w http.ResponseWriter, req *http.Request) { error := getErrorCode(MethodNotAllowed) w.WriteHeader(error.HTTPStatusCode) } diff --git a/pkg/server/api/resources.go b/server-api-resources.go similarity index 99% rename from pkg/server/api/resources.go rename to server-api-resources.go index fb5aea252..4b03d775c 100644 --- a/pkg/server/api/resources.go +++ b/server-api-resources.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package api +package main import ( "net/url" diff --git a/pkg/server/api/response.go b/server-api-response.go similarity index 99% rename from pkg/server/api/response.go rename to server-api-response.go index bb6aa5f11..afc826623 100644 --- a/pkg/server/api/response.go +++ b/server-api-response.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package api +package main import ( "net/http" diff --git a/pkg/server/api/signature.go b/server-api-signature.go similarity index 97% rename from pkg/server/api/signature.go rename to server-api-signature.go index 6d83897b4..3d1b4b0b8 100644 --- a/pkg/server/api/signature.go +++ b/server-api-signature.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package api +package main import ( "errors" @@ -28,6 +28,8 @@ import ( const ( authHeaderPrefix = "AWS4-HMAC-SHA256" + iso8601Format = "20060102T150405Z" + yyyymmdd = "20060102" ) // getCredentialsFromAuth parse credentials tag from authorization value diff --git a/pkg/server/api/typed-errors.go b/server-api-typed-errors.go similarity index 99% rename from pkg/server/api/typed-errors.go rename to server-api-typed-errors.go index d02fad6b5..ea5b554c6 100644 --- a/pkg/server/api/typed-errors.go +++ b/server-api-typed-errors.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package api +package main import "errors" diff --git a/pkg/server/api/utils.go b/server-api-utils.go similarity index 99% rename from pkg/server/api/utils.go rename to server-api-utils.go index 8bc26e4dc..f57743ac6 100644 --- a/pkg/server/api/utils.go +++ b/server-api-utils.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package api +package main import ( "encoding/base64" diff --git a/pkg/server/api/api.go b/server-api.go similarity index 69% rename from pkg/server/api/api.go rename to server-api.go index 5cd2cd967..ad50c6a7f 100644 --- a/pkg/server/api/api.go +++ b/server-api.go @@ -14,30 +14,30 @@ * limitations under the License. */ -package api +package main import "github.com/minio/minio/pkg/donut" -// Operation container for individual operations read by Ticket Master -type Operation struct { +// APIOperation container for individual operations read by Ticket Master +type APIOperation struct { ProceedCh chan struct{} } -// Minio container for API and also carries OP (operation) channel -type Minio struct { - OP chan Operation +// MinioAPI container for API and also carries OP (operation) channel +type MinioAPI struct { + OP chan APIOperation Donut donut.Interface } -// New instantiate a new minio API -func New() Minio { +// NewAPI instantiate a new minio API +func NewAPI() MinioAPI { // ignore errors for now d, err := donut.New() if err != nil { panic(err) } - return Minio{ - OP: make(chan Operation), + return MinioAPI{ + OP: make(chan APIOperation), Donut: d, } } diff --git a/server-main.go b/server-main.go index da7150192..8d8947dc0 100644 --- a/server-main.go +++ b/server-main.go @@ -16,11 +16,7 @@ package main -import ( - "github.com/minio/cli" - "github.com/minio/minio/pkg/server" - "github.com/minio/minio/pkg/server/api" -) +import "github.com/minio/cli" var serverCmd = cli.Command{ Name: "server", @@ -39,14 +35,14 @@ EXAMPLES: `, } -func getServerConfig(c *cli.Context) api.Config { +func getServerConfig(c *cli.Context) APIConfig { certFile := c.GlobalString("cert") keyFile := c.GlobalString("key") if (certFile != "" && keyFile == "") || (certFile == "" && keyFile != "") { Fatalln("Both certificate and key are required to enable https.") } tls := (certFile != "" && keyFile != "") - return api.Config{ + return APIConfig{ Address: c.GlobalString("address"), TLS: tls, CertFile: certFile, @@ -61,6 +57,6 @@ func serverMain(c *cli.Context) { } apiServerConfig := getServerConfig(c) - err := server.Start(apiServerConfig) + err := StartServer(apiServerConfig) errorIf(err.Trace(), "Failed to start the minio server.", nil) } diff --git a/pkg/server/router.go b/server-router.go similarity index 85% rename from pkg/server/router.go rename to server-router.go index db00f210b..92e6b118c 100644 --- a/pkg/server/router.go +++ b/server-router.go @@ -14,17 +14,16 @@ * limitations under the License. */ -package server +package main import ( "net/http" router "github.com/gorilla/mux" - "github.com/minio/minio/pkg/server/api" ) // registerAPI - register all the object API handlers to their respective paths -func registerAPI(mux *router.Router, a api.Minio) { +func registerAPI(mux *router.Router, a MinioAPI) { mux.HandleFunc("/", a.ListBucketsHandler).Methods("GET") mux.HandleFunc("/{bucket}", a.ListObjectsHandler).Methods("GET") mux.HandleFunc("/{bucket}", a.PutBucketHandler).Methods("PUT") @@ -45,7 +44,7 @@ func registerAPI(mux *router.Router, a api.Minio) { mux.HandleFunc("/{bucket}/{object:.*}", a.DeleteObjectHandler).Methods("DELETE") } -func registerCustomMiddleware(mux *router.Router, mwHandlers ...api.MiddlewareHandler) http.Handler { +func registerCustomMiddleware(mux *router.Router, mwHandlers ...MiddlewareHandler) http.Handler { var f http.Handler f = mux for _, mw := range mwHandlers { @@ -55,18 +54,18 @@ func registerCustomMiddleware(mux *router.Router, mwHandlers ...api.MiddlewareHa } // getAPIHandler api handler -func getAPIHandler(conf api.Config) (http.Handler, api.Minio) { - var mwHandlers = []api.MiddlewareHandler{ - api.ValidContentTypeHandler, - api.TimeValidityHandler, - api.IgnoreResourcesHandler, - api.ValidateAuthHeaderHandler, +func getAPIHandler(conf APIConfig) (http.Handler, MinioAPI) { + var mwHandlers = []MiddlewareHandler{ + ValidContentTypeHandler, + TimeValidityHandler, + IgnoreResourcesHandler, + ValidateAuthHeaderHandler, // api.LoggingHandler, // Disabled logging until we bring in external logging support - api.CorsHandler, + CorsHandler, } mux := router.NewRouter() - minioAPI := api.New() + minioAPI := NewAPI() registerAPI(mux, minioAPI) apiHandler := registerCustomMiddleware(mux, mwHandlers...) return apiHandler, minioAPI diff --git a/pkg/server/signature-v4_test.go b/server-signature-v4_test.go similarity index 95% rename from pkg/server/signature-v4_test.go rename to server-signature-v4_test.go index d3842f6d1..9de6df881 100644 --- a/pkg/server/signature-v4_test.go +++ b/server-signature-v4_test.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package server +package main import ( "bytes" @@ -28,20 +28,8 @@ import ( "regexp" "sort" "strings" - "testing" "time" "unicode/utf8" - - . "gopkg.in/check.v1" -) - -// Hook up gocheck into the "go test" runner. -func Test(t *testing.T) { TestingT(t) } - -const ( - authHeader = "AWS4-HMAC-SHA256" - iso8601Format = "20060102T150405Z" - yyyymmdd = "20060102" ) /// @@ -254,7 +242,7 @@ func (s *MyAPISignatureV4Suite) newRequest(method, urlStr string, contentLength "aws4_request", }, "/") - stringToSign := authHeader + "\n" + t.Format(iso8601Format) + "\n" + stringToSign := authHeaderPrefix + "\n" + t.Format(iso8601Format) + "\n" stringToSign = stringToSign + scope + "\n" stringToSign = stringToSign + hex.EncodeToString(sum256([]byte(canonicalRequest))) @@ -267,7 +255,7 @@ func (s *MyAPISignatureV4Suite) newRequest(method, urlStr string, contentLength // final Authorization header parts := []string{ - authHeader + " Credential=" + s.accessKeyID + "/" + scope, + authHeaderPrefix + " Credential=" + s.accessKeyID + "/" + scope, "SignedHeaders=" + signedHeaders, "Signature=" + signature, } diff --git a/pkg/server/server.go b/server.go similarity index 89% rename from pkg/server/server.go rename to server.go index 9dc699ae7..a8ae6605a 100644 --- a/pkg/server/server.go +++ b/server.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package server +package main import ( "crypto/tls" @@ -26,11 +26,10 @@ import ( "github.com/minio/minio/pkg/minhttp" "github.com/minio/minio/pkg/probe" - "github.com/minio/minio/pkg/server/api" ) // getAPI server instance -func getAPIServer(conf api.Config, apiHandler http.Handler) (*http.Server, *probe.Error) { +func getAPIServer(conf APIConfig, apiHandler http.Handler) (*http.Server, *probe.Error) { // Minio server config httpServer := &http.Server{ Addr: conf.Address, @@ -84,7 +83,7 @@ func getAPIServer(conf api.Config, apiHandler http.Handler) (*http.Server, *prob } // Start ticket master -func startTM(a api.Minio) { +func startTM(a MinioAPI) { for { for op := range a.OP { op.ProceedCh <- struct{}{} @@ -92,8 +91,8 @@ func startTM(a api.Minio) { } } -// Start starts a s3 compatible cloud storage server -func Start(conf api.Config) *probe.Error { +// StartServer starts an s3 compatible cloud storage server +func StartServer(conf APIConfig) *probe.Error { apiHandler, minioAPI := getAPIHandler(conf) apiServer, err := getAPIServer(conf, apiHandler) if err != nil { diff --git a/pkg/server/api_donut_cache_test.go b/server_donut_cache_test.go similarity index 98% rename from pkg/server/api_donut_cache_test.go rename to server_donut_cache_test.go index 024a515f7..cb2354dc2 100644 --- a/pkg/server/api_donut_cache_test.go +++ b/server_donut_cache_test.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package server +package main import ( "bytes" @@ -28,7 +28,6 @@ import ( "net/http/httptest" "github.com/minio/minio/pkg/donut" - "github.com/minio/minio/pkg/server/api" . "gopkg.in/check.v1" ) @@ -52,7 +51,7 @@ func (s *MyAPIDonutCacheSuite) SetUpSuite(c *C) { perr := donut.SaveConfig(conf) c.Assert(perr, IsNil) - httpHandler, minioAPI := getAPIHandler(api.Config{RateLimit: 16}) + httpHandler, minioAPI := getAPIHandler(APIConfig{RateLimit: 16}) go startTM(minioAPI) testAPIDonutCacheServer = httptest.NewServer(httpHandler) } @@ -319,7 +318,7 @@ func (s *MyAPIDonutCacheSuite) TestListBuckets(c *C) { c.Assert(err, IsNil) c.Assert(response.StatusCode, Equals, http.StatusOK) - var results api.ListBucketsResponse + var results ListBucketsResponse decoder := xml.NewDecoder(response.Body) err = decoder.Decode(&results) c.Assert(err, IsNil) @@ -676,7 +675,7 @@ func (s *MyAPIDonutCacheSuite) TestObjectMultipartAbort(c *C) { c.Assert(response.StatusCode, Equals, http.StatusOK) decoder := xml.NewDecoder(response.Body) - newResponse := &api.InitiateMultipartUploadResponse{} + newResponse := &InitiateMultipartUploadResponse{} err = decoder.Decode(newResponse) c.Assert(err, IsNil) @@ -721,7 +720,7 @@ func (s *MyAPIDonutCacheSuite) TestBucketMultipartList(c *C) { c.Assert(response.StatusCode, Equals, http.StatusOK) decoder := xml.NewDecoder(response.Body) - newResponse := &api.InitiateMultipartUploadResponse{} + newResponse := &InitiateMultipartUploadResponse{} err = decoder.Decode(newResponse) c.Assert(err, IsNil) @@ -750,7 +749,7 @@ func (s *MyAPIDonutCacheSuite) TestBucketMultipartList(c *C) { c.Assert(response3.StatusCode, Equals, http.StatusOK) decoder = xml.NewDecoder(response3.Body) - newResponse3 := &api.ListMultipartUploadsResponse{} + newResponse3 := &ListMultipartUploadsResponse{} err = decoder.Decode(newResponse3) c.Assert(err, IsNil) c.Assert(newResponse3.Bucket, Equals, "bucketmultipartlist") @@ -772,7 +771,7 @@ func (s *MyAPIDonutCacheSuite) TestObjectMultipartList(c *C) { c.Assert(response.StatusCode, Equals, http.StatusOK) decoder := xml.NewDecoder(response.Body) - newResponse := &api.InitiateMultipartUploadResponse{} + newResponse := &InitiateMultipartUploadResponse{} err = decoder.Decode(newResponse) c.Assert(err, IsNil) @@ -826,7 +825,7 @@ func (s *MyAPIDonutCacheSuite) TestObjectMultipart(c *C) { c.Assert(response.StatusCode, Equals, http.StatusOK) decoder := xml.NewDecoder(response.Body) - newResponse := &api.InitiateMultipartUploadResponse{} + newResponse := &InitiateMultipartUploadResponse{} err = decoder.Decode(newResponse) c.Assert(err, IsNil) @@ -888,7 +887,7 @@ func (s *MyAPIDonutCacheSuite) TestObjectMultipart(c *C) { func verifyError(c *C, response *http.Response, code, description string, statusCode int) { data, err := ioutil.ReadAll(response.Body) c.Assert(err, IsNil) - errorResponse := api.ErrorResponse{} + errorResponse := APIErrorResponse{} err = xml.Unmarshal(data, &errorResponse) c.Assert(err, IsNil) c.Assert(errorResponse.Code, Equals, code) diff --git a/pkg/server/api_donut_test.go b/server_donut_test.go similarity index 98% rename from pkg/server/api_donut_test.go rename to server_donut_test.go index 570eb1029..bd9cd2fa1 100644 --- a/pkg/server/api_donut_test.go +++ b/server_donut_test.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package server +package main import ( "bytes" @@ -29,7 +29,6 @@ import ( "net/http/httptest" "github.com/minio/minio/pkg/donut" - "github.com/minio/minio/pkg/server/api" . "gopkg.in/check.v1" ) @@ -71,7 +70,7 @@ func (s *MyAPIDonutSuite) SetUpSuite(c *C) { perr := donut.SaveConfig(conf) c.Assert(perr, IsNil) - httpHandler, minioAPI := getAPIHandler(api.Config{RateLimit: 16}) + httpHandler, minioAPI := getAPIHandler(APIConfig{RateLimit: 16}) go startTM(minioAPI) testAPIDonutServer = httptest.NewServer(httpHandler) } @@ -338,7 +337,7 @@ func (s *MyAPIDonutSuite) TestListBuckets(c *C) { c.Assert(err, IsNil) c.Assert(response.StatusCode, Equals, http.StatusOK) - var results api.ListBucketsResponse + var results ListBucketsResponse decoder := xml.NewDecoder(response.Body) err = decoder.Decode(&results) c.Assert(err, IsNil) @@ -696,7 +695,7 @@ func (s *MyAPIDonutSuite) TestObjectMultipartAbort(c *C) { c.Assert(response.StatusCode, Equals, http.StatusOK) decoder := xml.NewDecoder(response.Body) - newResponse := &api.InitiateMultipartUploadResponse{} + newResponse := &InitiateMultipartUploadResponse{} err = decoder.Decode(newResponse) c.Assert(err, IsNil) @@ -741,7 +740,7 @@ func (s *MyAPIDonutSuite) TestBucketMultipartList(c *C) { c.Assert(response.StatusCode, Equals, http.StatusOK) decoder := xml.NewDecoder(response.Body) - newResponse := &api.InitiateMultipartUploadResponse{} + newResponse := &InitiateMultipartUploadResponse{} err = decoder.Decode(newResponse) c.Assert(err, IsNil) @@ -770,7 +769,7 @@ func (s *MyAPIDonutSuite) TestBucketMultipartList(c *C) { c.Assert(response3.StatusCode, Equals, http.StatusOK) decoder = xml.NewDecoder(response3.Body) - newResponse3 := &api.ListMultipartUploadsResponse{} + newResponse3 := &ListMultipartUploadsResponse{} err = decoder.Decode(newResponse3) c.Assert(err, IsNil) c.Assert(newResponse3.Bucket, Equals, "bucketmultipartlist") @@ -792,7 +791,7 @@ func (s *MyAPIDonutSuite) TestObjectMultipartList(c *C) { c.Assert(response.StatusCode, Equals, http.StatusOK) decoder := xml.NewDecoder(response.Body) - newResponse := &api.InitiateMultipartUploadResponse{} + newResponse := &InitiateMultipartUploadResponse{} err = decoder.Decode(newResponse) c.Assert(err, IsNil) @@ -846,7 +845,7 @@ func (s *MyAPIDonutSuite) TestObjectMultipart(c *C) { c.Assert(response.StatusCode, Equals, http.StatusOK) decoder := xml.NewDecoder(response.Body) - newResponse := &api.InitiateMultipartUploadResponse{} + newResponse := &InitiateMultipartUploadResponse{} err = decoder.Decode(newResponse) c.Assert(err, IsNil) diff --git a/pkg/server/api_signature_v4_test.go b/server_signature_v4_test.go similarity index 98% rename from pkg/server/api_signature_v4_test.go rename to server_signature_v4_test.go index 61258b367..2f39ed7d2 100644 --- a/pkg/server/api_signature_v4_test.go +++ b/server_signature_v4_test.go @@ -14,7 +14,7 @@ * limitations under the License. */ -package server +package main import ( "bytes" @@ -30,7 +30,6 @@ import ( "github.com/minio/minio/pkg/auth" "github.com/minio/minio/pkg/donut" - "github.com/minio/minio/pkg/server/api" . "gopkg.in/check.v1" ) @@ -79,7 +78,7 @@ func (s *MyAPISignatureV4Suite) SetUpSuite(c *C) { perr = auth.SaveConfig(authConf) c.Assert(perr, IsNil) - httpHandler, minioAPI := getAPIHandler(api.Config{RateLimit: 16}) + httpHandler, minioAPI := getAPIHandler(APIConfig{RateLimit: 16}) go startTM(minioAPI) testSignatureV4Server = httptest.NewServer(httpHandler) } @@ -347,7 +346,7 @@ func (s *MyAPISignatureV4Suite) TestListBuckets(c *C) { c.Assert(err, IsNil) c.Assert(response.StatusCode, Equals, http.StatusOK) - var results api.ListBucketsResponse + var results ListBucketsResponse decoder := xml.NewDecoder(response.Body) err = decoder.Decode(&results) c.Assert(err, IsNil) @@ -689,7 +688,7 @@ func (s *MyAPISignatureV4Suite) TestObjectMultipartAbort(c *C) { c.Assert(response.StatusCode, Equals, http.StatusOK) decoder := xml.NewDecoder(response.Body) - newResponse := &api.InitiateMultipartUploadResponse{} + newResponse := &InitiateMultipartUploadResponse{} err = decoder.Decode(newResponse) c.Assert(err, IsNil) @@ -736,7 +735,7 @@ func (s *MyAPISignatureV4Suite) TestBucketMultipartList(c *C) { c.Assert(response.StatusCode, Equals, http.StatusOK) decoder := xml.NewDecoder(response.Body) - newResponse := &api.InitiateMultipartUploadResponse{} + newResponse := &InitiateMultipartUploadResponse{} err = decoder.Decode(newResponse) c.Assert(err, IsNil) @@ -767,7 +766,7 @@ func (s *MyAPISignatureV4Suite) TestBucketMultipartList(c *C) { c.Assert(response3.StatusCode, Equals, http.StatusOK) decoder = xml.NewDecoder(response3.Body) - newResponse3 := &api.ListMultipartUploadsResponse{} + newResponse3 := &ListMultipartUploadsResponse{} err = decoder.Decode(newResponse3) c.Assert(err, IsNil) c.Assert(newResponse3.Bucket, Equals, "bucketmultipartlist") @@ -789,7 +788,7 @@ func (s *MyAPISignatureV4Suite) TestObjectMultipartList(c *C) { c.Assert(response.StatusCode, Equals, http.StatusOK) decoder := xml.NewDecoder(response.Body) - newResponse := &api.InitiateMultipartUploadResponse{} + newResponse := &InitiateMultipartUploadResponse{} err = decoder.Decode(newResponse) c.Assert(err, IsNil) @@ -845,7 +844,7 @@ func (s *MyAPISignatureV4Suite) TestObjectMultipart(c *C) { c.Assert(response.StatusCode, Equals, http.StatusOK) decoder := xml.NewDecoder(response.Body) - newResponse := &api.InitiateMultipartUploadResponse{} + newResponse := &InitiateMultipartUploadResponse{} err = decoder.Decode(newResponse) c.Assert(err, IsNil) diff --git a/vendor.json b/vendor.json index 2f343f2ba..5a096f541 100755 --- a/vendor.json +++ b/vendor.json @@ -99,6 +99,34 @@ "local": "vendor/gopkg.in/check.v1", "revision": "11d3bc7aa68e238947792f30573146a3231fc0f1", "revisionTime": "2015-07-29T10:04:31+02:00" + }, + { + "canonical": "gopkg.in/mgo.v2", + "comment": "", + "local": "vendor/gopkg.in/mgo.v2", + "revision": "f4923a569136442e900b8cf5c1a706c0a8b0883c", + "revisionTime": "2015-08-21T12:30:02-03:00" + }, + { + "canonical": "gopkg.in/mgo.v2/bson", + "comment": "", + "local": "vendor/gopkg.in/mgo.v2/bson", + "revision": "f4923a569136442e900b8cf5c1a706c0a8b0883c", + "revisionTime": "2015-08-21T12:30:02-03:00" + }, + { + "canonical": "gopkg.in/mgo.v2/internal/sasl", + "comment": "", + "local": "vendor/gopkg.in/mgo.v2/internal/sasl", + "revision": "f4923a569136442e900b8cf5c1a706c0a8b0883c", + "revisionTime": "2015-08-21T12:30:02-03:00" + }, + { + "canonical": "gopkg.in/mgo.v2/internal/scram", + "comment": "", + "local": "vendor/gopkg.in/mgo.v2/internal/scram", + "revision": "f4923a569136442e900b8cf5c1a706c0a8b0883c", + "revisionTime": "2015-08-21T12:30:02-03:00" } ] } \ No newline at end of file diff --git a/vendor/gopkg.in/mgo.v2/LICENSE b/vendor/gopkg.in/mgo.v2/LICENSE new file mode 100644 index 000000000..770c7672b --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/LICENSE @@ -0,0 +1,25 @@ +mgo - MongoDB driver for Go + +Copyright (c) 2010-2013 - Gustavo Niemeyer + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/gopkg.in/mgo.v2/Makefile b/vendor/gopkg.in/mgo.v2/Makefile new file mode 100644 index 000000000..51bee7322 --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/Makefile @@ -0,0 +1,5 @@ +startdb: + @testdb/setup.sh start + +stopdb: + @testdb/setup.sh stop diff --git a/vendor/gopkg.in/mgo.v2/README.md b/vendor/gopkg.in/mgo.v2/README.md new file mode 100644 index 000000000..f4e452c04 --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/README.md @@ -0,0 +1,4 @@ +The MongoDB driver for Go +------------------------- + +Please go to [http://labix.org/mgo](http://labix.org/mgo) for all project details. diff --git a/vendor/gopkg.in/mgo.v2/auth.go b/vendor/gopkg.in/mgo.v2/auth.go new file mode 100644 index 000000000..dc26e52f5 --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/auth.go @@ -0,0 +1,467 @@ +// mgo - MongoDB driver for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package mgo + +import ( + "crypto/md5" + "crypto/sha1" + "encoding/hex" + "errors" + "fmt" + "sync" + + "gopkg.in/mgo.v2/bson" + "gopkg.in/mgo.v2/internal/scram" +) + +type authCmd struct { + Authenticate int + + Nonce string + User string + Key string +} + +type startSaslCmd struct { + StartSASL int `bson:"startSasl"` +} + +type authResult struct { + ErrMsg string + Ok bool +} + +type getNonceCmd struct { + GetNonce int +} + +type getNonceResult struct { + Nonce string + Err string "$err" + Code int +} + +type logoutCmd struct { + Logout int +} + +type saslCmd struct { + Start int `bson:"saslStart,omitempty"` + Continue int `bson:"saslContinue,omitempty"` + ConversationId int `bson:"conversationId,omitempty"` + Mechanism string `bson:"mechanism,omitempty"` + Payload []byte +} + +type saslResult struct { + Ok bool `bson:"ok"` + NotOk bool `bson:"code"` // Server <= 2.3.2 returns ok=1 & code>0 on errors (WTF?) + Done bool + + ConversationId int `bson:"conversationId"` + Payload []byte + ErrMsg string +} + +type saslStepper interface { + Step(serverData []byte) (clientData []byte, done bool, err error) + Close() +} + +func (socket *mongoSocket) getNonce() (nonce string, err error) { + socket.Lock() + for socket.cachedNonce == "" && socket.dead == nil { + debugf("Socket %p to %s: waiting for nonce", socket, socket.addr) + socket.gotNonce.Wait() + } + if socket.cachedNonce == "mongos" { + socket.Unlock() + return "", errors.New("Can't authenticate with mongos; see http://j.mp/mongos-auth") + } + debugf("Socket %p to %s: got nonce", socket, socket.addr) + nonce, err = socket.cachedNonce, socket.dead + socket.cachedNonce = "" + socket.Unlock() + if err != nil { + nonce = "" + } + return +} + +func (socket *mongoSocket) resetNonce() { + debugf("Socket %p to %s: requesting a new nonce", socket, socket.addr) + op := &queryOp{} + op.query = &getNonceCmd{GetNonce: 1} + op.collection = "admin.$cmd" + op.limit = -1 + op.replyFunc = func(err error, reply *replyOp, docNum int, docData []byte) { + if err != nil { + socket.kill(errors.New("getNonce: "+err.Error()), true) + return + } + result := &getNonceResult{} + err = bson.Unmarshal(docData, &result) + if err != nil { + socket.kill(errors.New("Failed to unmarshal nonce: "+err.Error()), true) + return + } + debugf("Socket %p to %s: nonce unmarshalled: %#v", socket, socket.addr, result) + if result.Code == 13390 { + // mongos doesn't yet support auth (see http://j.mp/mongos-auth) + result.Nonce = "mongos" + } else if result.Nonce == "" { + var msg string + if result.Err != "" { + msg = fmt.Sprintf("Got an empty nonce: %s (%d)", result.Err, result.Code) + } else { + msg = "Got an empty nonce" + } + socket.kill(errors.New(msg), true) + return + } + socket.Lock() + if socket.cachedNonce != "" { + socket.Unlock() + panic("resetNonce: nonce already cached") + } + socket.cachedNonce = result.Nonce + socket.gotNonce.Signal() + socket.Unlock() + } + err := socket.Query(op) + if err != nil { + socket.kill(errors.New("resetNonce: "+err.Error()), true) + } +} + +func (socket *mongoSocket) Login(cred Credential) error { + socket.Lock() + if cred.Mechanism == "" && socket.serverInfo.MaxWireVersion >= 3 { + cred.Mechanism = "SCRAM-SHA-1" + } + for _, sockCred := range socket.creds { + if sockCred == cred { + debugf("Socket %p to %s: login: db=%q user=%q (already logged in)", socket, socket.addr, cred.Source, cred.Username) + socket.Unlock() + return nil + } + } + if socket.dropLogout(cred) { + debugf("Socket %p to %s: login: db=%q user=%q (cached)", socket, socket.addr, cred.Source, cred.Username) + socket.creds = append(socket.creds, cred) + socket.Unlock() + return nil + } + socket.Unlock() + + debugf("Socket %p to %s: login: db=%q user=%q", socket, socket.addr, cred.Source, cred.Username) + + var err error + switch cred.Mechanism { + case "", "MONGODB-CR", "MONGO-CR": // Name changed to MONGODB-CR in SERVER-8501. + err = socket.loginClassic(cred) + case "PLAIN": + err = socket.loginPlain(cred) + case "MONGODB-X509": + err = socket.loginX509(cred) + default: + // Try SASL for everything else, if it is available. + err = socket.loginSASL(cred) + } + + if err != nil { + debugf("Socket %p to %s: login error: %s", socket, socket.addr, err) + } else { + debugf("Socket %p to %s: login successful", socket, socket.addr) + } + return err +} + +func (socket *mongoSocket) loginClassic(cred Credential) error { + // Note that this only works properly because this function is + // synchronous, which means the nonce won't get reset while we're + // using it and any other login requests will block waiting for a + // new nonce provided in the defer call below. + nonce, err := socket.getNonce() + if err != nil { + return err + } + defer socket.resetNonce() + + psum := md5.New() + psum.Write([]byte(cred.Username + ":mongo:" + cred.Password)) + + ksum := md5.New() + ksum.Write([]byte(nonce + cred.Username)) + ksum.Write([]byte(hex.EncodeToString(psum.Sum(nil)))) + + key := hex.EncodeToString(ksum.Sum(nil)) + + cmd := authCmd{Authenticate: 1, User: cred.Username, Nonce: nonce, Key: key} + res := authResult{} + return socket.loginRun(cred.Source, &cmd, &res, func() error { + if !res.Ok { + return errors.New(res.ErrMsg) + } + socket.Lock() + socket.dropAuth(cred.Source) + socket.creds = append(socket.creds, cred) + socket.Unlock() + return nil + }) +} + +type authX509Cmd struct { + Authenticate int + User string + Mechanism string +} + +func (socket *mongoSocket) loginX509(cred Credential) error { + cmd := authX509Cmd{Authenticate: 1, User: cred.Username, Mechanism: "MONGODB-X509"} + res := authResult{} + return socket.loginRun(cred.Source, &cmd, &res, func() error { + if !res.Ok { + return errors.New(res.ErrMsg) + } + socket.Lock() + socket.dropAuth(cred.Source) + socket.creds = append(socket.creds, cred) + socket.Unlock() + return nil + }) +} + +func (socket *mongoSocket) loginPlain(cred Credential) error { + cmd := saslCmd{Start: 1, Mechanism: "PLAIN", Payload: []byte("\x00" + cred.Username + "\x00" + cred.Password)} + res := authResult{} + return socket.loginRun(cred.Source, &cmd, &res, func() error { + if !res.Ok { + return errors.New(res.ErrMsg) + } + socket.Lock() + socket.dropAuth(cred.Source) + socket.creds = append(socket.creds, cred) + socket.Unlock() + return nil + }) +} + +func (socket *mongoSocket) loginSASL(cred Credential) error { + var sasl saslStepper + var err error + if cred.Mechanism == "SCRAM-SHA-1" { + // SCRAM is handled without external libraries. + sasl = saslNewScram(cred) + } else if len(cred.ServiceHost) > 0 { + sasl, err = saslNew(cred, cred.ServiceHost) + } else { + sasl, err = saslNew(cred, socket.Server().Addr) + } + if err != nil { + return err + } + defer sasl.Close() + + // The goal of this logic is to carry a locked socket until the + // local SASL step confirms the auth is valid; the socket needs to be + // locked so that concurrent action doesn't leave the socket in an + // auth state that doesn't reflect the operations that took place. + // As a simple case, imagine inverting login=>logout to logout=>login. + // + // The logic below works because the lock func isn't called concurrently. + locked := false + lock := func(b bool) { + if locked != b { + locked = b + if b { + socket.Lock() + } else { + socket.Unlock() + } + } + } + + lock(true) + defer lock(false) + + start := 1 + cmd := saslCmd{} + res := saslResult{} + for { + payload, done, err := sasl.Step(res.Payload) + if err != nil { + return err + } + if done && res.Done { + socket.dropAuth(cred.Source) + socket.creds = append(socket.creds, cred) + break + } + lock(false) + + cmd = saslCmd{ + Start: start, + Continue: 1 - start, + ConversationId: res.ConversationId, + Mechanism: cred.Mechanism, + Payload: payload, + } + start = 0 + err = socket.loginRun(cred.Source, &cmd, &res, func() error { + // See the comment on lock for why this is necessary. + lock(true) + if !res.Ok || res.NotOk { + return fmt.Errorf("server returned error on SASL authentication step: %s", res.ErrMsg) + } + return nil + }) + if err != nil { + return err + } + if done && res.Done { + socket.dropAuth(cred.Source) + socket.creds = append(socket.creds, cred) + break + } + } + + return nil +} + +func saslNewScram(cred Credential) *saslScram { + credsum := md5.New() + credsum.Write([]byte(cred.Username + ":mongo:" + cred.Password)) + client := scram.NewClient(sha1.New, cred.Username, hex.EncodeToString(credsum.Sum(nil))) + return &saslScram{cred: cred, client: client} +} + +type saslScram struct { + cred Credential + client *scram.Client +} + +func (s *saslScram) Close() {} + +func (s *saslScram) Step(serverData []byte) (clientData []byte, done bool, err error) { + more := s.client.Step(serverData) + return s.client.Out(), !more, s.client.Err() +} + +func (socket *mongoSocket) loginRun(db string, query, result interface{}, f func() error) error { + var mutex sync.Mutex + var replyErr error + mutex.Lock() + + op := queryOp{} + op.query = query + op.collection = db + ".$cmd" + op.limit = -1 + op.replyFunc = func(err error, reply *replyOp, docNum int, docData []byte) { + defer mutex.Unlock() + + if err != nil { + replyErr = err + return + } + + err = bson.Unmarshal(docData, result) + if err != nil { + replyErr = err + } else { + // Must handle this within the read loop for the socket, so + // that concurrent login requests are properly ordered. + replyErr = f() + } + } + + err := socket.Query(&op) + if err != nil { + return err + } + mutex.Lock() // Wait. + return replyErr +} + +func (socket *mongoSocket) Logout(db string) { + socket.Lock() + cred, found := socket.dropAuth(db) + if found { + debugf("Socket %p to %s: logout: db=%q (flagged)", socket, socket.addr, db) + socket.logout = append(socket.logout, cred) + } + socket.Unlock() +} + +func (socket *mongoSocket) LogoutAll() { + socket.Lock() + if l := len(socket.creds); l > 0 { + debugf("Socket %p to %s: logout all (flagged %d)", socket, socket.addr, l) + socket.logout = append(socket.logout, socket.creds...) + socket.creds = socket.creds[0:0] + } + socket.Unlock() +} + +func (socket *mongoSocket) flushLogout() (ops []interface{}) { + socket.Lock() + if l := len(socket.logout); l > 0 { + debugf("Socket %p to %s: logout all (flushing %d)", socket, socket.addr, l) + for i := 0; i != l; i++ { + op := queryOp{} + op.query = &logoutCmd{1} + op.collection = socket.logout[i].Source + ".$cmd" + op.limit = -1 + ops = append(ops, &op) + } + socket.logout = socket.logout[0:0] + } + socket.Unlock() + return +} + +func (socket *mongoSocket) dropAuth(db string) (cred Credential, found bool) { + for i, sockCred := range socket.creds { + if sockCred.Source == db { + copy(socket.creds[i:], socket.creds[i+1:]) + socket.creds = socket.creds[:len(socket.creds)-1] + return sockCred, true + } + } + return cred, false +} + +func (socket *mongoSocket) dropLogout(cred Credential) (found bool) { + for i, sockCred := range socket.logout { + if sockCred == cred { + copy(socket.logout[i:], socket.logout[i+1:]) + socket.logout = socket.logout[:len(socket.logout)-1] + return true + } + } + return false +} diff --git a/vendor/gopkg.in/mgo.v2/auth_test.go b/vendor/gopkg.in/mgo.v2/auth_test.go new file mode 100644 index 000000000..eb42ab1d6 --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/auth_test.go @@ -0,0 +1,1180 @@ +// mgo - MongoDB driver for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package mgo_test + +import ( + "crypto/tls" + "flag" + "fmt" + "io/ioutil" + "net" + "net/url" + "os" + "runtime" + "sync" + "time" + + . "gopkg.in/check.v1" + "gopkg.in/mgo.v2" +) + +func (s *S) TestAuthLoginDatabase(c *C) { + // Test both with a normal database and with an authenticated shard. + for _, addr := range []string{"localhost:40002", "localhost:40203"} { + session, err := mgo.Dial(addr) + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"n": 1}) + c.Assert(err, ErrorMatches, "unauthorized|need to login|not authorized .*") + + admindb := session.DB("admin") + + err = admindb.Login("root", "wrong") + c.Assert(err, ErrorMatches, "auth fail(s|ed)|.*Authentication failed.") + + err = admindb.Login("root", "rapadura") + c.Assert(err, IsNil) + + err = coll.Insert(M{"n": 1}) + c.Assert(err, IsNil) + } +} + +func (s *S) TestAuthLoginSession(c *C) { + // Test both with a normal database and with an authenticated shard. + for _, addr := range []string{"localhost:40002", "localhost:40203"} { + session, err := mgo.Dial(addr) + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"n": 1}) + c.Assert(err, ErrorMatches, "unauthorized|need to login|not authorized .*") + + cred := mgo.Credential{ + Username: "root", + Password: "wrong", + } + err = session.Login(&cred) + c.Assert(err, ErrorMatches, "auth fail(s|ed)|.*Authentication failed.") + + cred.Password = "rapadura" + + err = session.Login(&cred) + c.Assert(err, IsNil) + + err = coll.Insert(M{"n": 1}) + c.Assert(err, IsNil) + } +} + +func (s *S) TestAuthLoginLogout(c *C) { + // Test both with a normal database and with an authenticated shard. + for _, addr := range []string{"localhost:40002", "localhost:40203"} { + session, err := mgo.Dial(addr) + c.Assert(err, IsNil) + defer session.Close() + + admindb := session.DB("admin") + err = admindb.Login("root", "rapadura") + c.Assert(err, IsNil) + + admindb.Logout() + + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"n": 1}) + c.Assert(err, ErrorMatches, "unauthorized|need to login|not authorized .*") + + // Must have dropped auth from the session too. + session = session.Copy() + defer session.Close() + + coll = session.DB("mydb").C("mycoll") + err = coll.Insert(M{"n": 1}) + c.Assert(err, ErrorMatches, "unauthorized|need to login|not authorized .*") + } +} + +func (s *S) TestAuthLoginLogoutAll(c *C) { + session, err := mgo.Dial("localhost:40002") + c.Assert(err, IsNil) + defer session.Close() + + admindb := session.DB("admin") + err = admindb.Login("root", "rapadura") + c.Assert(err, IsNil) + + session.LogoutAll() + + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"n": 1}) + c.Assert(err, ErrorMatches, "unauthorized|need to login|not authorized .*") + + // Must have dropped auth from the session too. + session = session.Copy() + defer session.Close() + + coll = session.DB("mydb").C("mycoll") + err = coll.Insert(M{"n": 1}) + c.Assert(err, ErrorMatches, "unauthorized|need to login|not authorized .*") +} + +func (s *S) TestAuthUpsertUserErrors(c *C) { + session, err := mgo.Dial("localhost:40002") + c.Assert(err, IsNil) + defer session.Close() + + admindb := session.DB("admin") + err = admindb.Login("root", "rapadura") + c.Assert(err, IsNil) + + mydb := session.DB("mydb") + + err = mydb.UpsertUser(&mgo.User{}) + c.Assert(err, ErrorMatches, "user has no Username") + + err = mydb.UpsertUser(&mgo.User{Username: "user", Password: "pass", UserSource: "source"}) + c.Assert(err, ErrorMatches, "user has both Password/PasswordHash and UserSource set") + + err = mydb.UpsertUser(&mgo.User{Username: "user", Password: "pass", OtherDBRoles: map[string][]mgo.Role{"db": nil}}) + c.Assert(err, ErrorMatches, "user with OtherDBRoles is only supported in the admin or \\$external databases") +} + +func (s *S) TestAuthUpsertUser(c *C) { + if !s.versionAtLeast(2, 4) { + c.Skip("UpsertUser only works on 2.4+") + } + session, err := mgo.Dial("localhost:40002") + c.Assert(err, IsNil) + defer session.Close() + + admindb := session.DB("admin") + err = admindb.Login("root", "rapadura") + c.Assert(err, IsNil) + + mydb := session.DB("mydb") + + ruser := &mgo.User{ + Username: "myruser", + Password: "mypass", + Roles: []mgo.Role{mgo.RoleRead}, + } + rwuser := &mgo.User{ + Username: "myrwuser", + Password: "mypass", + Roles: []mgo.Role{mgo.RoleReadWrite}, + } + + err = mydb.UpsertUser(ruser) + c.Assert(err, IsNil) + err = mydb.UpsertUser(rwuser) + c.Assert(err, IsNil) + + err = mydb.Login("myruser", "mypass") + c.Assert(err, IsNil) + + admindb.Logout() + + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"n": 1}) + c.Assert(err, ErrorMatches, "unauthorized|not authorized .*") + + err = mydb.Login("myrwuser", "mypass") + c.Assert(err, IsNil) + + err = coll.Insert(M{"n": 1}) + c.Assert(err, IsNil) + + myotherdb := session.DB("myotherdb") + + err = admindb.Login("root", "rapadura") + c.Assert(err, IsNil) + + // Test UserSource. + rwuserother := &mgo.User{ + Username: "myrwuser", + UserSource: "mydb", + Roles: []mgo.Role{mgo.RoleRead}, + } + + err = myotherdb.UpsertUser(rwuserother) + if s.versionAtLeast(2, 6) { + c.Assert(err, ErrorMatches, `MongoDB 2.6\+ does not support the UserSource setting`) + return + } + c.Assert(err, IsNil) + + admindb.Logout() + + // Test indirection via UserSource: we can't write to it, because + // the roles for myrwuser are different there. + othercoll := myotherdb.C("myothercoll") + err = othercoll.Insert(M{"n": 1}) + c.Assert(err, ErrorMatches, "unauthorized|not authorized .*") + + // Reading works, though. + err = othercoll.Find(nil).One(nil) + c.Assert(err, Equals, mgo.ErrNotFound) + + // Can't login directly into the database using UserSource, though. + err = myotherdb.Login("myrwuser", "mypass") + c.Assert(err, ErrorMatches, "auth fail(s|ed)|.*Authentication failed.") +} + +func (s *S) TestAuthUpsertUserOtherDBRoles(c *C) { + if !s.versionAtLeast(2, 4) { + c.Skip("UpsertUser only works on 2.4+") + } + session, err := mgo.Dial("localhost:40002") + c.Assert(err, IsNil) + defer session.Close() + + admindb := session.DB("admin") + err = admindb.Login("root", "rapadura") + c.Assert(err, IsNil) + + ruser := &mgo.User{ + Username: "myruser", + Password: "mypass", + OtherDBRoles: map[string][]mgo.Role{"mydb": []mgo.Role{mgo.RoleRead}}, + } + + err = admindb.UpsertUser(ruser) + c.Assert(err, IsNil) + defer admindb.RemoveUser("myruser") + + admindb.Logout() + err = admindb.Login("myruser", "mypass") + + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"n": 1}) + c.Assert(err, ErrorMatches, "unauthorized|not authorized .*") + + err = coll.Find(nil).One(nil) + c.Assert(err, Equals, mgo.ErrNotFound) +} + +func (s *S) TestAuthUpsertUserUpdates(c *C) { + if !s.versionAtLeast(2, 4) { + c.Skip("UpsertUser only works on 2.4+") + } + session, err := mgo.Dial("localhost:40002") + c.Assert(err, IsNil) + defer session.Close() + + admindb := session.DB("admin") + err = admindb.Login("root", "rapadura") + c.Assert(err, IsNil) + + mydb := session.DB("mydb") + + // Insert a user that can read. + user := &mgo.User{ + Username: "myruser", + Password: "mypass", + Roles: []mgo.Role{mgo.RoleRead}, + } + err = mydb.UpsertUser(user) + c.Assert(err, IsNil) + + // Now update the user password. + user = &mgo.User{ + Username: "myruser", + Password: "mynewpass", + } + err = mydb.UpsertUser(user) + c.Assert(err, IsNil) + + // Login with the new user. + usession, err := mgo.Dial("myruser:mynewpass@localhost:40002/mydb") + c.Assert(err, IsNil) + defer usession.Close() + + // Can read, but not write. + err = usession.DB("mydb").C("mycoll").Find(nil).One(nil) + c.Assert(err, Equals, mgo.ErrNotFound) + err = usession.DB("mydb").C("mycoll").Insert(M{"ok": 1}) + c.Assert(err, ErrorMatches, "unauthorized|not authorized .*") + + // Update the user role. + user = &mgo.User{ + Username: "myruser", + Roles: []mgo.Role{mgo.RoleReadWrite}, + } + err = mydb.UpsertUser(user) + c.Assert(err, IsNil) + + // Dial again to ensure the password hasn't changed. + usession, err = mgo.Dial("myruser:mynewpass@localhost:40002/mydb") + c.Assert(err, IsNil) + defer usession.Close() + + // Now it can write. + err = usession.DB("mydb").C("mycoll").Insert(M{"ok": 1}) + c.Assert(err, IsNil) +} + +func (s *S) TestAuthAddUser(c *C) { + session, err := mgo.Dial("localhost:40002") + c.Assert(err, IsNil) + defer session.Close() + + admindb := session.DB("admin") + err = admindb.Login("root", "rapadura") + c.Assert(err, IsNil) + + mydb := session.DB("mydb") + err = mydb.AddUser("myruser", "mypass", true) + c.Assert(err, IsNil) + err = mydb.AddUser("mywuser", "mypass", false) + c.Assert(err, IsNil) + + err = mydb.Login("myruser", "mypass") + c.Assert(err, IsNil) + + admindb.Logout() + + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"n": 1}) + c.Assert(err, ErrorMatches, "unauthorized|not authorized .*") + + err = mydb.Login("mywuser", "mypass") + c.Assert(err, IsNil) + + err = coll.Insert(M{"n": 1}) + c.Assert(err, IsNil) +} + +func (s *S) TestAuthAddUserReplaces(c *C) { + session, err := mgo.Dial("localhost:40002") + c.Assert(err, IsNil) + defer session.Close() + + admindb := session.DB("admin") + err = admindb.Login("root", "rapadura") + c.Assert(err, IsNil) + + mydb := session.DB("mydb") + err = mydb.AddUser("myuser", "myoldpass", false) + c.Assert(err, IsNil) + err = mydb.AddUser("myuser", "mynewpass", true) + c.Assert(err, IsNil) + + admindb.Logout() + + err = mydb.Login("myuser", "myoldpass") + c.Assert(err, ErrorMatches, "auth fail(s|ed)|.*Authentication failed.") + err = mydb.Login("myuser", "mynewpass") + c.Assert(err, IsNil) + + // ReadOnly flag was changed too. + err = mydb.C("mycoll").Insert(M{"n": 1}) + c.Assert(err, ErrorMatches, "unauthorized|not authorized .*") +} + +func (s *S) TestAuthRemoveUser(c *C) { + session, err := mgo.Dial("localhost:40002") + c.Assert(err, IsNil) + defer session.Close() + + admindb := session.DB("admin") + err = admindb.Login("root", "rapadura") + c.Assert(err, IsNil) + + mydb := session.DB("mydb") + err = mydb.AddUser("myuser", "mypass", true) + c.Assert(err, IsNil) + err = mydb.RemoveUser("myuser") + c.Assert(err, IsNil) + err = mydb.RemoveUser("myuser") + c.Assert(err, Equals, mgo.ErrNotFound) + + err = mydb.Login("myuser", "mypass") + c.Assert(err, ErrorMatches, "auth fail(s|ed)|.*Authentication failed.") +} + +func (s *S) TestAuthLoginTwiceDoesNothing(c *C) { + session, err := mgo.Dial("localhost:40002") + c.Assert(err, IsNil) + defer session.Close() + + admindb := session.DB("admin") + err = admindb.Login("root", "rapadura") + c.Assert(err, IsNil) + + oldStats := mgo.GetStats() + + err = admindb.Login("root", "rapadura") + c.Assert(err, IsNil) + + newStats := mgo.GetStats() + c.Assert(newStats.SentOps, Equals, oldStats.SentOps) +} + +func (s *S) TestAuthLoginLogoutLoginDoesNothing(c *C) { + session, err := mgo.Dial("localhost:40002") + c.Assert(err, IsNil) + defer session.Close() + + admindb := session.DB("admin") + err = admindb.Login("root", "rapadura") + c.Assert(err, IsNil) + + oldStats := mgo.GetStats() + + admindb.Logout() + err = admindb.Login("root", "rapadura") + c.Assert(err, IsNil) + + newStats := mgo.GetStats() + c.Assert(newStats.SentOps, Equals, oldStats.SentOps) +} + +func (s *S) TestAuthLoginSwitchUser(c *C) { + session, err := mgo.Dial("localhost:40002") + c.Assert(err, IsNil) + defer session.Close() + + admindb := session.DB("admin") + err = admindb.Login("root", "rapadura") + c.Assert(err, IsNil) + + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"n": 1}) + c.Assert(err, IsNil) + + err = admindb.Login("reader", "rapadura") + c.Assert(err, IsNil) + + // Can't write. + err = coll.Insert(M{"n": 1}) + c.Assert(err, ErrorMatches, "unauthorized|not authorized .*") + + // But can read. + result := struct{ N int }{} + err = coll.Find(nil).One(&result) + c.Assert(err, IsNil) + c.Assert(result.N, Equals, 1) +} + +func (s *S) TestAuthLoginChangePassword(c *C) { + session, err := mgo.Dial("localhost:40002") + c.Assert(err, IsNil) + defer session.Close() + + admindb := session.DB("admin") + err = admindb.Login("root", "rapadura") + c.Assert(err, IsNil) + + mydb := session.DB("mydb") + err = mydb.AddUser("myuser", "myoldpass", false) + c.Assert(err, IsNil) + + err = mydb.Login("myuser", "myoldpass") + c.Assert(err, IsNil) + + err = mydb.AddUser("myuser", "mynewpass", true) + c.Assert(err, IsNil) + + err = mydb.Login("myuser", "mynewpass") + c.Assert(err, IsNil) + + admindb.Logout() + + // The second login must be in effect, which means read-only. + err = mydb.C("mycoll").Insert(M{"n": 1}) + c.Assert(err, ErrorMatches, "unauthorized|not authorized .*") +} + +func (s *S) TestAuthLoginCachingWithSessionRefresh(c *C) { + session, err := mgo.Dial("localhost:40002") + c.Assert(err, IsNil) + defer session.Close() + + admindb := session.DB("admin") + err = admindb.Login("root", "rapadura") + c.Assert(err, IsNil) + + session.Refresh() + + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"n": 1}) + c.Assert(err, IsNil) +} + +func (s *S) TestAuthLoginCachingWithSessionCopy(c *C) { + session, err := mgo.Dial("localhost:40002") + c.Assert(err, IsNil) + defer session.Close() + + admindb := session.DB("admin") + err = admindb.Login("root", "rapadura") + c.Assert(err, IsNil) + + session = session.Copy() + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"n": 1}) + c.Assert(err, IsNil) +} + +func (s *S) TestAuthLoginCachingWithSessionClone(c *C) { + session, err := mgo.Dial("localhost:40002") + c.Assert(err, IsNil) + defer session.Close() + + admindb := session.DB("admin") + err = admindb.Login("root", "rapadura") + c.Assert(err, IsNil) + + session = session.Clone() + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"n": 1}) + c.Assert(err, IsNil) +} + +func (s *S) TestAuthLoginCachingWithNewSession(c *C) { + session, err := mgo.Dial("localhost:40002") + c.Assert(err, IsNil) + defer session.Close() + + admindb := session.DB("admin") + err = admindb.Login("root", "rapadura") + c.Assert(err, IsNil) + + session = session.New() + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"n": 1}) + c.Assert(err, ErrorMatches, "unauthorized|need to login|not authorized .*") +} + +func (s *S) TestAuthLoginCachingAcrossPool(c *C) { + // Logins are cached even when the conenction goes back + // into the pool. + + session, err := mgo.Dial("localhost:40002") + c.Assert(err, IsNil) + defer session.Close() + + admindb := session.DB("admin") + err = admindb.Login("root", "rapadura") + c.Assert(err, IsNil) + + // Add another user to test the logout case at the same time. + mydb := session.DB("mydb") + err = mydb.AddUser("myuser", "mypass", false) + c.Assert(err, IsNil) + + err = mydb.Login("myuser", "mypass") + c.Assert(err, IsNil) + + // Logout root explicitly, to test both cases. + admindb.Logout() + + // Give socket back to pool. + session.Refresh() + + // Brand new session, should use socket from the pool. + other := session.New() + defer other.Close() + + oldStats := mgo.GetStats() + + err = other.DB("admin").Login("root", "rapadura") + c.Assert(err, IsNil) + err = other.DB("mydb").Login("myuser", "mypass") + c.Assert(err, IsNil) + + // Both logins were cached, so no ops. + newStats := mgo.GetStats() + c.Assert(newStats.SentOps, Equals, oldStats.SentOps) + + // And they actually worked. + err = other.DB("mydb").C("mycoll").Insert(M{"n": 1}) + c.Assert(err, IsNil) + + other.DB("admin").Logout() + + err = other.DB("mydb").C("mycoll").Insert(M{"n": 1}) + c.Assert(err, IsNil) +} + +func (s *S) TestAuthLoginCachingAcrossPoolWithLogout(c *C) { + // Now verify that logouts are properly flushed if they + // are not revalidated after leaving the pool. + + session, err := mgo.Dial("localhost:40002") + c.Assert(err, IsNil) + defer session.Close() + + admindb := session.DB("admin") + err = admindb.Login("root", "rapadura") + c.Assert(err, IsNil) + + // Add another user to test the logout case at the same time. + mydb := session.DB("mydb") + err = mydb.AddUser("myuser", "mypass", true) + c.Assert(err, IsNil) + + err = mydb.Login("myuser", "mypass") + c.Assert(err, IsNil) + + // Just some data to query later. + err = session.DB("mydb").C("mycoll").Insert(M{"n": 1}) + c.Assert(err, IsNil) + + // Give socket back to pool. + session.Refresh() + + // Brand new session, should use socket from the pool. + other := session.New() + defer other.Close() + + oldStats := mgo.GetStats() + + err = other.DB("mydb").Login("myuser", "mypass") + c.Assert(err, IsNil) + + // Login was cached, so no ops. + newStats := mgo.GetStats() + c.Assert(newStats.SentOps, Equals, oldStats.SentOps) + + // Can't write, since root has been implicitly logged out + // when the collection went into the pool, and not revalidated. + err = other.DB("mydb").C("mycoll").Insert(M{"n": 1}) + c.Assert(err, ErrorMatches, "unauthorized|not authorized .*") + + // But can read due to the revalidated myuser login. + result := struct{ N int }{} + err = other.DB("mydb").C("mycoll").Find(nil).One(&result) + c.Assert(err, IsNil) + c.Assert(result.N, Equals, 1) +} + +func (s *S) TestAuthEventual(c *C) { + // Eventual sessions don't keep sockets around, so they are + // an interesting test case. + session, err := mgo.Dial("localhost:40002") + c.Assert(err, IsNil) + defer session.Close() + + admindb := session.DB("admin") + err = admindb.Login("root", "rapadura") + c.Assert(err, IsNil) + + err = session.DB("mydb").C("mycoll").Insert(M{"n": 1}) + c.Assert(err, IsNil) + + var wg sync.WaitGroup + wg.Add(20) + + for i := 0; i != 10; i++ { + go func() { + defer wg.Done() + var result struct{ N int } + err := session.DB("mydb").C("mycoll").Find(nil).One(&result) + c.Assert(err, IsNil) + c.Assert(result.N, Equals, 1) + }() + } + + for i := 0; i != 10; i++ { + go func() { + defer wg.Done() + err := session.DB("mydb").C("mycoll").Insert(M{"n": 1}) + c.Assert(err, IsNil) + }() + } + + wg.Wait() +} + +func (s *S) TestAuthURL(c *C) { + session, err := mgo.Dial("mongodb://root:rapadura@localhost:40002/") + c.Assert(err, IsNil) + defer session.Close() + + err = session.DB("mydb").C("mycoll").Insert(M{"n": 1}) + c.Assert(err, IsNil) +} + +func (s *S) TestAuthURLWrongCredentials(c *C) { + session, err := mgo.Dial("mongodb://root:wrong@localhost:40002/") + if session != nil { + session.Close() + } + c.Assert(err, ErrorMatches, "auth fail(s|ed)|.*Authentication failed.") + c.Assert(session, IsNil) +} + +func (s *S) TestAuthURLWithNewSession(c *C) { + // When authentication is in the URL, the new session will + // actually carry it on as well, even if logged out explicitly. + session, err := mgo.Dial("mongodb://root:rapadura@localhost:40002/") + c.Assert(err, IsNil) + defer session.Close() + + session.DB("admin").Logout() + + // Do it twice to ensure it passes the needed data on. + session = session.New() + defer session.Close() + session = session.New() + defer session.Close() + + err = session.DB("mydb").C("mycoll").Insert(M{"n": 1}) + c.Assert(err, IsNil) +} + +func (s *S) TestAuthURLWithDatabase(c *C) { + session, err := mgo.Dial("mongodb://root:rapadura@localhost:40002") + c.Assert(err, IsNil) + defer session.Close() + + mydb := session.DB("mydb") + err = mydb.AddUser("myruser", "mypass", true) + c.Assert(err, IsNil) + + // Test once with database, and once with source. + for i := 0; i < 2; i++ { + var url string + if i == 0 { + url = "mongodb://myruser:mypass@localhost:40002/mydb" + } else { + url = "mongodb://myruser:mypass@localhost:40002/admin?authSource=mydb" + } + usession, err := mgo.Dial(url) + c.Assert(err, IsNil) + defer usession.Close() + + ucoll := usession.DB("mydb").C("mycoll") + err = ucoll.FindId(0).One(nil) + c.Assert(err, Equals, mgo.ErrNotFound) + err = ucoll.Insert(M{"n": 1}) + c.Assert(err, ErrorMatches, "unauthorized|not authorized .*") + } +} + +func (s *S) TestDefaultDatabase(c *C) { + tests := []struct{ url, db string }{ + {"mongodb://root:rapadura@localhost:40002", "test"}, + {"mongodb://root:rapadura@localhost:40002/admin", "admin"}, + {"mongodb://localhost:40001", "test"}, + {"mongodb://localhost:40001/", "test"}, + {"mongodb://localhost:40001/mydb", "mydb"}, + } + + for _, test := range tests { + session, err := mgo.Dial(test.url) + c.Assert(err, IsNil) + defer session.Close() + + c.Logf("test: %#v", test) + c.Assert(session.DB("").Name, Equals, test.db) + + scopy := session.Copy() + c.Check(scopy.DB("").Name, Equals, test.db) + scopy.Close() + } +} + +func (s *S) TestAuthDirect(c *C) { + // Direct connections must work to the master and slaves. + for _, port := range []string{"40031", "40032", "40033"} { + url := fmt.Sprintf("mongodb://root:rapadura@localhost:%s/?connect=direct", port) + session, err := mgo.Dial(url) + c.Assert(err, IsNil) + defer session.Close() + + session.SetMode(mgo.Monotonic, true) + + var result struct{} + err = session.DB("mydb").C("mycoll").Find(nil).One(&result) + c.Assert(err, Equals, mgo.ErrNotFound) + } +} + +func (s *S) TestAuthDirectWithLogin(c *C) { + // Direct connections must work to the master and slaves. + for _, port := range []string{"40031", "40032", "40033"} { + url := fmt.Sprintf("mongodb://localhost:%s/?connect=direct", port) + session, err := mgo.Dial(url) + c.Assert(err, IsNil) + defer session.Close() + + session.SetMode(mgo.Monotonic, true) + session.SetSyncTimeout(3 * time.Second) + + err = session.DB("admin").Login("root", "rapadura") + c.Assert(err, IsNil) + + var result struct{} + err = session.DB("mydb").C("mycoll").Find(nil).One(&result) + c.Assert(err, Equals, mgo.ErrNotFound) + } +} + +func (s *S) TestAuthScramSha1Cred(c *C) { + if !s.versionAtLeast(2, 7, 7) { + c.Skip("SCRAM-SHA-1 tests depend on 2.7.7") + } + cred := &mgo.Credential{ + Username: "root", + Password: "rapadura", + Mechanism: "SCRAM-SHA-1", + Source: "admin", + } + host := "localhost:40002" + c.Logf("Connecting to %s...", host) + session, err := mgo.Dial(host) + c.Assert(err, IsNil) + defer session.Close() + + mycoll := session.DB("admin").C("mycoll") + + c.Logf("Connected! Testing the need for authentication...") + err = mycoll.Find(nil).One(nil) + c.Assert(err, ErrorMatches, "unauthorized|not authorized .*") + + c.Logf("Authenticating...") + err = session.Login(cred) + c.Assert(err, IsNil) + c.Logf("Authenticated!") + + c.Logf("Connected! Testing the need for authentication...") + err = mycoll.Find(nil).One(nil) + c.Assert(err, Equals, mgo.ErrNotFound) +} + +func (s *S) TestAuthScramSha1URL(c *C) { + if !s.versionAtLeast(2, 7, 7) { + c.Skip("SCRAM-SHA-1 tests depend on 2.7.7") + } + host := "localhost:40002" + c.Logf("Connecting to %s...", host) + session, err := mgo.Dial(fmt.Sprintf("root:rapadura@%s?authMechanism=SCRAM-SHA-1", host)) + c.Assert(err, IsNil) + defer session.Close() + + mycoll := session.DB("admin").C("mycoll") + + c.Logf("Connected! Testing the need for authentication...") + err = mycoll.Find(nil).One(nil) + c.Assert(err, Equals, mgo.ErrNotFound) +} + +func (s *S) TestAuthX509Cred(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + binfo, err := session.BuildInfo() + c.Assert(err, IsNil) + if binfo.OpenSSLVersion == "" { + c.Skip("server does not support SSL") + } + + clientCertPEM, err := ioutil.ReadFile("testdb/client.pem") + c.Assert(err, IsNil) + + clientCert, err := tls.X509KeyPair(clientCertPEM, clientCertPEM) + c.Assert(err, IsNil) + + tlsConfig := &tls.Config{ + // Isolating tests to client certs, don't care about server validation. + InsecureSkipVerify: true, + Certificates: []tls.Certificate{clientCert}, + } + + var host = "localhost:40003" + c.Logf("Connecting to %s...", host) + session, err = mgo.DialWithInfo(&mgo.DialInfo{ + Addrs: []string{host}, + DialServer: func(addr *mgo.ServerAddr) (net.Conn, error) { + return tls.Dial("tcp", addr.String(), tlsConfig) + }, + }) + c.Assert(err, IsNil) + defer session.Close() + + err = session.Login(&mgo.Credential{Username: "root", Password: "rapadura"}) + c.Assert(err, IsNil) + + // This needs to be kept in sync with client.pem + x509Subject := "CN=localhost,OU=Client,O=MGO,L=MGO,ST=MGO,C=GO" + + externalDB := session.DB("$external") + var x509User mgo.User = mgo.User{ + Username: x509Subject, + OtherDBRoles: map[string][]mgo.Role{"admin": []mgo.Role{mgo.RoleRoot}}, + } + err = externalDB.UpsertUser(&x509User) + c.Assert(err, IsNil) + + session.LogoutAll() + + c.Logf("Connected! Ensuring authentication is required...") + names, err := session.DatabaseNames() + c.Assert(err, ErrorMatches, "not authorized .*") + + cred := &mgo.Credential{ + Username: x509Subject, + Mechanism: "MONGODB-X509", + Source: "$external", + } + + c.Logf("Authenticating...") + err = session.Login(cred) + c.Assert(err, IsNil) + c.Logf("Authenticated!") + + names, err = session.DatabaseNames() + c.Assert(err, IsNil) + c.Assert(len(names) > 0, Equals, true) +} + +var ( + plainFlag = flag.String("plain", "", "Host to test PLAIN authentication against (depends on custom environment)") + plainUser = "einstein" + plainPass = "password" +) + +func (s *S) TestAuthPlainCred(c *C) { + if *plainFlag == "" { + c.Skip("no -plain") + } + cred := &mgo.Credential{ + Username: plainUser, + Password: plainPass, + Source: "$external", + Mechanism: "PLAIN", + } + c.Logf("Connecting to %s...", *plainFlag) + session, err := mgo.Dial(*plainFlag) + c.Assert(err, IsNil) + defer session.Close() + + records := session.DB("records").C("records") + + c.Logf("Connected! Testing the need for authentication...") + err = records.Find(nil).One(nil) + c.Assert(err, ErrorMatches, "unauthorized|not authorized .*") + + c.Logf("Authenticating...") + err = session.Login(cred) + c.Assert(err, IsNil) + c.Logf("Authenticated!") + + c.Logf("Connected! Testing the need for authentication...") + err = records.Find(nil).One(nil) + c.Assert(err, Equals, mgo.ErrNotFound) +} + +func (s *S) TestAuthPlainURL(c *C) { + if *plainFlag == "" { + c.Skip("no -plain") + } + c.Logf("Connecting to %s...", *plainFlag) + session, err := mgo.Dial(fmt.Sprintf("%s:%s@%s?authMechanism=PLAIN", url.QueryEscape(plainUser), url.QueryEscape(plainPass), *plainFlag)) + c.Assert(err, IsNil) + defer session.Close() + + c.Logf("Connected! Testing the need for authentication...") + err = session.DB("records").C("records").Find(nil).One(nil) + c.Assert(err, Equals, mgo.ErrNotFound) +} + +var ( + kerberosFlag = flag.Bool("kerberos", false, "Test Kerberos authentication (depends on custom environment)") + kerberosHost = "ldaptest.10gen.cc" + kerberosUser = "drivers@LDAPTEST.10GEN.CC" + + winKerberosPasswordEnv = "MGO_KERBEROS_PASSWORD" +) + +// Kerberos has its own suite because it talks to a remote server +// that is prepared to authenticate against a kerberos deployment. +type KerberosSuite struct{} + +var _ = Suite(&KerberosSuite{}) + +func (kerberosSuite *KerberosSuite) SetUpSuite(c *C) { + mgo.SetDebug(true) + mgo.SetStats(true) +} + +func (kerberosSuite *KerberosSuite) TearDownSuite(c *C) { + mgo.SetDebug(false) + mgo.SetStats(false) +} + +func (kerberosSuite *KerberosSuite) SetUpTest(c *C) { + mgo.SetLogger((*cLogger)(c)) + mgo.ResetStats() +} + +func (kerberosSuite *KerberosSuite) TearDownTest(c *C) { + mgo.SetLogger(nil) +} + +func (kerberosSuite *KerberosSuite) TestAuthKerberosCred(c *C) { + if !*kerberosFlag { + c.Skip("no -kerberos") + } + cred := &mgo.Credential{ + Username: kerberosUser, + Mechanism: "GSSAPI", + } + windowsAppendPasswordToCredential(cred) + c.Logf("Connecting to %s...", kerberosHost) + session, err := mgo.Dial(kerberosHost) + c.Assert(err, IsNil) + defer session.Close() + + c.Logf("Connected! Testing the need for authentication...") + n, err := session.DB("kerberos").C("test").Find(M{}).Count() + c.Assert(err, ErrorMatches, ".*authorized.*") + + c.Logf("Authenticating...") + err = session.Login(cred) + c.Assert(err, IsNil) + c.Logf("Authenticated!") + + n, err = session.DB("kerberos").C("test").Find(M{}).Count() + c.Assert(err, IsNil) + c.Assert(n, Equals, 1) +} + +func (kerberosSuite *KerberosSuite) TestAuthKerberosURL(c *C) { + if !*kerberosFlag { + c.Skip("no -kerberos") + } + c.Logf("Connecting to %s...", kerberosHost) + connectUri := url.QueryEscape(kerberosUser) + "@" + kerberosHost + "?authMechanism=GSSAPI" + if runtime.GOOS == "windows" { + connectUri = url.QueryEscape(kerberosUser) + ":" + url.QueryEscape(getWindowsKerberosPassword()) + "@" + kerberosHost + "?authMechanism=GSSAPI" + } + session, err := mgo.Dial(connectUri) + c.Assert(err, IsNil) + defer session.Close() + n, err := session.DB("kerberos").C("test").Find(M{}).Count() + c.Assert(err, IsNil) + c.Assert(n, Equals, 1) +} + +func (kerberosSuite *KerberosSuite) TestAuthKerberosServiceName(c *C) { + if !*kerberosFlag { + c.Skip("no -kerberos") + } + + wrongServiceName := "wrong" + rightServiceName := "mongodb" + + cred := &mgo.Credential{ + Username: kerberosUser, + Mechanism: "GSSAPI", + Service: wrongServiceName, + } + windowsAppendPasswordToCredential(cred) + + c.Logf("Connecting to %s...", kerberosHost) + session, err := mgo.Dial(kerberosHost) + c.Assert(err, IsNil) + defer session.Close() + + c.Logf("Authenticating with incorrect service name...") + err = session.Login(cred) + c.Assert(err, ErrorMatches, ".*@LDAPTEST.10GEN.CC not found.*") + + cred.Service = rightServiceName + c.Logf("Authenticating with correct service name...") + err = session.Login(cred) + c.Assert(err, IsNil) + c.Logf("Authenticated!") + + n, err := session.DB("kerberos").C("test").Find(M{}).Count() + c.Assert(err, IsNil) + c.Assert(n, Equals, 1) +} + +func (kerberosSuite *KerberosSuite) TestAuthKerberosServiceHost(c *C) { + if !*kerberosFlag { + c.Skip("no -kerberos") + } + + wrongServiceHost := "eggs.bacon.tk" + rightServiceHost := kerberosHost + + cred := &mgo.Credential{ + Username: kerberosUser, + Mechanism: "GSSAPI", + ServiceHost: wrongServiceHost, + } + windowsAppendPasswordToCredential(cred) + + c.Logf("Connecting to %s...", kerberosHost) + session, err := mgo.Dial(kerberosHost) + c.Assert(err, IsNil) + defer session.Close() + + c.Logf("Authenticating with incorrect service host...") + err = session.Login(cred) + c.Assert(err, ErrorMatches, ".*@LDAPTEST.10GEN.CC not found.*") + + cred.ServiceHost = rightServiceHost + c.Logf("Authenticating with correct service host...") + err = session.Login(cred) + c.Assert(err, IsNil) + c.Logf("Authenticated!") + + n, err := session.DB("kerberos").C("test").Find(M{}).Count() + c.Assert(err, IsNil) + c.Assert(n, Equals, 1) +} + +// No kinit on SSPI-style Kerberos, so we need to provide a password. In order +// to avoid inlining password, require it to be set as an environment variable, +// for instance: `SET MGO_KERBEROS_PASSWORD=this_isnt_the_password` +func getWindowsKerberosPassword() string { + pw := os.Getenv(winKerberosPasswordEnv) + if pw == "" { + panic(fmt.Sprintf("Need to set %v environment variable to run Kerberos tests on Windows", winKerberosPasswordEnv)) + } + return pw +} + +func windowsAppendPasswordToCredential(cred *mgo.Credential) { + if runtime.GOOS == "windows" { + cred.Password = getWindowsKerberosPassword() + } +} diff --git a/vendor/gopkg.in/mgo.v2/bson/LICENSE b/vendor/gopkg.in/mgo.v2/bson/LICENSE new file mode 100644 index 000000000..890326017 --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/bson/LICENSE @@ -0,0 +1,25 @@ +BSON library for Go + +Copyright (c) 2010-2012 - Gustavo Niemeyer + +All rights reserved. + +Redistribution and use in source and binary forms, with or without +modification, are permitted provided that the following conditions are met: + +1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. +2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + +THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +(INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +(INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. diff --git a/vendor/gopkg.in/mgo.v2/bson/bson.go b/vendor/gopkg.in/mgo.v2/bson/bson.go new file mode 100644 index 000000000..41816b874 --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/bson/bson.go @@ -0,0 +1,705 @@ +// BSON library for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Package bson is an implementation of the BSON specification for Go: +// +// http://bsonspec.org +// +// It was created as part of the mgo MongoDB driver for Go, but is standalone +// and may be used on its own without the driver. +package bson + +import ( + "bytes" + "crypto/md5" + "crypto/rand" + "encoding/binary" + "encoding/hex" + "errors" + "fmt" + "io" + "os" + "reflect" + "runtime" + "strings" + "sync" + "sync/atomic" + "time" +) + +// -------------------------------------------------------------------------- +// The public API. + +// A value implementing the bson.Getter interface will have its GetBSON +// method called when the given value has to be marshalled, and the result +// of this method will be marshaled in place of the actual object. +// +// If GetBSON returns return a non-nil error, the marshalling procedure +// will stop and error out with the provided value. +type Getter interface { + GetBSON() (interface{}, error) +} + +// A value implementing the bson.Setter interface will receive the BSON +// value via the SetBSON method during unmarshaling, and the object +// itself will not be changed as usual. +// +// If setting the value works, the method should return nil or alternatively +// bson.SetZero to set the respective field to its zero value (nil for +// pointer types). If SetBSON returns a value of type bson.TypeError, the +// BSON value will be omitted from a map or slice being decoded and the +// unmarshalling will continue. If it returns any other non-nil error, the +// unmarshalling procedure will stop and error out with the provided value. +// +// This interface is generally useful in pointer receivers, since the method +// will want to change the receiver. A type field that implements the Setter +// interface doesn't have to be a pointer, though. +// +// Unlike the usual behavior, unmarshalling onto a value that implements a +// Setter interface will NOT reset the value to its zero state. This allows +// the value to decide by itself how to be unmarshalled. +// +// For example: +// +// type MyString string +// +// func (s *MyString) SetBSON(raw bson.Raw) error { +// return raw.Unmarshal(s) +// } +// +type Setter interface { + SetBSON(raw Raw) error +} + +// SetZero may be returned from a SetBSON method to have the value set to +// its respective zero value. When used in pointer values, this will set the +// field to nil rather than to the pre-allocated value. +var SetZero = errors.New("set to zero") + +// M is a convenient alias for a map[string]interface{} map, useful for +// dealing with BSON in a native way. For instance: +// +// bson.M{"a": 1, "b": true} +// +// There's no special handling for this type in addition to what's done anyway +// for an equivalent map type. Elements in the map will be dumped in an +// undefined ordered. See also the bson.D type for an ordered alternative. +type M map[string]interface{} + +// D represents a BSON document containing ordered elements. For example: +// +// bson.D{{"a", 1}, {"b", true}} +// +// In some situations, such as when creating indexes for MongoDB, the order in +// which the elements are defined is important. If the order is not important, +// using a map is generally more comfortable. See bson.M and bson.RawD. +type D []DocElem + +// DocElem is an element of the bson.D document representation. +type DocElem struct { + Name string + Value interface{} +} + +// Map returns a map out of the ordered element name/value pairs in d. +func (d D) Map() (m M) { + m = make(M, len(d)) + for _, item := range d { + m[item.Name] = item.Value + } + return m +} + +// The Raw type represents raw unprocessed BSON documents and elements. +// Kind is the kind of element as defined per the BSON specification, and +// Data is the raw unprocessed data for the respective element. +// Using this type it is possible to unmarshal or marshal values partially. +// +// Relevant documentation: +// +// http://bsonspec.org/#/specification +// +type Raw struct { + Kind byte + Data []byte +} + +// RawD represents a BSON document containing raw unprocessed elements. +// This low-level representation may be useful when lazily processing +// documents of uncertain content, or when manipulating the raw content +// documents in general. +type RawD []RawDocElem + +// See the RawD type. +type RawDocElem struct { + Name string + Value Raw +} + +// ObjectId is a unique ID identifying a BSON value. It must be exactly 12 bytes +// long. MongoDB objects by default have such a property set in their "_id" +// property. +// +// http://www.mongodb.org/display/DOCS/Object+IDs +type ObjectId string + +// ObjectIdHex returns an ObjectId from the provided hex representation. +// Calling this function with an invalid hex representation will +// cause a runtime panic. See the IsObjectIdHex function. +func ObjectIdHex(s string) ObjectId { + d, err := hex.DecodeString(s) + if err != nil || len(d) != 12 { + panic(fmt.Sprintf("Invalid input to ObjectIdHex: %q", s)) + } + return ObjectId(d) +} + +// IsObjectIdHex returns whether s is a valid hex representation of +// an ObjectId. See the ObjectIdHex function. +func IsObjectIdHex(s string) bool { + if len(s) != 24 { + return false + } + _, err := hex.DecodeString(s) + return err == nil +} + +// objectIdCounter is atomically incremented when generating a new ObjectId +// using NewObjectId() function. It's used as a counter part of an id. +var objectIdCounter uint32 = 0 + +// machineId stores machine id generated once and used in subsequent calls +// to NewObjectId function. +var machineId = readMachineId() + +// readMachineId generates machine id and puts it into the machineId global +// variable. If this function fails to get the hostname, it will cause +// a runtime error. +func readMachineId() []byte { + var sum [3]byte + id := sum[:] + hostname, err1 := os.Hostname() + if err1 != nil { + _, err2 := io.ReadFull(rand.Reader, id) + if err2 != nil { + panic(fmt.Errorf("cannot get hostname: %v; %v", err1, err2)) + } + return id + } + hw := md5.New() + hw.Write([]byte(hostname)) + copy(id, hw.Sum(nil)) + return id +} + +// NewObjectId returns a new unique ObjectId. +func NewObjectId() ObjectId { + var b [12]byte + // Timestamp, 4 bytes, big endian + binary.BigEndian.PutUint32(b[:], uint32(time.Now().Unix())) + // Machine, first 3 bytes of md5(hostname) + b[4] = machineId[0] + b[5] = machineId[1] + b[6] = machineId[2] + // Pid, 2 bytes, specs don't specify endianness, but we use big endian. + pid := os.Getpid() + b[7] = byte(pid >> 8) + b[8] = byte(pid) + // Increment, 3 bytes, big endian + i := atomic.AddUint32(&objectIdCounter, 1) + b[9] = byte(i >> 16) + b[10] = byte(i >> 8) + b[11] = byte(i) + return ObjectId(b[:]) +} + +// NewObjectIdWithTime returns a dummy ObjectId with the timestamp part filled +// with the provided number of seconds from epoch UTC, and all other parts +// filled with zeroes. It's not safe to insert a document with an id generated +// by this method, it is useful only for queries to find documents with ids +// generated before or after the specified timestamp. +func NewObjectIdWithTime(t time.Time) ObjectId { + var b [12]byte + binary.BigEndian.PutUint32(b[:4], uint32(t.Unix())) + return ObjectId(string(b[:])) +} + +// String returns a hex string representation of the id. +// Example: ObjectIdHex("4d88e15b60f486e428412dc9"). +func (id ObjectId) String() string { + return fmt.Sprintf(`ObjectIdHex("%x")`, string(id)) +} + +// Hex returns a hex representation of the ObjectId. +func (id ObjectId) Hex() string { + return hex.EncodeToString([]byte(id)) +} + +// MarshalJSON turns a bson.ObjectId into a json.Marshaller. +func (id ObjectId) MarshalJSON() ([]byte, error) { + return []byte(fmt.Sprintf(`"%x"`, string(id))), nil +} + +var nullBytes = []byte("null") + +// UnmarshalJSON turns *bson.ObjectId into a json.Unmarshaller. +func (id *ObjectId) UnmarshalJSON(data []byte) error { + if len(data) == 2 && data[0] == '"' && data[1] == '"' || bytes.Equal(data, nullBytes) { + *id = "" + return nil + } + if len(data) != 26 || data[0] != '"' || data[25] != '"' { + return errors.New(fmt.Sprintf("Invalid ObjectId in JSON: %s", string(data))) + } + var buf [12]byte + _, err := hex.Decode(buf[:], data[1:25]) + if err != nil { + return errors.New(fmt.Sprintf("Invalid ObjectId in JSON: %s (%s)", string(data), err)) + } + *id = ObjectId(string(buf[:])) + return nil +} + +// Valid returns true if id is valid. A valid id must contain exactly 12 bytes. +func (id ObjectId) Valid() bool { + return len(id) == 12 +} + +// byteSlice returns byte slice of id from start to end. +// Calling this function with an invalid id will cause a runtime panic. +func (id ObjectId) byteSlice(start, end int) []byte { + if len(id) != 12 { + panic(fmt.Sprintf("Invalid ObjectId: %q", string(id))) + } + return []byte(string(id)[start:end]) +} + +// Time returns the timestamp part of the id. +// It's a runtime error to call this method with an invalid id. +func (id ObjectId) Time() time.Time { + // First 4 bytes of ObjectId is 32-bit big-endian seconds from epoch. + secs := int64(binary.BigEndian.Uint32(id.byteSlice(0, 4))) + return time.Unix(secs, 0) +} + +// Machine returns the 3-byte machine id part of the id. +// It's a runtime error to call this method with an invalid id. +func (id ObjectId) Machine() []byte { + return id.byteSlice(4, 7) +} + +// Pid returns the process id part of the id. +// It's a runtime error to call this method with an invalid id. +func (id ObjectId) Pid() uint16 { + return binary.BigEndian.Uint16(id.byteSlice(7, 9)) +} + +// Counter returns the incrementing value part of the id. +// It's a runtime error to call this method with an invalid id. +func (id ObjectId) Counter() int32 { + b := id.byteSlice(9, 12) + // Counter is stored as big-endian 3-byte value + return int32(uint32(b[0])<<16 | uint32(b[1])<<8 | uint32(b[2])) +} + +// The Symbol type is similar to a string and is used in languages with a +// distinct symbol type. +type Symbol string + +// Now returns the current time with millisecond precision. MongoDB stores +// timestamps with the same precision, so a Time returned from this method +// will not change after a roundtrip to the database. That's the only reason +// why this function exists. Using the time.Now function also works fine +// otherwise. +func Now() time.Time { + return time.Unix(0, time.Now().UnixNano()/1e6*1e6) +} + +// MongoTimestamp is a special internal type used by MongoDB that for some +// strange reason has its own datatype defined in BSON. +type MongoTimestamp int64 + +type orderKey int64 + +// MaxKey is a special value that compares higher than all other possible BSON +// values in a MongoDB database. +var MaxKey = orderKey(1<<63 - 1) + +// MinKey is a special value that compares lower than all other possible BSON +// values in a MongoDB database. +var MinKey = orderKey(-1 << 63) + +type undefined struct{} + +// Undefined represents the undefined BSON value. +var Undefined undefined + +// Binary is a representation for non-standard binary values. Any kind should +// work, but the following are known as of this writing: +// +// 0x00 - Generic. This is decoded as []byte(data), not Binary{0x00, data}. +// 0x01 - Function (!?) +// 0x02 - Obsolete generic. +// 0x03 - UUID +// 0x05 - MD5 +// 0x80 - User defined. +// +type Binary struct { + Kind byte + Data []byte +} + +// RegEx represents a regular expression. The Options field may contain +// individual characters defining the way in which the pattern should be +// applied, and must be sorted. Valid options as of this writing are 'i' for +// case insensitive matching, 'm' for multi-line matching, 'x' for verbose +// mode, 'l' to make \w, \W, and similar be locale-dependent, 's' for dot-all +// mode (a '.' matches everything), and 'u' to make \w, \W, and similar match +// unicode. The value of the Options parameter is not verified before being +// marshaled into the BSON format. +type RegEx struct { + Pattern string + Options string +} + +// JavaScript is a type that holds JavaScript code. If Scope is non-nil, it +// will be marshaled as a mapping from identifiers to values that may be +// used when evaluating the provided Code. +type JavaScript struct { + Code string + Scope interface{} +} + +// DBPointer refers to a document id in a namespace. +// +// This type is deprecated in the BSON specification and should not be used +// except for backwards compatibility with ancient applications. +type DBPointer struct { + Namespace string + Id ObjectId +} + +const initialBufferSize = 64 + +func handleErr(err *error) { + if r := recover(); r != nil { + if _, ok := r.(runtime.Error); ok { + panic(r) + } else if _, ok := r.(externalPanic); ok { + panic(r) + } else if s, ok := r.(string); ok { + *err = errors.New(s) + } else if e, ok := r.(error); ok { + *err = e + } else { + panic(r) + } + } +} + +// Marshal serializes the in value, which may be a map or a struct value. +// In the case of struct values, only exported fields will be serialized. +// The lowercased field name is used as the key for each exported field, +// but this behavior may be changed using the respective field tag. +// The tag may also contain flags to tweak the marshalling behavior for +// the field. The tag formats accepted are: +// +// "[][,[,]]" +// +// `(...) bson:"[][,[,]]" (...)` +// +// The following flags are currently supported: +// +// omitempty Only include the field if it's not set to the zero +// value for the type or to empty slices or maps. +// +// minsize Marshal an int64 value as an int32, if that's feasible +// while preserving the numeric value. +// +// inline Inline the field, which must be a struct or a map, +// causing all of its fields or keys to be processed as if +// they were part of the outer struct. For maps, keys must +// not conflict with the bson keys of other struct fields. +// +// Some examples: +// +// type T struct { +// A bool +// B int "myb" +// C string "myc,omitempty" +// D string `bson:",omitempty" json:"jsonkey"` +// E int64 ",minsize" +// F int64 "myf,omitempty,minsize" +// } +// +func Marshal(in interface{}) (out []byte, err error) { + defer handleErr(&err) + e := &encoder{make([]byte, 0, initialBufferSize)} + e.addDoc(reflect.ValueOf(in)) + return e.out, nil +} + +// Unmarshal deserializes data from in into the out value. The out value +// must be a map, a pointer to a struct, or a pointer to a bson.D value. +// The lowercased field name is used as the key for each exported field, +// but this behavior may be changed using the respective field tag. +// The tag may also contain flags to tweak the marshalling behavior for +// the field. The tag formats accepted are: +// +// "[][,[,]]" +// +// `(...) bson:"[][,[,]]" (...)` +// +// The following flags are currently supported during unmarshal (see the +// Marshal method for other flags): +// +// inline Inline the field, which must be a struct or a map. +// Inlined structs are handled as if its fields were part +// of the outer struct. An inlined map causes keys that do +// not match any other struct field to be inserted in the +// map rather than being discarded as usual. +// +// The target field or element types of out may not necessarily match +// the BSON values of the provided data. The following conversions are +// made automatically: +// +// - Numeric types are converted if at least the integer part of the +// value would be preserved correctly +// - Bools are converted to numeric types as 1 or 0 +// - Numeric types are converted to bools as true if not 0 or false otherwise +// - Binary and string BSON data is converted to a string, array or byte slice +// +// If the value would not fit the type and cannot be converted, it's +// silently skipped. +// +// Pointer values are initialized when necessary. +func Unmarshal(in []byte, out interface{}) (err error) { + if raw, ok := out.(*Raw); ok { + raw.Kind = 3 + raw.Data = in + return nil + } + defer handleErr(&err) + v := reflect.ValueOf(out) + switch v.Kind() { + case reflect.Ptr: + fallthrough + case reflect.Map: + d := newDecoder(in) + d.readDocTo(v) + case reflect.Struct: + return errors.New("Unmarshal can't deal with struct values. Use a pointer.") + default: + return errors.New("Unmarshal needs a map or a pointer to a struct.") + } + return nil +} + +// Unmarshal deserializes raw into the out value. If the out value type +// is not compatible with raw, a *bson.TypeError is returned. +// +// See the Unmarshal function documentation for more details on the +// unmarshalling process. +func (raw Raw) Unmarshal(out interface{}) (err error) { + defer handleErr(&err) + v := reflect.ValueOf(out) + switch v.Kind() { + case reflect.Ptr: + v = v.Elem() + fallthrough + case reflect.Map: + d := newDecoder(raw.Data) + good := d.readElemTo(v, raw.Kind) + if !good { + return &TypeError{v.Type(), raw.Kind} + } + case reflect.Struct: + return errors.New("Raw Unmarshal can't deal with struct values. Use a pointer.") + default: + return errors.New("Raw Unmarshal needs a map or a valid pointer.") + } + return nil +} + +type TypeError struct { + Type reflect.Type + Kind byte +} + +func (e *TypeError) Error() string { + return fmt.Sprintf("BSON kind 0x%02x isn't compatible with type %s", e.Kind, e.Type.String()) +} + +// -------------------------------------------------------------------------- +// Maintain a mapping of keys to structure field indexes + +type structInfo struct { + FieldsMap map[string]fieldInfo + FieldsList []fieldInfo + InlineMap int + Zero reflect.Value +} + +type fieldInfo struct { + Key string + Num int + OmitEmpty bool + MinSize bool + Inline []int +} + +var structMap = make(map[reflect.Type]*structInfo) +var structMapMutex sync.RWMutex + +type externalPanic string + +func (e externalPanic) String() string { + return string(e) +} + +func getStructInfo(st reflect.Type) (*structInfo, error) { + structMapMutex.RLock() + sinfo, found := structMap[st] + structMapMutex.RUnlock() + if found { + return sinfo, nil + } + n := st.NumField() + fieldsMap := make(map[string]fieldInfo) + fieldsList := make([]fieldInfo, 0, n) + inlineMap := -1 + for i := 0; i != n; i++ { + field := st.Field(i) + if field.PkgPath != "" { + continue // Private field + } + + info := fieldInfo{Num: i} + + tag := field.Tag.Get("bson") + if tag == "" && strings.Index(string(field.Tag), ":") < 0 { + tag = string(field.Tag) + } + if tag == "-" { + continue + } + + // XXX Drop this after a few releases. + if s := strings.Index(tag, "/"); s >= 0 { + recommend := tag[:s] + for _, c := range tag[s+1:] { + switch c { + case 'c': + recommend += ",omitempty" + case 's': + recommend += ",minsize" + default: + msg := fmt.Sprintf("Unsupported flag %q in tag %q of type %s", string([]byte{uint8(c)}), tag, st) + panic(externalPanic(msg)) + } + } + msg := fmt.Sprintf("Replace tag %q in field %s of type %s by %q", tag, field.Name, st, recommend) + panic(externalPanic(msg)) + } + + inline := false + fields := strings.Split(tag, ",") + if len(fields) > 1 { + for _, flag := range fields[1:] { + switch flag { + case "omitempty": + info.OmitEmpty = true + case "minsize": + info.MinSize = true + case "inline": + inline = true + default: + msg := fmt.Sprintf("Unsupported flag %q in tag %q of type %s", flag, tag, st) + panic(externalPanic(msg)) + } + } + tag = fields[0] + } + + if inline { + switch field.Type.Kind() { + case reflect.Map: + if inlineMap >= 0 { + return nil, errors.New("Multiple ,inline maps in struct " + st.String()) + } + if field.Type.Key() != reflect.TypeOf("") { + return nil, errors.New("Option ,inline needs a map with string keys in struct " + st.String()) + } + inlineMap = info.Num + case reflect.Struct: + sinfo, err := getStructInfo(field.Type) + if err != nil { + return nil, err + } + for _, finfo := range sinfo.FieldsList { + if _, found := fieldsMap[finfo.Key]; found { + msg := "Duplicated key '" + finfo.Key + "' in struct " + st.String() + return nil, errors.New(msg) + } + if finfo.Inline == nil { + finfo.Inline = []int{i, finfo.Num} + } else { + finfo.Inline = append([]int{i}, finfo.Inline...) + } + fieldsMap[finfo.Key] = finfo + fieldsList = append(fieldsList, finfo) + } + default: + panic("Option ,inline needs a struct value or map field") + } + continue + } + + if tag != "" { + info.Key = tag + } else { + info.Key = strings.ToLower(field.Name) + } + + if _, found = fieldsMap[info.Key]; found { + msg := "Duplicated key '" + info.Key + "' in struct " + st.String() + return nil, errors.New(msg) + } + + fieldsList = append(fieldsList, info) + fieldsMap[info.Key] = info + } + sinfo = &structInfo{ + fieldsMap, + fieldsList, + inlineMap, + reflect.New(st).Elem(), + } + structMapMutex.Lock() + structMap[st] = sinfo + structMapMutex.Unlock() + return sinfo, nil +} diff --git a/vendor/gopkg.in/mgo.v2/bson/bson_test.go b/vendor/gopkg.in/mgo.v2/bson/bson_test.go new file mode 100644 index 000000000..eb2e9f41e --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/bson/bson_test.go @@ -0,0 +1,1605 @@ +// BSON library for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// gobson - BSON library for Go. + +package bson_test + +import ( + "encoding/binary" + "encoding/json" + "errors" + "net/url" + "reflect" + "testing" + "time" + + . "gopkg.in/check.v1" + "gopkg.in/mgo.v2/bson" +) + +func TestAll(t *testing.T) { + TestingT(t) +} + +type S struct{} + +var _ = Suite(&S{}) + +// Wrap up the document elements contained in data, prepending the int32 +// length of the data, and appending the '\x00' value closing the document. +func wrapInDoc(data string) string { + result := make([]byte, len(data)+5) + binary.LittleEndian.PutUint32(result, uint32(len(result))) + copy(result[4:], []byte(data)) + return string(result) +} + +func makeZeroDoc(value interface{}) (zero interface{}) { + v := reflect.ValueOf(value) + t := v.Type() + switch t.Kind() { + case reflect.Map: + mv := reflect.MakeMap(t) + zero = mv.Interface() + case reflect.Ptr: + pv := reflect.New(v.Type().Elem()) + zero = pv.Interface() + case reflect.Slice, reflect.Int: + zero = reflect.New(t).Interface() + default: + panic("unsupported doc type") + } + return zero +} + +func testUnmarshal(c *C, data string, obj interface{}) { + zero := makeZeroDoc(obj) + err := bson.Unmarshal([]byte(data), zero) + c.Assert(err, IsNil) + c.Assert(zero, DeepEquals, obj) +} + +type testItemType struct { + obj interface{} + data string +} + +// -------------------------------------------------------------------------- +// Samples from bsonspec.org: + +var sampleItems = []testItemType{ + {bson.M{"hello": "world"}, + "\x16\x00\x00\x00\x02hello\x00\x06\x00\x00\x00world\x00\x00"}, + {bson.M{"BSON": []interface{}{"awesome", float64(5.05), 1986}}, + "1\x00\x00\x00\x04BSON\x00&\x00\x00\x00\x020\x00\x08\x00\x00\x00" + + "awesome\x00\x011\x00333333\x14@\x102\x00\xc2\x07\x00\x00\x00\x00"}, +} + +func (s *S) TestMarshalSampleItems(c *C) { + for i, item := range sampleItems { + data, err := bson.Marshal(item.obj) + c.Assert(err, IsNil) + c.Assert(string(data), Equals, item.data, Commentf("Failed on item %d", i)) + } +} + +func (s *S) TestUnmarshalSampleItems(c *C) { + for i, item := range sampleItems { + value := bson.M{} + err := bson.Unmarshal([]byte(item.data), value) + c.Assert(err, IsNil) + c.Assert(value, DeepEquals, item.obj, Commentf("Failed on item %d", i)) + } +} + +// -------------------------------------------------------------------------- +// Every type, ordered by the type flag. These are not wrapped with the +// length and last \x00 from the document. wrapInDoc() computes them. +// Note that all of them should be supported as two-way conversions. + +var allItems = []testItemType{ + {bson.M{}, + ""}, + {bson.M{"_": float64(5.05)}, + "\x01_\x00333333\x14@"}, + {bson.M{"_": "yo"}, + "\x02_\x00\x03\x00\x00\x00yo\x00"}, + {bson.M{"_": bson.M{"a": true}}, + "\x03_\x00\x09\x00\x00\x00\x08a\x00\x01\x00"}, + {bson.M{"_": []interface{}{true, false}}, + "\x04_\x00\r\x00\x00\x00\x080\x00\x01\x081\x00\x00\x00"}, + {bson.M{"_": []byte("yo")}, + "\x05_\x00\x02\x00\x00\x00\x00yo"}, + {bson.M{"_": bson.Binary{0x80, []byte("udef")}}, + "\x05_\x00\x04\x00\x00\x00\x80udef"}, + {bson.M{"_": bson.Undefined}, // Obsolete, but still seen in the wild. + "\x06_\x00"}, + {bson.M{"_": bson.ObjectId("0123456789ab")}, + "\x07_\x000123456789ab"}, + {bson.M{"_": bson.DBPointer{"testnamespace", bson.ObjectId("0123456789ab")}}, + "\x0C_\x00\x0e\x00\x00\x00testnamespace\x000123456789ab"}, + {bson.M{"_": false}, + "\x08_\x00\x00"}, + {bson.M{"_": true}, + "\x08_\x00\x01"}, + {bson.M{"_": time.Unix(0, 258e6)}, // Note the NS <=> MS conversion. + "\x09_\x00\x02\x01\x00\x00\x00\x00\x00\x00"}, + {bson.M{"_": nil}, + "\x0A_\x00"}, + {bson.M{"_": bson.RegEx{"ab", "cd"}}, + "\x0B_\x00ab\x00cd\x00"}, + {bson.M{"_": bson.JavaScript{"code", nil}}, + "\x0D_\x00\x05\x00\x00\x00code\x00"}, + {bson.M{"_": bson.Symbol("sym")}, + "\x0E_\x00\x04\x00\x00\x00sym\x00"}, + {bson.M{"_": bson.JavaScript{"code", bson.M{"": nil}}}, + "\x0F_\x00\x14\x00\x00\x00\x05\x00\x00\x00code\x00" + + "\x07\x00\x00\x00\x0A\x00\x00"}, + {bson.M{"_": 258}, + "\x10_\x00\x02\x01\x00\x00"}, + {bson.M{"_": bson.MongoTimestamp(258)}, + "\x11_\x00\x02\x01\x00\x00\x00\x00\x00\x00"}, + {bson.M{"_": int64(258)}, + "\x12_\x00\x02\x01\x00\x00\x00\x00\x00\x00"}, + {bson.M{"_": int64(258 << 32)}, + "\x12_\x00\x00\x00\x00\x00\x02\x01\x00\x00"}, + {bson.M{"_": bson.MaxKey}, + "\x7F_\x00"}, + {bson.M{"_": bson.MinKey}, + "\xFF_\x00"}, +} + +func (s *S) TestMarshalAllItems(c *C) { + for i, item := range allItems { + data, err := bson.Marshal(item.obj) + c.Assert(err, IsNil) + c.Assert(string(data), Equals, wrapInDoc(item.data), Commentf("Failed on item %d: %#v", i, item)) + } +} + +func (s *S) TestUnmarshalAllItems(c *C) { + for i, item := range allItems { + value := bson.M{} + err := bson.Unmarshal([]byte(wrapInDoc(item.data)), value) + c.Assert(err, IsNil) + c.Assert(value, DeepEquals, item.obj, Commentf("Failed on item %d: %#v", i, item)) + } +} + +func (s *S) TestUnmarshalRawAllItems(c *C) { + for i, item := range allItems { + if len(item.data) == 0 { + continue + } + value := item.obj.(bson.M)["_"] + if value == nil { + continue + } + pv := reflect.New(reflect.ValueOf(value).Type()) + raw := bson.Raw{item.data[0], []byte(item.data[3:])} + c.Logf("Unmarshal raw: %#v, %#v", raw, pv.Interface()) + err := raw.Unmarshal(pv.Interface()) + c.Assert(err, IsNil) + c.Assert(pv.Elem().Interface(), DeepEquals, value, Commentf("Failed on item %d: %#v", i, item)) + } +} + +func (s *S) TestUnmarshalRawIncompatible(c *C) { + raw := bson.Raw{0x08, []byte{0x01}} // true + err := raw.Unmarshal(&struct{}{}) + c.Assert(err, ErrorMatches, "BSON kind 0x08 isn't compatible with type struct \\{\\}") +} + +func (s *S) TestUnmarshalZeroesStruct(c *C) { + data, err := bson.Marshal(bson.M{"b": 2}) + c.Assert(err, IsNil) + type T struct{ A, B int } + v := T{A: 1} + err = bson.Unmarshal(data, &v) + c.Assert(err, IsNil) + c.Assert(v.A, Equals, 0) + c.Assert(v.B, Equals, 2) +} + +func (s *S) TestUnmarshalZeroesMap(c *C) { + data, err := bson.Marshal(bson.M{"b": 2}) + c.Assert(err, IsNil) + m := bson.M{"a": 1} + err = bson.Unmarshal(data, &m) + c.Assert(err, IsNil) + c.Assert(m, DeepEquals, bson.M{"b": 2}) +} + +func (s *S) TestUnmarshalNonNilInterface(c *C) { + data, err := bson.Marshal(bson.M{"b": 2}) + c.Assert(err, IsNil) + m := bson.M{"a": 1} + var i interface{} + i = m + err = bson.Unmarshal(data, &i) + c.Assert(err, IsNil) + c.Assert(i, DeepEquals, bson.M{"b": 2}) + c.Assert(m, DeepEquals, bson.M{"a": 1}) +} + +// -------------------------------------------------------------------------- +// Some one way marshaling operations which would unmarshal differently. + +var oneWayMarshalItems = []testItemType{ + // These are being passed as pointers, and will unmarshal as values. + {bson.M{"": &bson.Binary{0x02, []byte("old")}}, + "\x05\x00\x07\x00\x00\x00\x02\x03\x00\x00\x00old"}, + {bson.M{"": &bson.Binary{0x80, []byte("udef")}}, + "\x05\x00\x04\x00\x00\x00\x80udef"}, + {bson.M{"": &bson.RegEx{"ab", "cd"}}, + "\x0B\x00ab\x00cd\x00"}, + {bson.M{"": &bson.JavaScript{"code", nil}}, + "\x0D\x00\x05\x00\x00\x00code\x00"}, + {bson.M{"": &bson.JavaScript{"code", bson.M{"": nil}}}, + "\x0F\x00\x14\x00\x00\x00\x05\x00\x00\x00code\x00" + + "\x07\x00\x00\x00\x0A\x00\x00"}, + + // There's no float32 type in BSON. Will encode as a float64. + {bson.M{"": float32(5.05)}, + "\x01\x00\x00\x00\x00@33\x14@"}, + + // The array will be unmarshaled as a slice instead. + {bson.M{"": [2]bool{true, false}}, + "\x04\x00\r\x00\x00\x00\x080\x00\x01\x081\x00\x00\x00"}, + + // The typed slice will be unmarshaled as []interface{}. + {bson.M{"": []bool{true, false}}, + "\x04\x00\r\x00\x00\x00\x080\x00\x01\x081\x00\x00\x00"}, + + // Will unmarshal as a []byte. + {bson.M{"": bson.Binary{0x00, []byte("yo")}}, + "\x05\x00\x02\x00\x00\x00\x00yo"}, + {bson.M{"": bson.Binary{0x02, []byte("old")}}, + "\x05\x00\x07\x00\x00\x00\x02\x03\x00\x00\x00old"}, + + // No way to preserve the type information here. We might encode as a zero + // value, but this would mean that pointer values in structs wouldn't be + // able to correctly distinguish between unset and set to the zero value. + {bson.M{"": (*byte)(nil)}, + "\x0A\x00"}, + + // No int types smaller than int32 in BSON. Could encode this as a char, + // but it would still be ambiguous, take more, and be awkward in Go when + // loaded without typing information. + {bson.M{"": byte(8)}, + "\x10\x00\x08\x00\x00\x00"}, + + // There are no unsigned types in BSON. Will unmarshal as int32 or int64. + {bson.M{"": uint32(258)}, + "\x10\x00\x02\x01\x00\x00"}, + {bson.M{"": uint64(258)}, + "\x12\x00\x02\x01\x00\x00\x00\x00\x00\x00"}, + {bson.M{"": uint64(258 << 32)}, + "\x12\x00\x00\x00\x00\x00\x02\x01\x00\x00"}, + + // This will unmarshal as int. + {bson.M{"": int32(258)}, + "\x10\x00\x02\x01\x00\x00"}, + + // That's a special case. The unsigned value is too large for an int32, + // so an int64 is used instead. + {bson.M{"": uint32(1<<32 - 1)}, + "\x12\x00\xFF\xFF\xFF\xFF\x00\x00\x00\x00"}, + {bson.M{"": uint(1<<32 - 1)}, + "\x12\x00\xFF\xFF\xFF\xFF\x00\x00\x00\x00"}, +} + +func (s *S) TestOneWayMarshalItems(c *C) { + for i, item := range oneWayMarshalItems { + data, err := bson.Marshal(item.obj) + c.Assert(err, IsNil) + c.Assert(string(data), Equals, wrapInDoc(item.data), + Commentf("Failed on item %d", i)) + } +} + +// -------------------------------------------------------------------------- +// Two-way tests for user-defined structures using the samples +// from bsonspec.org. + +type specSample1 struct { + Hello string +} + +type specSample2 struct { + BSON []interface{} "BSON" +} + +var structSampleItems = []testItemType{ + {&specSample1{"world"}, + "\x16\x00\x00\x00\x02hello\x00\x06\x00\x00\x00world\x00\x00"}, + {&specSample2{[]interface{}{"awesome", float64(5.05), 1986}}, + "1\x00\x00\x00\x04BSON\x00&\x00\x00\x00\x020\x00\x08\x00\x00\x00" + + "awesome\x00\x011\x00333333\x14@\x102\x00\xc2\x07\x00\x00\x00\x00"}, +} + +func (s *S) TestMarshalStructSampleItems(c *C) { + for i, item := range structSampleItems { + data, err := bson.Marshal(item.obj) + c.Assert(err, IsNil) + c.Assert(string(data), Equals, item.data, + Commentf("Failed on item %d", i)) + } +} + +func (s *S) TestUnmarshalStructSampleItems(c *C) { + for _, item := range structSampleItems { + testUnmarshal(c, item.data, item.obj) + } +} + +func (s *S) Test64bitInt(c *C) { + var i int64 = (1 << 31) + if int(i) > 0 { + data, err := bson.Marshal(bson.M{"i": int(i)}) + c.Assert(err, IsNil) + c.Assert(string(data), Equals, wrapInDoc("\x12i\x00\x00\x00\x00\x80\x00\x00\x00\x00")) + + var result struct{ I int } + err = bson.Unmarshal(data, &result) + c.Assert(err, IsNil) + c.Assert(int64(result.I), Equals, i) + } +} + +// -------------------------------------------------------------------------- +// Generic two-way struct marshaling tests. + +var bytevar = byte(8) +var byteptr = &bytevar + +var structItems = []testItemType{ + {&struct{ Ptr *byte }{nil}, + "\x0Aptr\x00"}, + {&struct{ Ptr *byte }{&bytevar}, + "\x10ptr\x00\x08\x00\x00\x00"}, + {&struct{ Ptr **byte }{&byteptr}, + "\x10ptr\x00\x08\x00\x00\x00"}, + {&struct{ Byte byte }{8}, + "\x10byte\x00\x08\x00\x00\x00"}, + {&struct{ Byte byte }{0}, + "\x10byte\x00\x00\x00\x00\x00"}, + {&struct { + V byte "Tag" + }{8}, + "\x10Tag\x00\x08\x00\x00\x00"}, + {&struct { + V *struct { + Byte byte + } + }{&struct{ Byte byte }{8}}, + "\x03v\x00" + "\x0f\x00\x00\x00\x10byte\x00\b\x00\x00\x00\x00"}, + {&struct{ priv byte }{}, ""}, + + // The order of the dumped fields should be the same in the struct. + {&struct{ A, C, B, D, F, E *byte }{}, + "\x0Aa\x00\x0Ac\x00\x0Ab\x00\x0Ad\x00\x0Af\x00\x0Ae\x00"}, + + {&struct{ V bson.Raw }{bson.Raw{0x03, []byte("\x0f\x00\x00\x00\x10byte\x00\b\x00\x00\x00\x00")}}, + "\x03v\x00" + "\x0f\x00\x00\x00\x10byte\x00\b\x00\x00\x00\x00"}, + {&struct{ V bson.Raw }{bson.Raw{0x10, []byte("\x00\x00\x00\x00")}}, + "\x10v\x00" + "\x00\x00\x00\x00"}, + + // Byte arrays. + {&struct{ V [2]byte }{[2]byte{'y', 'o'}}, + "\x05v\x00\x02\x00\x00\x00\x00yo"}, +} + +func (s *S) TestMarshalStructItems(c *C) { + for i, item := range structItems { + data, err := bson.Marshal(item.obj) + c.Assert(err, IsNil) + c.Assert(string(data), Equals, wrapInDoc(item.data), + Commentf("Failed on item %d", i)) + } +} + +func (s *S) TestUnmarshalStructItems(c *C) { + for _, item := range structItems { + testUnmarshal(c, wrapInDoc(item.data), item.obj) + } +} + +func (s *S) TestUnmarshalRawStructItems(c *C) { + for i, item := range structItems { + raw := bson.Raw{0x03, []byte(wrapInDoc(item.data))} + zero := makeZeroDoc(item.obj) + err := raw.Unmarshal(zero) + c.Assert(err, IsNil) + c.Assert(zero, DeepEquals, item.obj, Commentf("Failed on item %d: %#v", i, item)) + } +} + +func (s *S) TestUnmarshalRawNil(c *C) { + // Regression test: shouldn't try to nil out the pointer itself, + // as it's not settable. + raw := bson.Raw{0x0A, []byte{}} + err := raw.Unmarshal(&struct{}{}) + c.Assert(err, IsNil) +} + +// -------------------------------------------------------------------------- +// One-way marshaling tests. + +type dOnIface struct { + D interface{} +} + +type ignoreField struct { + Before string + Ignore string `bson:"-"` + After string +} + +var marshalItems = []testItemType{ + // Ordered document dump. Will unmarshal as a dictionary by default. + {bson.D{{"a", nil}, {"c", nil}, {"b", nil}, {"d", nil}, {"f", nil}, {"e", true}}, + "\x0Aa\x00\x0Ac\x00\x0Ab\x00\x0Ad\x00\x0Af\x00\x08e\x00\x01"}, + {MyD{{"a", nil}, {"c", nil}, {"b", nil}, {"d", nil}, {"f", nil}, {"e", true}}, + "\x0Aa\x00\x0Ac\x00\x0Ab\x00\x0Ad\x00\x0Af\x00\x08e\x00\x01"}, + {&dOnIface{bson.D{{"a", nil}, {"c", nil}, {"b", nil}, {"d", true}}}, + "\x03d\x00" + wrapInDoc("\x0Aa\x00\x0Ac\x00\x0Ab\x00\x08d\x00\x01")}, + + {bson.RawD{{"a", bson.Raw{0x0A, nil}}, {"c", bson.Raw{0x0A, nil}}, {"b", bson.Raw{0x08, []byte{0x01}}}}, + "\x0Aa\x00" + "\x0Ac\x00" + "\x08b\x00\x01"}, + {MyRawD{{"a", bson.Raw{0x0A, nil}}, {"c", bson.Raw{0x0A, nil}}, {"b", bson.Raw{0x08, []byte{0x01}}}}, + "\x0Aa\x00" + "\x0Ac\x00" + "\x08b\x00\x01"}, + {&dOnIface{bson.RawD{{"a", bson.Raw{0x0A, nil}}, {"c", bson.Raw{0x0A, nil}}, {"b", bson.Raw{0x08, []byte{0x01}}}}}, + "\x03d\x00" + wrapInDoc("\x0Aa\x00"+"\x0Ac\x00"+"\x08b\x00\x01")}, + + {&ignoreField{"before", "ignore", "after"}, + "\x02before\x00\a\x00\x00\x00before\x00\x02after\x00\x06\x00\x00\x00after\x00"}, + + // Marshalling a Raw document does nothing. + {bson.Raw{0x03, []byte(wrapInDoc("anything"))}, + "anything"}, + {bson.Raw{Data: []byte(wrapInDoc("anything"))}, + "anything"}, +} + +func (s *S) TestMarshalOneWayItems(c *C) { + for _, item := range marshalItems { + data, err := bson.Marshal(item.obj) + c.Assert(err, IsNil) + c.Assert(string(data), Equals, wrapInDoc(item.data)) + } +} + +// -------------------------------------------------------------------------- +// One-way unmarshaling tests. + +var unmarshalItems = []testItemType{ + // Field is private. Should not attempt to unmarshal it. + {&struct{ priv byte }{}, + "\x10priv\x00\x08\x00\x00\x00"}, + + // Wrong casing. Field names are lowercased. + {&struct{ Byte byte }{}, + "\x10Byte\x00\x08\x00\x00\x00"}, + + // Ignore non-existing field. + {&struct{ Byte byte }{9}, + "\x10boot\x00\x08\x00\x00\x00" + "\x10byte\x00\x09\x00\x00\x00"}, + + // Do not unmarshal on ignored field. + {&ignoreField{"before", "", "after"}, + "\x02before\x00\a\x00\x00\x00before\x00" + + "\x02-\x00\a\x00\x00\x00ignore\x00" + + "\x02after\x00\x06\x00\x00\x00after\x00"}, + + // Ignore unsuitable types silently. + {map[string]string{"str": "s"}, + "\x02str\x00\x02\x00\x00\x00s\x00" + "\x10int\x00\x01\x00\x00\x00"}, + {map[string][]int{"array": []int{5, 9}}, + "\x04array\x00" + wrapInDoc("\x100\x00\x05\x00\x00\x00"+"\x021\x00\x02\x00\x00\x00s\x00"+"\x102\x00\x09\x00\x00\x00")}, + + // Wrong type. Shouldn't init pointer. + {&struct{ Str *byte }{}, + "\x02str\x00\x02\x00\x00\x00s\x00"}, + {&struct{ Str *struct{ Str string } }{}, + "\x02str\x00\x02\x00\x00\x00s\x00"}, + + // Ordered document. + {&struct{ bson.D }{bson.D{{"a", nil}, {"c", nil}, {"b", nil}, {"d", true}}}, + "\x03d\x00" + wrapInDoc("\x0Aa\x00\x0Ac\x00\x0Ab\x00\x08d\x00\x01")}, + + // Raw document. + {&bson.Raw{0x03, []byte(wrapInDoc("\x10byte\x00\x08\x00\x00\x00"))}, + "\x10byte\x00\x08\x00\x00\x00"}, + + // RawD document. + {&struct{ bson.RawD }{bson.RawD{{"a", bson.Raw{0x0A, []byte{}}}, {"c", bson.Raw{0x0A, []byte{}}}, {"b", bson.Raw{0x08, []byte{0x01}}}}}, + "\x03rawd\x00" + wrapInDoc("\x0Aa\x00\x0Ac\x00\x08b\x00\x01")}, + + // Decode old binary. + {bson.M{"_": []byte("old")}, + "\x05_\x00\x07\x00\x00\x00\x02\x03\x00\x00\x00old"}, + + // Decode old binary without length. According to the spec, this shouldn't happen. + {bson.M{"_": []byte("old")}, + "\x05_\x00\x03\x00\x00\x00\x02old"}, +} + +func (s *S) TestUnmarshalOneWayItems(c *C) { + for _, item := range unmarshalItems { + testUnmarshal(c, wrapInDoc(item.data), item.obj) + } +} + +func (s *S) TestUnmarshalNilInStruct(c *C) { + // Nil is the default value, so we need to ensure it's indeed being set. + b := byte(1) + v := &struct{ Ptr *byte }{&b} + err := bson.Unmarshal([]byte(wrapInDoc("\x0Aptr\x00")), v) + c.Assert(err, IsNil) + c.Assert(v, DeepEquals, &struct{ Ptr *byte }{nil}) +} + +// -------------------------------------------------------------------------- +// Marshalling error cases. + +type structWithDupKeys struct { + Name byte + Other byte "name" // Tag should precede. +} + +var marshalErrorItems = []testItemType{ + {bson.M{"": uint64(1 << 63)}, + "BSON has no uint64 type, and value is too large to fit correctly in an int64"}, + {bson.M{"": bson.ObjectId("tooshort")}, + "ObjectIDs must be exactly 12 bytes long \\(got 8\\)"}, + {int64(123), + "Can't marshal int64 as a BSON document"}, + {bson.M{"": 1i}, + "Can't marshal complex128 in a BSON document"}, + {&structWithDupKeys{}, + "Duplicated key 'name' in struct bson_test.structWithDupKeys"}, + {bson.Raw{0x0A, []byte{}}, + "Attempted to unmarshal Raw kind 10 as a document"}, + {&inlineCantPtr{&struct{ A, B int }{1, 2}}, + "Option ,inline needs a struct value or map field"}, + {&inlineDupName{1, struct{ A, B int }{2, 3}}, + "Duplicated key 'a' in struct bson_test.inlineDupName"}, + {&inlineDupMap{}, + "Multiple ,inline maps in struct bson_test.inlineDupMap"}, + {&inlineBadKeyMap{}, + "Option ,inline needs a map with string keys in struct bson_test.inlineBadKeyMap"}, + {&inlineMap{A: 1, M: map[string]interface{}{"a": 1}}, + `Can't have key "a" in inlined map; conflicts with struct field`}, +} + +func (s *S) TestMarshalErrorItems(c *C) { + for _, item := range marshalErrorItems { + data, err := bson.Marshal(item.obj) + c.Assert(err, ErrorMatches, item.data) + c.Assert(data, IsNil) + } +} + +// -------------------------------------------------------------------------- +// Unmarshalling error cases. + +type unmarshalErrorType struct { + obj interface{} + data string + error string +} + +var unmarshalErrorItems = []unmarshalErrorType{ + // Tag name conflicts with existing parameter. + {&structWithDupKeys{}, + "\x10name\x00\x08\x00\x00\x00", + "Duplicated key 'name' in struct bson_test.structWithDupKeys"}, + + // Non-string map key. + {map[int]interface{}{}, + "\x10name\x00\x08\x00\x00\x00", + "BSON map must have string keys. Got: map\\[int\\]interface \\{\\}"}, + + {nil, + "\xEEname\x00", + "Unknown element kind \\(0xEE\\)"}, + + {struct{ Name bool }{}, + "\x10name\x00\x08\x00\x00\x00", + "Unmarshal can't deal with struct values. Use a pointer."}, + + {123, + "\x10name\x00\x08\x00\x00\x00", + "Unmarshal needs a map or a pointer to a struct."}, +} + +func (s *S) TestUnmarshalErrorItems(c *C) { + for _, item := range unmarshalErrorItems { + data := []byte(wrapInDoc(item.data)) + var value interface{} + switch reflect.ValueOf(item.obj).Kind() { + case reflect.Map, reflect.Ptr: + value = makeZeroDoc(item.obj) + case reflect.Invalid: + value = bson.M{} + default: + value = item.obj + } + err := bson.Unmarshal(data, value) + c.Assert(err, ErrorMatches, item.error) + } +} + +type unmarshalRawErrorType struct { + obj interface{} + raw bson.Raw + error string +} + +var unmarshalRawErrorItems = []unmarshalRawErrorType{ + // Tag name conflicts with existing parameter. + {&structWithDupKeys{}, + bson.Raw{0x03, []byte("\x10byte\x00\x08\x00\x00\x00")}, + "Duplicated key 'name' in struct bson_test.structWithDupKeys"}, + + {&struct{}{}, + bson.Raw{0xEE, []byte{}}, + "Unknown element kind \\(0xEE\\)"}, + + {struct{ Name bool }{}, + bson.Raw{0x10, []byte("\x08\x00\x00\x00")}, + "Raw Unmarshal can't deal with struct values. Use a pointer."}, + + {123, + bson.Raw{0x10, []byte("\x08\x00\x00\x00")}, + "Raw Unmarshal needs a map or a valid pointer."}, +} + +func (s *S) TestUnmarshalRawErrorItems(c *C) { + for i, item := range unmarshalRawErrorItems { + err := item.raw.Unmarshal(item.obj) + c.Assert(err, ErrorMatches, item.error, Commentf("Failed on item %d: %#v\n", i, item)) + } +} + +var corruptedData = []string{ + "\x04\x00\x00\x00\x00", // Shorter than minimum + "\x06\x00\x00\x00\x00", // Not enough data + "\x05\x00\x00", // Broken length + "\x05\x00\x00\x00\xff", // Corrupted termination + "\x0A\x00\x00\x00\x0Aooop\x00", // Unfinished C string + + // Array end past end of string (s[2]=0x07 is correct) + wrapInDoc("\x04\x00\x09\x00\x00\x00\x0A\x00\x00"), + + // Array end within string, but past acceptable. + wrapInDoc("\x04\x00\x08\x00\x00\x00\x0A\x00\x00"), + + // Document end within string, but past acceptable. + wrapInDoc("\x03\x00\x08\x00\x00\x00\x0A\x00\x00"), + + // String with corrupted end. + wrapInDoc("\x02\x00\x03\x00\x00\x00yo\xFF"), +} + +func (s *S) TestUnmarshalMapDocumentTooShort(c *C) { + for _, data := range corruptedData { + err := bson.Unmarshal([]byte(data), bson.M{}) + c.Assert(err, ErrorMatches, "Document is corrupted") + + err = bson.Unmarshal([]byte(data), &struct{}{}) + c.Assert(err, ErrorMatches, "Document is corrupted") + } +} + +// -------------------------------------------------------------------------- +// Setter test cases. + +var setterResult = map[string]error{} + +type setterType struct { + received interface{} +} + +func (o *setterType) SetBSON(raw bson.Raw) error { + err := raw.Unmarshal(&o.received) + if err != nil { + panic("The panic:" + err.Error()) + } + if s, ok := o.received.(string); ok { + if result, ok := setterResult[s]; ok { + return result + } + } + return nil +} + +type ptrSetterDoc struct { + Field *setterType "_" +} + +type valSetterDoc struct { + Field setterType "_" +} + +func (s *S) TestUnmarshalAllItemsWithPtrSetter(c *C) { + for _, item := range allItems { + for i := 0; i != 2; i++ { + var field *setterType + if i == 0 { + obj := &ptrSetterDoc{} + err := bson.Unmarshal([]byte(wrapInDoc(item.data)), obj) + c.Assert(err, IsNil) + field = obj.Field + } else { + obj := &valSetterDoc{} + err := bson.Unmarshal([]byte(wrapInDoc(item.data)), obj) + c.Assert(err, IsNil) + field = &obj.Field + } + if item.data == "" { + // Nothing to unmarshal. Should be untouched. + if i == 0 { + c.Assert(field, IsNil) + } else { + c.Assert(field.received, IsNil) + } + } else { + expected := item.obj.(bson.M)["_"] + c.Assert(field, NotNil, Commentf("Pointer not initialized (%#v)", expected)) + c.Assert(field.received, DeepEquals, expected) + } + } + } +} + +func (s *S) TestUnmarshalWholeDocumentWithSetter(c *C) { + obj := &setterType{} + err := bson.Unmarshal([]byte(sampleItems[0].data), obj) + c.Assert(err, IsNil) + c.Assert(obj.received, DeepEquals, bson.M{"hello": "world"}) +} + +func (s *S) TestUnmarshalSetterOmits(c *C) { + setterResult["2"] = &bson.TypeError{} + setterResult["4"] = &bson.TypeError{} + defer func() { + delete(setterResult, "2") + delete(setterResult, "4") + }() + + m := map[string]*setterType{} + data := wrapInDoc("\x02abc\x00\x02\x00\x00\x001\x00" + + "\x02def\x00\x02\x00\x00\x002\x00" + + "\x02ghi\x00\x02\x00\x00\x003\x00" + + "\x02jkl\x00\x02\x00\x00\x004\x00") + err := bson.Unmarshal([]byte(data), m) + c.Assert(err, IsNil) + c.Assert(m["abc"], NotNil) + c.Assert(m["def"], IsNil) + c.Assert(m["ghi"], NotNil) + c.Assert(m["jkl"], IsNil) + + c.Assert(m["abc"].received, Equals, "1") + c.Assert(m["ghi"].received, Equals, "3") +} + +func (s *S) TestUnmarshalSetterErrors(c *C) { + boom := errors.New("BOOM") + setterResult["2"] = boom + defer delete(setterResult, "2") + + m := map[string]*setterType{} + data := wrapInDoc("\x02abc\x00\x02\x00\x00\x001\x00" + + "\x02def\x00\x02\x00\x00\x002\x00" + + "\x02ghi\x00\x02\x00\x00\x003\x00") + err := bson.Unmarshal([]byte(data), m) + c.Assert(err, Equals, boom) + c.Assert(m["abc"], NotNil) + c.Assert(m["def"], IsNil) + c.Assert(m["ghi"], IsNil) + + c.Assert(m["abc"].received, Equals, "1") +} + +func (s *S) TestDMap(c *C) { + d := bson.D{{"a", 1}, {"b", 2}} + c.Assert(d.Map(), DeepEquals, bson.M{"a": 1, "b": 2}) +} + +func (s *S) TestUnmarshalSetterSetZero(c *C) { + setterResult["foo"] = bson.SetZero + defer delete(setterResult, "field") + + data, err := bson.Marshal(bson.M{"field": "foo"}) + c.Assert(err, IsNil) + + m := map[string]*setterType{} + err = bson.Unmarshal([]byte(data), m) + c.Assert(err, IsNil) + + value, ok := m["field"] + c.Assert(ok, Equals, true) + c.Assert(value, IsNil) +} + +// -------------------------------------------------------------------------- +// Getter test cases. + +type typeWithGetter struct { + result interface{} + err error +} + +func (t *typeWithGetter) GetBSON() (interface{}, error) { + if t == nil { + return "", nil + } + return t.result, t.err +} + +type docWithGetterField struct { + Field *typeWithGetter "_" +} + +func (s *S) TestMarshalAllItemsWithGetter(c *C) { + for i, item := range allItems { + if item.data == "" { + continue + } + obj := &docWithGetterField{} + obj.Field = &typeWithGetter{result: item.obj.(bson.M)["_"]} + data, err := bson.Marshal(obj) + c.Assert(err, IsNil) + c.Assert(string(data), Equals, wrapInDoc(item.data), + Commentf("Failed on item #%d", i)) + } +} + +func (s *S) TestMarshalWholeDocumentWithGetter(c *C) { + obj := &typeWithGetter{result: sampleItems[0].obj} + data, err := bson.Marshal(obj) + c.Assert(err, IsNil) + c.Assert(string(data), Equals, sampleItems[0].data) +} + +func (s *S) TestGetterErrors(c *C) { + e := errors.New("oops") + + obj1 := &docWithGetterField{} + obj1.Field = &typeWithGetter{sampleItems[0].obj, e} + data, err := bson.Marshal(obj1) + c.Assert(err, ErrorMatches, "oops") + c.Assert(data, IsNil) + + obj2 := &typeWithGetter{sampleItems[0].obj, e} + data, err = bson.Marshal(obj2) + c.Assert(err, ErrorMatches, "oops") + c.Assert(data, IsNil) +} + +type intGetter int64 + +func (t intGetter) GetBSON() (interface{}, error) { + return int64(t), nil +} + +type typeWithIntGetter struct { + V intGetter ",minsize" +} + +func (s *S) TestMarshalShortWithGetter(c *C) { + obj := typeWithIntGetter{42} + data, err := bson.Marshal(obj) + c.Assert(err, IsNil) + m := bson.M{} + err = bson.Unmarshal(data, m) + c.Assert(err, IsNil) + c.Assert(m["v"], Equals, 42) +} + +func (s *S) TestMarshalWithGetterNil(c *C) { + obj := docWithGetterField{} + data, err := bson.Marshal(obj) + c.Assert(err, IsNil) + m := bson.M{} + err = bson.Unmarshal(data, m) + c.Assert(err, IsNil) + c.Assert(m, DeepEquals, bson.M{"_": ""}) +} + +// -------------------------------------------------------------------------- +// Cross-type conversion tests. + +type crossTypeItem struct { + obj1 interface{} + obj2 interface{} +} + +type condStr struct { + V string ",omitempty" +} +type condStrNS struct { + V string `a:"A" bson:",omitempty" b:"B"` +} +type condBool struct { + V bool ",omitempty" +} +type condInt struct { + V int ",omitempty" +} +type condUInt struct { + V uint ",omitempty" +} +type condFloat struct { + V float64 ",omitempty" +} +type condIface struct { + V interface{} ",omitempty" +} +type condPtr struct { + V *bool ",omitempty" +} +type condSlice struct { + V []string ",omitempty" +} +type condMap struct { + V map[string]int ",omitempty" +} +type namedCondStr struct { + V string "myv,omitempty" +} +type condTime struct { + V time.Time ",omitempty" +} +type condStruct struct { + V struct{ A []int } ",omitempty" +} + +type shortInt struct { + V int64 ",minsize" +} +type shortUint struct { + V uint64 ",minsize" +} +type shortIface struct { + V interface{} ",minsize" +} +type shortPtr struct { + V *int64 ",minsize" +} +type shortNonEmptyInt struct { + V int64 ",minsize,omitempty" +} + +type inlineInt struct { + V struct{ A, B int } ",inline" +} +type inlineCantPtr struct { + V *struct{ A, B int } ",inline" +} +type inlineDupName struct { + A int + V struct{ A, B int } ",inline" +} +type inlineMap struct { + A int + M map[string]interface{} ",inline" +} +type inlineMapInt struct { + A int + M map[string]int ",inline" +} +type inlineMapMyM struct { + A int + M MyM ",inline" +} +type inlineDupMap struct { + M1 map[string]interface{} ",inline" + M2 map[string]interface{} ",inline" +} +type inlineBadKeyMap struct { + M map[int]int ",inline" +} + +type getterSetterD bson.D + +func (s getterSetterD) GetBSON() (interface{}, error) { + if len(s) == 0 { + return bson.D{}, nil + } + return bson.D(s[:len(s)-1]), nil +} + +func (s *getterSetterD) SetBSON(raw bson.Raw) error { + var doc bson.D + err := raw.Unmarshal(&doc) + doc = append(doc, bson.DocElem{"suffix", true}) + *s = getterSetterD(doc) + return err +} + +type getterSetterInt int + +func (i getterSetterInt) GetBSON() (interface{}, error) { + return bson.D{{"a", int(i)}}, nil +} + +func (i *getterSetterInt) SetBSON(raw bson.Raw) error { + var doc struct{ A int } + err := raw.Unmarshal(&doc) + *i = getterSetterInt(doc.A) + return err +} + +type ifaceType interface { + Hello() +} + +type ifaceSlice []ifaceType + +func (s *ifaceSlice) SetBSON(raw bson.Raw) error { + var ns []int + if err := raw.Unmarshal(&ns); err != nil { + return err + } + *s = make(ifaceSlice, ns[0]) + return nil +} + +func (s ifaceSlice) GetBSON() (interface{}, error) { + return []int{len(s)}, nil +} + +type ( + MyString string + MyBytes []byte + MyBool bool + MyD []bson.DocElem + MyRawD []bson.RawDocElem + MyM map[string]interface{} +) + +var ( + truevar = true + falsevar = false + + int64var = int64(42) + int64ptr = &int64var + intvar = int(42) + intptr = &intvar + + gsintvar = getterSetterInt(42) +) + +func parseURL(s string) *url.URL { + u, err := url.Parse(s) + if err != nil { + panic(err) + } + return u +} + +// That's a pretty fun test. It will dump the first item, generate a zero +// value equivalent to the second one, load the dumped data onto it, and then +// verify that the resulting value is deep-equal to the untouched second value. +// Then, it will do the same in the *opposite* direction! +var twoWayCrossItems = []crossTypeItem{ + // int<=>int + {&struct{ I int }{42}, &struct{ I int8 }{42}}, + {&struct{ I int }{42}, &struct{ I int32 }{42}}, + {&struct{ I int }{42}, &struct{ I int64 }{42}}, + {&struct{ I int8 }{42}, &struct{ I int32 }{42}}, + {&struct{ I int8 }{42}, &struct{ I int64 }{42}}, + {&struct{ I int32 }{42}, &struct{ I int64 }{42}}, + + // uint<=>uint + {&struct{ I uint }{42}, &struct{ I uint8 }{42}}, + {&struct{ I uint }{42}, &struct{ I uint32 }{42}}, + {&struct{ I uint }{42}, &struct{ I uint64 }{42}}, + {&struct{ I uint8 }{42}, &struct{ I uint32 }{42}}, + {&struct{ I uint8 }{42}, &struct{ I uint64 }{42}}, + {&struct{ I uint32 }{42}, &struct{ I uint64 }{42}}, + + // float32<=>float64 + {&struct{ I float32 }{42}, &struct{ I float64 }{42}}, + + // int<=>uint + {&struct{ I uint }{42}, &struct{ I int }{42}}, + {&struct{ I uint }{42}, &struct{ I int8 }{42}}, + {&struct{ I uint }{42}, &struct{ I int32 }{42}}, + {&struct{ I uint }{42}, &struct{ I int64 }{42}}, + {&struct{ I uint8 }{42}, &struct{ I int }{42}}, + {&struct{ I uint8 }{42}, &struct{ I int8 }{42}}, + {&struct{ I uint8 }{42}, &struct{ I int32 }{42}}, + {&struct{ I uint8 }{42}, &struct{ I int64 }{42}}, + {&struct{ I uint32 }{42}, &struct{ I int }{42}}, + {&struct{ I uint32 }{42}, &struct{ I int8 }{42}}, + {&struct{ I uint32 }{42}, &struct{ I int32 }{42}}, + {&struct{ I uint32 }{42}, &struct{ I int64 }{42}}, + {&struct{ I uint64 }{42}, &struct{ I int }{42}}, + {&struct{ I uint64 }{42}, &struct{ I int8 }{42}}, + {&struct{ I uint64 }{42}, &struct{ I int32 }{42}}, + {&struct{ I uint64 }{42}, &struct{ I int64 }{42}}, + + // int <=> float + {&struct{ I int }{42}, &struct{ I float64 }{42}}, + + // int <=> bool + {&struct{ I int }{1}, &struct{ I bool }{true}}, + {&struct{ I int }{0}, &struct{ I bool }{false}}, + + // uint <=> float64 + {&struct{ I uint }{42}, &struct{ I float64 }{42}}, + + // uint <=> bool + {&struct{ I uint }{1}, &struct{ I bool }{true}}, + {&struct{ I uint }{0}, &struct{ I bool }{false}}, + + // float64 <=> bool + {&struct{ I float64 }{1}, &struct{ I bool }{true}}, + {&struct{ I float64 }{0}, &struct{ I bool }{false}}, + + // string <=> string and string <=> []byte + {&struct{ S []byte }{[]byte("abc")}, &struct{ S string }{"abc"}}, + {&struct{ S []byte }{[]byte("def")}, &struct{ S bson.Symbol }{"def"}}, + {&struct{ S string }{"ghi"}, &struct{ S bson.Symbol }{"ghi"}}, + + // map <=> struct + {&struct { + A struct { + B, C int + } + }{struct{ B, C int }{1, 2}}, + map[string]map[string]int{"a": map[string]int{"b": 1, "c": 2}}}, + + {&struct{ A bson.Symbol }{"abc"}, map[string]string{"a": "abc"}}, + {&struct{ A bson.Symbol }{"abc"}, map[string][]byte{"a": []byte("abc")}}, + {&struct{ A []byte }{[]byte("abc")}, map[string]string{"a": "abc"}}, + {&struct{ A uint }{42}, map[string]int{"a": 42}}, + {&struct{ A uint }{42}, map[string]float64{"a": 42}}, + {&struct{ A uint }{1}, map[string]bool{"a": true}}, + {&struct{ A int }{42}, map[string]uint{"a": 42}}, + {&struct{ A int }{42}, map[string]float64{"a": 42}}, + {&struct{ A int }{1}, map[string]bool{"a": true}}, + {&struct{ A float64 }{42}, map[string]float32{"a": 42}}, + {&struct{ A float64 }{42}, map[string]int{"a": 42}}, + {&struct{ A float64 }{42}, map[string]uint{"a": 42}}, + {&struct{ A float64 }{1}, map[string]bool{"a": true}}, + {&struct{ A bool }{true}, map[string]int{"a": 1}}, + {&struct{ A bool }{true}, map[string]uint{"a": 1}}, + {&struct{ A bool }{true}, map[string]float64{"a": 1}}, + {&struct{ A **byte }{&byteptr}, map[string]byte{"a": 8}}, + + // url.URL <=> string + {&struct{ URL *url.URL }{parseURL("h://e.c/p")}, map[string]string{"url": "h://e.c/p"}}, + {&struct{ URL url.URL }{*parseURL("h://e.c/p")}, map[string]string{"url": "h://e.c/p"}}, + + // Slices + {&struct{ S []int }{[]int{1, 2, 3}}, map[string][]int{"s": []int{1, 2, 3}}}, + {&struct{ S *[]int }{&[]int{1, 2, 3}}, map[string][]int{"s": []int{1, 2, 3}}}, + + // Conditionals + {&condBool{true}, map[string]bool{"v": true}}, + {&condBool{}, map[string]bool{}}, + {&condInt{1}, map[string]int{"v": 1}}, + {&condInt{}, map[string]int{}}, + {&condUInt{1}, map[string]uint{"v": 1}}, + {&condUInt{}, map[string]uint{}}, + {&condFloat{}, map[string]int{}}, + {&condStr{"yo"}, map[string]string{"v": "yo"}}, + {&condStr{}, map[string]string{}}, + {&condStrNS{"yo"}, map[string]string{"v": "yo"}}, + {&condStrNS{}, map[string]string{}}, + {&condSlice{[]string{"yo"}}, map[string][]string{"v": []string{"yo"}}}, + {&condSlice{}, map[string][]string{}}, + {&condMap{map[string]int{"k": 1}}, bson.M{"v": bson.M{"k": 1}}}, + {&condMap{}, map[string][]string{}}, + {&condIface{"yo"}, map[string]string{"v": "yo"}}, + {&condIface{""}, map[string]string{"v": ""}}, + {&condIface{}, map[string]string{}}, + {&condPtr{&truevar}, map[string]bool{"v": true}}, + {&condPtr{&falsevar}, map[string]bool{"v": false}}, + {&condPtr{}, map[string]string{}}, + + {&condTime{time.Unix(123456789, 123e6)}, map[string]time.Time{"v": time.Unix(123456789, 123e6)}}, + {&condTime{}, map[string]string{}}, + + {&condStruct{struct{ A []int }{[]int{1}}}, bson.M{"v": bson.M{"a": []interface{}{1}}}}, + {&condStruct{struct{ A []int }{}}, bson.M{}}, + + {&namedCondStr{"yo"}, map[string]string{"myv": "yo"}}, + {&namedCondStr{}, map[string]string{}}, + + {&shortInt{1}, map[string]interface{}{"v": 1}}, + {&shortInt{1 << 30}, map[string]interface{}{"v": 1 << 30}}, + {&shortInt{1 << 31}, map[string]interface{}{"v": int64(1 << 31)}}, + {&shortUint{1 << 30}, map[string]interface{}{"v": 1 << 30}}, + {&shortUint{1 << 31}, map[string]interface{}{"v": int64(1 << 31)}}, + {&shortIface{int64(1) << 31}, map[string]interface{}{"v": int64(1 << 31)}}, + {&shortPtr{int64ptr}, map[string]interface{}{"v": intvar}}, + + {&shortNonEmptyInt{1}, map[string]interface{}{"v": 1}}, + {&shortNonEmptyInt{1 << 31}, map[string]interface{}{"v": int64(1 << 31)}}, + {&shortNonEmptyInt{}, map[string]interface{}{}}, + + {&inlineInt{struct{ A, B int }{1, 2}}, map[string]interface{}{"a": 1, "b": 2}}, + {&inlineMap{A: 1, M: map[string]interface{}{"b": 2}}, map[string]interface{}{"a": 1, "b": 2}}, + {&inlineMap{A: 1, M: nil}, map[string]interface{}{"a": 1}}, + {&inlineMapInt{A: 1, M: map[string]int{"b": 2}}, map[string]int{"a": 1, "b": 2}}, + {&inlineMapInt{A: 1, M: nil}, map[string]int{"a": 1}}, + {&inlineMapMyM{A: 1, M: MyM{"b": MyM{"c": 3}}}, map[string]interface{}{"a": 1, "b": map[string]interface{}{"c": 3}}}, + + // []byte <=> MyBytes + {&struct{ B MyBytes }{[]byte("abc")}, map[string]string{"b": "abc"}}, + {&struct{ B MyBytes }{[]byte{}}, map[string]string{"b": ""}}, + {&struct{ B MyBytes }{}, map[string]bool{}}, + {&struct{ B []byte }{[]byte("abc")}, map[string]MyBytes{"b": []byte("abc")}}, + + // bool <=> MyBool + {&struct{ B MyBool }{true}, map[string]bool{"b": true}}, + {&struct{ B MyBool }{}, map[string]bool{"b": false}}, + {&struct{ B MyBool }{}, map[string]string{}}, + {&struct{ B bool }{}, map[string]MyBool{"b": false}}, + + // arrays + {&struct{ V [2]int }{[...]int{1, 2}}, map[string][2]int{"v": [2]int{1, 2}}}, + {&struct{ V [2]byte }{[...]byte{1, 2}}, map[string][2]byte{"v": [2]byte{1, 2}}}, + + // zero time + {&struct{ V time.Time }{}, map[string]interface{}{"v": time.Time{}}}, + + // zero time + 1 second + 1 millisecond; overflows int64 as nanoseconds + {&struct{ V time.Time }{time.Unix(-62135596799, 1e6).Local()}, + map[string]interface{}{"v": time.Unix(-62135596799, 1e6).Local()}}, + + // bson.D <=> []DocElem + {&bson.D{{"a", bson.D{{"b", 1}, {"c", 2}}}}, &bson.D{{"a", bson.D{{"b", 1}, {"c", 2}}}}}, + {&bson.D{{"a", bson.D{{"b", 1}, {"c", 2}}}}, &MyD{{"a", MyD{{"b", 1}, {"c", 2}}}}}, + {&struct{ V MyD }{MyD{{"a", 1}}}, &bson.D{{"v", bson.D{{"a", 1}}}}}, + + // bson.RawD <=> []RawDocElem + {&bson.RawD{{"a", bson.Raw{0x08, []byte{0x01}}}}, &bson.RawD{{"a", bson.Raw{0x08, []byte{0x01}}}}}, + {&bson.RawD{{"a", bson.Raw{0x08, []byte{0x01}}}}, &MyRawD{{"a", bson.Raw{0x08, []byte{0x01}}}}}, + + // bson.M <=> map + {bson.M{"a": bson.M{"b": 1, "c": 2}}, MyM{"a": MyM{"b": 1, "c": 2}}}, + {bson.M{"a": bson.M{"b": 1, "c": 2}}, map[string]interface{}{"a": map[string]interface{}{"b": 1, "c": 2}}}, + + // bson.M <=> map[MyString] + {bson.M{"a": bson.M{"b": 1, "c": 2}}, map[MyString]interface{}{"a": map[MyString]interface{}{"b": 1, "c": 2}}}, + + // json.Number <=> int64, float64 + {&struct{ N json.Number }{"5"}, map[string]interface{}{"n": int64(5)}}, + {&struct{ N json.Number }{"5.05"}, map[string]interface{}{"n": 5.05}}, + {&struct{ N json.Number }{"9223372036854776000"}, map[string]interface{}{"n": float64(1 << 63)}}, + + // bson.D <=> non-struct getter/setter + {&bson.D{{"a", 1}}, &getterSetterD{{"a", 1}, {"suffix", true}}}, + {&bson.D{{"a", 42}}, &gsintvar}, + + // Interface slice setter. + {&struct{ V ifaceSlice }{ifaceSlice{nil, nil, nil}}, bson.M{"v": []interface{}{3}}}, +} + +// Same thing, but only one way (obj1 => obj2). +var oneWayCrossItems = []crossTypeItem{ + // map <=> struct + {map[string]interface{}{"a": 1, "b": "2", "c": 3}, map[string]int{"a": 1, "c": 3}}, + + // inline map elides badly typed values + {map[string]interface{}{"a": 1, "b": "2", "c": 3}, &inlineMapInt{A: 1, M: map[string]int{"c": 3}}}, + + // Can't decode int into struct. + {bson.M{"a": bson.M{"b": 2}}, &struct{ A bool }{}}, + + // Would get decoded into a int32 too in the opposite direction. + {&shortIface{int64(1) << 30}, map[string]interface{}{"v": 1 << 30}}, + + // Ensure omitempty on struct with private fields works properly. + {&struct { + V struct{ v time.Time } ",omitempty" + }{}, map[string]interface{}{}}, +} + +func testCrossPair(c *C, dump interface{}, load interface{}) { + c.Logf("Dump: %#v", dump) + c.Logf("Load: %#v", load) + zero := makeZeroDoc(load) + data, err := bson.Marshal(dump) + c.Assert(err, IsNil) + c.Logf("Dumped: %#v", string(data)) + err = bson.Unmarshal(data, zero) + c.Assert(err, IsNil) + c.Logf("Loaded: %#v", zero) + c.Assert(zero, DeepEquals, load) +} + +func (s *S) TestTwoWayCrossPairs(c *C) { + for _, item := range twoWayCrossItems { + testCrossPair(c, item.obj1, item.obj2) + testCrossPair(c, item.obj2, item.obj1) + } +} + +func (s *S) TestOneWayCrossPairs(c *C) { + for _, item := range oneWayCrossItems { + testCrossPair(c, item.obj1, item.obj2) + } +} + +// -------------------------------------------------------------------------- +// ObjectId hex representation test. + +func (s *S) TestObjectIdHex(c *C) { + id := bson.ObjectIdHex("4d88e15b60f486e428412dc9") + c.Assert(id.String(), Equals, `ObjectIdHex("4d88e15b60f486e428412dc9")`) + c.Assert(id.Hex(), Equals, "4d88e15b60f486e428412dc9") +} + +func (s *S) TestIsObjectIdHex(c *C) { + test := []struct { + id string + valid bool + }{ + {"4d88e15b60f486e428412dc9", true}, + {"4d88e15b60f486e428412dc", false}, + {"4d88e15b60f486e428412dc9e", false}, + {"4d88e15b60f486e428412dcx", false}, + } + for _, t := range test { + c.Assert(bson.IsObjectIdHex(t.id), Equals, t.valid) + } +} + +// -------------------------------------------------------------------------- +// ObjectId parts extraction tests. + +type objectIdParts struct { + id bson.ObjectId + timestamp int64 + machine []byte + pid uint16 + counter int32 +} + +var objectIds = []objectIdParts{ + objectIdParts{ + bson.ObjectIdHex("4d88e15b60f486e428412dc9"), + 1300816219, + []byte{0x60, 0xf4, 0x86}, + 0xe428, + 4271561, + }, + objectIdParts{ + bson.ObjectIdHex("000000000000000000000000"), + 0, + []byte{0x00, 0x00, 0x00}, + 0x0000, + 0, + }, + objectIdParts{ + bson.ObjectIdHex("00000000aabbccddee000001"), + 0, + []byte{0xaa, 0xbb, 0xcc}, + 0xddee, + 1, + }, +} + +func (s *S) TestObjectIdPartsExtraction(c *C) { + for i, v := range objectIds { + t := time.Unix(v.timestamp, 0) + c.Assert(v.id.Time(), Equals, t, Commentf("#%d Wrong timestamp value", i)) + c.Assert(v.id.Machine(), DeepEquals, v.machine, Commentf("#%d Wrong machine id value", i)) + c.Assert(v.id.Pid(), Equals, v.pid, Commentf("#%d Wrong pid value", i)) + c.Assert(v.id.Counter(), Equals, v.counter, Commentf("#%d Wrong counter value", i)) + } +} + +func (s *S) TestNow(c *C) { + before := time.Now() + time.Sleep(1e6) + now := bson.Now() + time.Sleep(1e6) + after := time.Now() + c.Assert(now.After(before) && now.Before(after), Equals, true, Commentf("now=%s, before=%s, after=%s", now, before, after)) +} + +// -------------------------------------------------------------------------- +// ObjectId generation tests. + +func (s *S) TestNewObjectId(c *C) { + // Generate 10 ids + ids := make([]bson.ObjectId, 10) + for i := 0; i < 10; i++ { + ids[i] = bson.NewObjectId() + } + for i := 1; i < 10; i++ { + prevId := ids[i-1] + id := ids[i] + // Test for uniqueness among all other 9 generated ids + for j, tid := range ids { + if j != i { + c.Assert(id, Not(Equals), tid, Commentf("Generated ObjectId is not unique")) + } + } + // Check that timestamp was incremented and is within 30 seconds of the previous one + secs := id.Time().Sub(prevId.Time()).Seconds() + c.Assert((secs >= 0 && secs <= 30), Equals, true, Commentf("Wrong timestamp in generated ObjectId")) + // Check that machine ids are the same + c.Assert(id.Machine(), DeepEquals, prevId.Machine()) + // Check that pids are the same + c.Assert(id.Pid(), Equals, prevId.Pid()) + // Test for proper increment + delta := int(id.Counter() - prevId.Counter()) + c.Assert(delta, Equals, 1, Commentf("Wrong increment in generated ObjectId")) + } +} + +func (s *S) TestNewObjectIdWithTime(c *C) { + t := time.Unix(12345678, 0) + id := bson.NewObjectIdWithTime(t) + c.Assert(id.Time(), Equals, t) + c.Assert(id.Machine(), DeepEquals, []byte{0x00, 0x00, 0x00}) + c.Assert(int(id.Pid()), Equals, 0) + c.Assert(int(id.Counter()), Equals, 0) +} + +// -------------------------------------------------------------------------- +// ObjectId JSON marshalling. + +type jsonType struct { + Id bson.ObjectId +} + +var jsonIdTests = []struct { + value jsonType + json string + marshal bool + unmarshal bool + error string +}{{ + value: jsonType{Id: bson.ObjectIdHex("4d88e15b60f486e428412dc9")}, + json: `{"Id":"4d88e15b60f486e428412dc9"}`, + marshal: true, + unmarshal: true, +}, { + value: jsonType{}, + json: `{"Id":""}`, + marshal: true, + unmarshal: true, +}, { + value: jsonType{}, + json: `{"Id":null}`, + marshal: false, + unmarshal: true, +}, { + json: `{"Id":"4d88e15b60f486e428412dc9A"}`, + error: `Invalid ObjectId in JSON: "4d88e15b60f486e428412dc9A"`, + marshal: false, + unmarshal: true, +}, { + json: `{"Id":"4d88e15b60f486e428412dcZ"}`, + error: `Invalid ObjectId in JSON: "4d88e15b60f486e428412dcZ" .*`, + marshal: false, + unmarshal: true, +}} + +func (s *S) TestObjectIdJSONMarshaling(c *C) { + for _, test := range jsonIdTests { + if test.marshal { + data, err := json.Marshal(&test.value) + if test.error == "" { + c.Assert(err, IsNil) + c.Assert(string(data), Equals, test.json) + } else { + c.Assert(err, ErrorMatches, test.error) + } + } + + if test.unmarshal { + var value jsonType + err := json.Unmarshal([]byte(test.json), &value) + if test.error == "" { + c.Assert(err, IsNil) + c.Assert(value, DeepEquals, test.value) + } else { + c.Assert(err, ErrorMatches, test.error) + } + } + } +} + +// -------------------------------------------------------------------------- +// Some simple benchmarks. + +type BenchT struct { + A, B, C, D, E, F string +} + +type BenchRawT struct { + A string + B int + C bson.M + D []float64 +} + +func (s *S) BenchmarkUnmarhsalStruct(c *C) { + v := BenchT{A: "A", D: "D", E: "E"} + data, err := bson.Marshal(&v) + if err != nil { + panic(err) + } + c.ResetTimer() + for i := 0; i < c.N; i++ { + err = bson.Unmarshal(data, &v) + } + if err != nil { + panic(err) + } +} + +func (s *S) BenchmarkUnmarhsalMap(c *C) { + m := bson.M{"a": "a", "d": "d", "e": "e"} + data, err := bson.Marshal(&m) + if err != nil { + panic(err) + } + c.ResetTimer() + for i := 0; i < c.N; i++ { + err = bson.Unmarshal(data, &m) + } + if err != nil { + panic(err) + } +} + +func (s *S) BenchmarkUnmarshalRaw(c *C) { + var err error + m := BenchRawT{ + A: "test_string", + B: 123, + C: bson.M{ + "subdoc_int": 12312, + "subdoc_doc": bson.M{"1": 1}, + }, + D: []float64{0.0, 1.3333, -99.9997, 3.1415}, + } + data, err := bson.Marshal(&m) + if err != nil { + panic(err) + } + raw := bson.Raw{} + c.ResetTimer() + for i := 0; i < c.N; i++ { + err = bson.Unmarshal(data, &raw) + } + if err != nil { + panic(err) + } +} diff --git a/vendor/gopkg.in/mgo.v2/bson/decode.go b/vendor/gopkg.in/mgo.v2/bson/decode.go new file mode 100644 index 000000000..bdd2e0287 --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/bson/decode.go @@ -0,0 +1,825 @@ +// BSON library for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// gobson - BSON library for Go. + +package bson + +import ( + "fmt" + "math" + "net/url" + "reflect" + "strconv" + "sync" + "time" +) + +type decoder struct { + in []byte + i int + docType reflect.Type +} + +var typeM = reflect.TypeOf(M{}) + +func newDecoder(in []byte) *decoder { + return &decoder{in, 0, typeM} +} + +// -------------------------------------------------------------------------- +// Some helper functions. + +func corrupted() { + panic("Document is corrupted") +} + +func settableValueOf(i interface{}) reflect.Value { + v := reflect.ValueOf(i) + sv := reflect.New(v.Type()).Elem() + sv.Set(v) + return sv +} + +// -------------------------------------------------------------------------- +// Unmarshaling of documents. + +const ( + setterUnknown = iota + setterNone + setterType + setterAddr +) + +var setterStyles map[reflect.Type]int +var setterIface reflect.Type +var setterMutex sync.RWMutex + +func init() { + var iface Setter + setterIface = reflect.TypeOf(&iface).Elem() + setterStyles = make(map[reflect.Type]int) +} + +func setterStyle(outt reflect.Type) int { + setterMutex.RLock() + style := setterStyles[outt] + setterMutex.RUnlock() + if style == setterUnknown { + setterMutex.Lock() + defer setterMutex.Unlock() + if outt.Implements(setterIface) { + setterStyles[outt] = setterType + } else if reflect.PtrTo(outt).Implements(setterIface) { + setterStyles[outt] = setterAddr + } else { + setterStyles[outt] = setterNone + } + style = setterStyles[outt] + } + return style +} + +func getSetter(outt reflect.Type, out reflect.Value) Setter { + style := setterStyle(outt) + if style == setterNone { + return nil + } + if style == setterAddr { + if !out.CanAddr() { + return nil + } + out = out.Addr() + } else if outt.Kind() == reflect.Ptr && out.IsNil() { + out.Set(reflect.New(outt.Elem())) + } + return out.Interface().(Setter) +} + +func clearMap(m reflect.Value) { + var none reflect.Value + for _, k := range m.MapKeys() { + m.SetMapIndex(k, none) + } +} + +func (d *decoder) readDocTo(out reflect.Value) { + var elemType reflect.Type + outt := out.Type() + outk := outt.Kind() + + for { + if outk == reflect.Ptr && out.IsNil() { + out.Set(reflect.New(outt.Elem())) + } + if setter := getSetter(outt, out); setter != nil { + var raw Raw + d.readDocTo(reflect.ValueOf(&raw)) + err := setter.SetBSON(raw) + if _, ok := err.(*TypeError); err != nil && !ok { + panic(err) + } + return + } + if outk == reflect.Ptr { + out = out.Elem() + outt = out.Type() + outk = out.Kind() + continue + } + break + } + + var fieldsMap map[string]fieldInfo + var inlineMap reflect.Value + start := d.i + + origout := out + if outk == reflect.Interface { + if d.docType.Kind() == reflect.Map { + mv := reflect.MakeMap(d.docType) + out.Set(mv) + out = mv + } else { + dv := reflect.New(d.docType).Elem() + out.Set(dv) + out = dv + } + outt = out.Type() + outk = outt.Kind() + } + + docType := d.docType + keyType := typeString + convertKey := false + switch outk { + case reflect.Map: + keyType = outt.Key() + if keyType.Kind() != reflect.String { + panic("BSON map must have string keys. Got: " + outt.String()) + } + if keyType != typeString { + convertKey = true + } + elemType = outt.Elem() + if elemType == typeIface { + d.docType = outt + } + if out.IsNil() { + out.Set(reflect.MakeMap(out.Type())) + } else if out.Len() > 0 { + clearMap(out) + } + case reflect.Struct: + if outt != typeRaw { + sinfo, err := getStructInfo(out.Type()) + if err != nil { + panic(err) + } + fieldsMap = sinfo.FieldsMap + out.Set(sinfo.Zero) + if sinfo.InlineMap != -1 { + inlineMap = out.Field(sinfo.InlineMap) + if !inlineMap.IsNil() && inlineMap.Len() > 0 { + clearMap(inlineMap) + } + elemType = inlineMap.Type().Elem() + if elemType == typeIface { + d.docType = inlineMap.Type() + } + } + } + case reflect.Slice: + switch outt.Elem() { + case typeDocElem: + origout.Set(d.readDocElems(outt)) + return + case typeRawDocElem: + origout.Set(d.readRawDocElems(outt)) + return + } + fallthrough + default: + panic("Unsupported document type for unmarshalling: " + out.Type().String()) + } + + end := int(d.readInt32()) + end += d.i - 4 + if end <= d.i || end > len(d.in) || d.in[end-1] != '\x00' { + corrupted() + } + for d.in[d.i] != '\x00' { + kind := d.readByte() + name := d.readCStr() + if d.i >= end { + corrupted() + } + + switch outk { + case reflect.Map: + e := reflect.New(elemType).Elem() + if d.readElemTo(e, kind) { + k := reflect.ValueOf(name) + if convertKey { + k = k.Convert(keyType) + } + out.SetMapIndex(k, e) + } + case reflect.Struct: + if outt == typeRaw { + d.dropElem(kind) + } else { + if info, ok := fieldsMap[name]; ok { + if info.Inline == nil { + d.readElemTo(out.Field(info.Num), kind) + } else { + d.readElemTo(out.FieldByIndex(info.Inline), kind) + } + } else if inlineMap.IsValid() { + if inlineMap.IsNil() { + inlineMap.Set(reflect.MakeMap(inlineMap.Type())) + } + e := reflect.New(elemType).Elem() + if d.readElemTo(e, kind) { + inlineMap.SetMapIndex(reflect.ValueOf(name), e) + } + } else { + d.dropElem(kind) + } + } + case reflect.Slice: + } + + if d.i >= end { + corrupted() + } + } + d.i++ // '\x00' + if d.i != end { + corrupted() + } + d.docType = docType + + if outt == typeRaw { + out.Set(reflect.ValueOf(Raw{0x03, d.in[start:d.i]})) + } +} + +func (d *decoder) readArrayDocTo(out reflect.Value) { + end := int(d.readInt32()) + end += d.i - 4 + if end <= d.i || end > len(d.in) || d.in[end-1] != '\x00' { + corrupted() + } + i := 0 + l := out.Len() + for d.in[d.i] != '\x00' { + if i >= l { + panic("Length mismatch on array field") + } + kind := d.readByte() + for d.i < end && d.in[d.i] != '\x00' { + d.i++ + } + if d.i >= end { + corrupted() + } + d.i++ + d.readElemTo(out.Index(i), kind) + if d.i >= end { + corrupted() + } + i++ + } + if i != l { + panic("Length mismatch on array field") + } + d.i++ // '\x00' + if d.i != end { + corrupted() + } +} + +func (d *decoder) readSliceDoc(t reflect.Type) interface{} { + tmp := make([]reflect.Value, 0, 8) + elemType := t.Elem() + + end := int(d.readInt32()) + end += d.i - 4 + if end <= d.i || end > len(d.in) || d.in[end-1] != '\x00' { + corrupted() + } + for d.in[d.i] != '\x00' { + kind := d.readByte() + for d.i < end && d.in[d.i] != '\x00' { + d.i++ + } + if d.i >= end { + corrupted() + } + d.i++ + e := reflect.New(elemType).Elem() + if d.readElemTo(e, kind) { + tmp = append(tmp, e) + } + if d.i >= end { + corrupted() + } + } + d.i++ // '\x00' + if d.i != end { + corrupted() + } + + n := len(tmp) + slice := reflect.MakeSlice(t, n, n) + for i := 0; i != n; i++ { + slice.Index(i).Set(tmp[i]) + } + return slice.Interface() +} + +var typeSlice = reflect.TypeOf([]interface{}{}) +var typeIface = typeSlice.Elem() + +func (d *decoder) readDocElems(typ reflect.Type) reflect.Value { + docType := d.docType + d.docType = typ + slice := make([]DocElem, 0, 8) + d.readDocWith(func(kind byte, name string) { + e := DocElem{Name: name} + v := reflect.ValueOf(&e.Value) + if d.readElemTo(v.Elem(), kind) { + slice = append(slice, e) + } + }) + slicev := reflect.New(typ).Elem() + slicev.Set(reflect.ValueOf(slice)) + d.docType = docType + return slicev +} + +func (d *decoder) readRawDocElems(typ reflect.Type) reflect.Value { + docType := d.docType + d.docType = typ + slice := make([]RawDocElem, 0, 8) + d.readDocWith(func(kind byte, name string) { + e := RawDocElem{Name: name} + v := reflect.ValueOf(&e.Value) + if d.readElemTo(v.Elem(), kind) { + slice = append(slice, e) + } + }) + slicev := reflect.New(typ).Elem() + slicev.Set(reflect.ValueOf(slice)) + d.docType = docType + return slicev +} + +func (d *decoder) readDocWith(f func(kind byte, name string)) { + end := int(d.readInt32()) + end += d.i - 4 + if end <= d.i || end > len(d.in) || d.in[end-1] != '\x00' { + corrupted() + } + for d.in[d.i] != '\x00' { + kind := d.readByte() + name := d.readCStr() + if d.i >= end { + corrupted() + } + f(kind, name) + if d.i >= end { + corrupted() + } + } + d.i++ // '\x00' + if d.i != end { + corrupted() + } +} + +// -------------------------------------------------------------------------- +// Unmarshaling of individual elements within a document. + +var blackHole = settableValueOf(struct{}{}) + +func (d *decoder) dropElem(kind byte) { + d.readElemTo(blackHole, kind) +} + +// Attempt to decode an element from the document and put it into out. +// If the types are not compatible, the returned ok value will be +// false and out will be unchanged. +func (d *decoder) readElemTo(out reflect.Value, kind byte) (good bool) { + + start := d.i + + if kind == '\x03' { + // Delegate unmarshaling of documents. + outt := out.Type() + outk := out.Kind() + switch outk { + case reflect.Interface, reflect.Ptr, reflect.Struct, reflect.Map: + d.readDocTo(out) + return true + } + if setterStyle(outt) != setterNone { + d.readDocTo(out) + return true + } + if outk == reflect.Slice { + switch outt.Elem() { + case typeDocElem: + out.Set(d.readDocElems(outt)) + case typeRawDocElem: + out.Set(d.readRawDocElems(outt)) + } + return true + } + d.readDocTo(blackHole) + return true + } + + var in interface{} + + switch kind { + case 0x01: // Float64 + in = d.readFloat64() + case 0x02: // UTF-8 string + in = d.readStr() + case 0x03: // Document + panic("Can't happen. Handled above.") + case 0x04: // Array + outt := out.Type() + if setterStyle(outt) != setterNone { + // Skip the value so its data is handed to the setter below. + d.dropElem(kind) + break + } + for outt.Kind() == reflect.Ptr { + outt = outt.Elem() + } + switch outt.Kind() { + case reflect.Array: + d.readArrayDocTo(out) + return true + case reflect.Slice: + in = d.readSliceDoc(outt) + default: + in = d.readSliceDoc(typeSlice) + } + case 0x05: // Binary + b := d.readBinary() + if b.Kind == 0x00 || b.Kind == 0x02 { + in = b.Data + } else { + in = b + } + case 0x06: // Undefined (obsolete, but still seen in the wild) + in = Undefined + case 0x07: // ObjectId + in = ObjectId(d.readBytes(12)) + case 0x08: // Bool + in = d.readBool() + case 0x09: // Timestamp + // MongoDB handles timestamps as milliseconds. + i := d.readInt64() + if i == -62135596800000 { + in = time.Time{} // In UTC for convenience. + } else { + in = time.Unix(i/1e3, i%1e3*1e6) + } + case 0x0A: // Nil + in = nil + case 0x0B: // RegEx + in = d.readRegEx() + case 0x0C: + in = DBPointer{Namespace: d.readStr(), Id: ObjectId(d.readBytes(12))} + case 0x0D: // JavaScript without scope + in = JavaScript{Code: d.readStr()} + case 0x0E: // Symbol + in = Symbol(d.readStr()) + case 0x0F: // JavaScript with scope + d.i += 4 // Skip length + js := JavaScript{d.readStr(), make(M)} + d.readDocTo(reflect.ValueOf(js.Scope)) + in = js + case 0x10: // Int32 + in = int(d.readInt32()) + case 0x11: // Mongo-specific timestamp + in = MongoTimestamp(d.readInt64()) + case 0x12: // Int64 + in = d.readInt64() + case 0x7F: // Max key + in = MaxKey + case 0xFF: // Min key + in = MinKey + default: + panic(fmt.Sprintf("Unknown element kind (0x%02X)", kind)) + } + + outt := out.Type() + + if outt == typeRaw { + out.Set(reflect.ValueOf(Raw{kind, d.in[start:d.i]})) + return true + } + + if setter := getSetter(outt, out); setter != nil { + err := setter.SetBSON(Raw{kind, d.in[start:d.i]}) + if err == SetZero { + out.Set(reflect.Zero(outt)) + return true + } + if err == nil { + return true + } + if _, ok := err.(*TypeError); !ok { + panic(err) + } + return false + } + + if in == nil { + out.Set(reflect.Zero(outt)) + return true + } + + outk := outt.Kind() + + // Dereference and initialize pointer if necessary. + first := true + for outk == reflect.Ptr { + if !out.IsNil() { + out = out.Elem() + } else { + elem := reflect.New(outt.Elem()) + if first { + // Only set if value is compatible. + first = false + defer func(out, elem reflect.Value) { + if good { + out.Set(elem) + } + }(out, elem) + } else { + out.Set(elem) + } + out = elem + } + outt = out.Type() + outk = outt.Kind() + } + + inv := reflect.ValueOf(in) + if outt == inv.Type() { + out.Set(inv) + return true + } + + switch outk { + case reflect.Interface: + out.Set(inv) + return true + case reflect.String: + switch inv.Kind() { + case reflect.String: + out.SetString(inv.String()) + return true + case reflect.Slice: + if b, ok := in.([]byte); ok { + out.SetString(string(b)) + return true + } + case reflect.Int, reflect.Int64: + if outt == typeJSONNumber { + out.SetString(strconv.FormatInt(inv.Int(), 10)) + return true + } + case reflect.Float64: + if outt == typeJSONNumber { + out.SetString(strconv.FormatFloat(inv.Float(), 'f', -1, 64)) + return true + } + } + case reflect.Slice, reflect.Array: + // Remember, array (0x04) slices are built with the correct + // element type. If we are here, must be a cross BSON kind + // conversion (e.g. 0x05 unmarshalling on string). + if outt.Elem().Kind() != reflect.Uint8 { + break + } + switch inv.Kind() { + case reflect.String: + slice := []byte(inv.String()) + out.Set(reflect.ValueOf(slice)) + return true + case reflect.Slice: + switch outt.Kind() { + case reflect.Array: + reflect.Copy(out, inv) + case reflect.Slice: + out.SetBytes(inv.Bytes()) + } + return true + } + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + switch inv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + out.SetInt(inv.Int()) + return true + case reflect.Float32, reflect.Float64: + out.SetInt(int64(inv.Float())) + return true + case reflect.Bool: + if inv.Bool() { + out.SetInt(1) + } else { + out.SetInt(0) + } + return true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + panic("can't happen: no uint types in BSON (!?)") + } + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + switch inv.Kind() { + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + out.SetUint(uint64(inv.Int())) + return true + case reflect.Float32, reflect.Float64: + out.SetUint(uint64(inv.Float())) + return true + case reflect.Bool: + if inv.Bool() { + out.SetUint(1) + } else { + out.SetUint(0) + } + return true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + panic("Can't happen. No uint types in BSON.") + } + case reflect.Float32, reflect.Float64: + switch inv.Kind() { + case reflect.Float32, reflect.Float64: + out.SetFloat(inv.Float()) + return true + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + out.SetFloat(float64(inv.Int())) + return true + case reflect.Bool: + if inv.Bool() { + out.SetFloat(1) + } else { + out.SetFloat(0) + } + return true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + panic("Can't happen. No uint types in BSON?") + } + case reflect.Bool: + switch inv.Kind() { + case reflect.Bool: + out.SetBool(inv.Bool()) + return true + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + out.SetBool(inv.Int() != 0) + return true + case reflect.Float32, reflect.Float64: + out.SetBool(inv.Float() != 0) + return true + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + panic("Can't happen. No uint types in BSON?") + } + case reflect.Struct: + if outt == typeURL && inv.Kind() == reflect.String { + u, err := url.Parse(inv.String()) + if err != nil { + panic(err) + } + out.Set(reflect.ValueOf(u).Elem()) + return true + } + } + + return false +} + +// -------------------------------------------------------------------------- +// Parsers of basic types. + +func (d *decoder) readRegEx() RegEx { + re := RegEx{} + re.Pattern = d.readCStr() + re.Options = d.readCStr() + return re +} + +func (d *decoder) readBinary() Binary { + l := d.readInt32() + b := Binary{} + b.Kind = d.readByte() + b.Data = d.readBytes(l) + if b.Kind == 0x02 && len(b.Data) >= 4 { + // Weird obsolete format with redundant length. + b.Data = b.Data[4:] + } + return b +} + +func (d *decoder) readStr() string { + l := d.readInt32() + b := d.readBytes(l - 1) + if d.readByte() != '\x00' { + corrupted() + } + return string(b) +} + +func (d *decoder) readCStr() string { + start := d.i + end := start + l := len(d.in) + for ; end != l; end++ { + if d.in[end] == '\x00' { + break + } + } + d.i = end + 1 + if d.i > l { + corrupted() + } + return string(d.in[start:end]) +} + +func (d *decoder) readBool() bool { + if d.readByte() == 1 { + return true + } + return false +} + +func (d *decoder) readFloat64() float64 { + return math.Float64frombits(uint64(d.readInt64())) +} + +func (d *decoder) readInt32() int32 { + b := d.readBytes(4) + return int32((uint32(b[0]) << 0) | + (uint32(b[1]) << 8) | + (uint32(b[2]) << 16) | + (uint32(b[3]) << 24)) +} + +func (d *decoder) readInt64() int64 { + b := d.readBytes(8) + return int64((uint64(b[0]) << 0) | + (uint64(b[1]) << 8) | + (uint64(b[2]) << 16) | + (uint64(b[3]) << 24) | + (uint64(b[4]) << 32) | + (uint64(b[5]) << 40) | + (uint64(b[6]) << 48) | + (uint64(b[7]) << 56)) +} + +func (d *decoder) readByte() byte { + i := d.i + d.i++ + if d.i > len(d.in) { + corrupted() + } + return d.in[i] +} + +func (d *decoder) readBytes(length int32) []byte { + start := d.i + d.i += int(length) + if d.i > len(d.in) { + corrupted() + } + return d.in[start : start+int(length)] +} diff --git a/vendor/gopkg.in/mgo.v2/bson/encode.go b/vendor/gopkg.in/mgo.v2/bson/encode.go new file mode 100644 index 000000000..e1015091b --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/bson/encode.go @@ -0,0 +1,503 @@ +// BSON library for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +// gobson - BSON library for Go. + +package bson + +import ( + "encoding/json" + "fmt" + "math" + "net/url" + "reflect" + "strconv" + "time" +) + +// -------------------------------------------------------------------------- +// Some internal infrastructure. + +var ( + typeBinary = reflect.TypeOf(Binary{}) + typeObjectId = reflect.TypeOf(ObjectId("")) + typeDBPointer = reflect.TypeOf(DBPointer{"", ObjectId("")}) + typeSymbol = reflect.TypeOf(Symbol("")) + typeMongoTimestamp = reflect.TypeOf(MongoTimestamp(0)) + typeOrderKey = reflect.TypeOf(MinKey) + typeDocElem = reflect.TypeOf(DocElem{}) + typeRawDocElem = reflect.TypeOf(RawDocElem{}) + typeRaw = reflect.TypeOf(Raw{}) + typeURL = reflect.TypeOf(url.URL{}) + typeTime = reflect.TypeOf(time.Time{}) + typeString = reflect.TypeOf("") + typeJSONNumber = reflect.TypeOf(json.Number("")) +) + +const itoaCacheSize = 32 + +var itoaCache []string + +func init() { + itoaCache = make([]string, itoaCacheSize) + for i := 0; i != itoaCacheSize; i++ { + itoaCache[i] = strconv.Itoa(i) + } +} + +func itoa(i int) string { + if i < itoaCacheSize { + return itoaCache[i] + } + return strconv.Itoa(i) +} + +// -------------------------------------------------------------------------- +// Marshaling of the document value itself. + +type encoder struct { + out []byte +} + +func (e *encoder) addDoc(v reflect.Value) { + for { + if vi, ok := v.Interface().(Getter); ok { + getv, err := vi.GetBSON() + if err != nil { + panic(err) + } + v = reflect.ValueOf(getv) + continue + } + if v.Kind() == reflect.Ptr { + v = v.Elem() + continue + } + break + } + + if v.Type() == typeRaw { + raw := v.Interface().(Raw) + if raw.Kind != 0x03 && raw.Kind != 0x00 { + panic("Attempted to unmarshal Raw kind " + strconv.Itoa(int(raw.Kind)) + " as a document") + } + e.addBytes(raw.Data...) + return + } + + start := e.reserveInt32() + + switch v.Kind() { + case reflect.Map: + e.addMap(v) + case reflect.Struct: + e.addStruct(v) + case reflect.Array, reflect.Slice: + e.addSlice(v) + default: + panic("Can't marshal " + v.Type().String() + " as a BSON document") + } + + e.addBytes(0) + e.setInt32(start, int32(len(e.out)-start)) +} + +func (e *encoder) addMap(v reflect.Value) { + for _, k := range v.MapKeys() { + e.addElem(k.String(), v.MapIndex(k), false) + } +} + +func (e *encoder) addStruct(v reflect.Value) { + sinfo, err := getStructInfo(v.Type()) + if err != nil { + panic(err) + } + var value reflect.Value + if sinfo.InlineMap >= 0 { + m := v.Field(sinfo.InlineMap) + if m.Len() > 0 { + for _, k := range m.MapKeys() { + ks := k.String() + if _, found := sinfo.FieldsMap[ks]; found { + panic(fmt.Sprintf("Can't have key %q in inlined map; conflicts with struct field", ks)) + } + e.addElem(ks, m.MapIndex(k), false) + } + } + } + for _, info := range sinfo.FieldsList { + if info.Inline == nil { + value = v.Field(info.Num) + } else { + value = v.FieldByIndex(info.Inline) + } + if info.OmitEmpty && isZero(value) { + continue + } + e.addElem(info.Key, value, info.MinSize) + } +} + +func isZero(v reflect.Value) bool { + switch v.Kind() { + case reflect.String: + return len(v.String()) == 0 + case reflect.Ptr, reflect.Interface: + return v.IsNil() + case reflect.Slice: + return v.Len() == 0 + case reflect.Map: + return v.Len() == 0 + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + return v.Int() == 0 + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + return v.Uint() == 0 + case reflect.Float32, reflect.Float64: + return v.Float() == 0 + case reflect.Bool: + return !v.Bool() + case reflect.Struct: + vt := v.Type() + if vt == typeTime { + return v.Interface().(time.Time).IsZero() + } + for i := 0; i < v.NumField(); i++ { + if vt.Field(i).PkgPath != "" { + continue // Private field + } + if !isZero(v.Field(i)) { + return false + } + } + return true + } + return false +} + +func (e *encoder) addSlice(v reflect.Value) { + vi := v.Interface() + if d, ok := vi.(D); ok { + for _, elem := range d { + e.addElem(elem.Name, reflect.ValueOf(elem.Value), false) + } + return + } + if d, ok := vi.(RawD); ok { + for _, elem := range d { + e.addElem(elem.Name, reflect.ValueOf(elem.Value), false) + } + return + } + l := v.Len() + et := v.Type().Elem() + if et == typeDocElem { + for i := 0; i < l; i++ { + elem := v.Index(i).Interface().(DocElem) + e.addElem(elem.Name, reflect.ValueOf(elem.Value), false) + } + return + } + if et == typeRawDocElem { + for i := 0; i < l; i++ { + elem := v.Index(i).Interface().(RawDocElem) + e.addElem(elem.Name, reflect.ValueOf(elem.Value), false) + } + return + } + for i := 0; i < l; i++ { + e.addElem(itoa(i), v.Index(i), false) + } +} + +// -------------------------------------------------------------------------- +// Marshaling of elements in a document. + +func (e *encoder) addElemName(kind byte, name string) { + e.addBytes(kind) + e.addBytes([]byte(name)...) + e.addBytes(0) +} + +func (e *encoder) addElem(name string, v reflect.Value, minSize bool) { + + if !v.IsValid() { + e.addElemName('\x0A', name) + return + } + + if getter, ok := v.Interface().(Getter); ok { + getv, err := getter.GetBSON() + if err != nil { + panic(err) + } + e.addElem(name, reflect.ValueOf(getv), minSize) + return + } + + switch v.Kind() { + + case reflect.Interface: + e.addElem(name, v.Elem(), minSize) + + case reflect.Ptr: + e.addElem(name, v.Elem(), minSize) + + case reflect.String: + s := v.String() + switch v.Type() { + case typeObjectId: + if len(s) != 12 { + panic("ObjectIDs must be exactly 12 bytes long (got " + + strconv.Itoa(len(s)) + ")") + } + e.addElemName('\x07', name) + e.addBytes([]byte(s)...) + case typeSymbol: + e.addElemName('\x0E', name) + e.addStr(s) + case typeJSONNumber: + n := v.Interface().(json.Number) + if i, err := n.Int64(); err == nil { + e.addElemName('\x12', name) + e.addInt64(i) + } else if f, err := n.Float64(); err == nil { + e.addElemName('\x01', name) + e.addFloat64(f) + } else { + panic("failed to convert json.Number to a number: " + s) + } + default: + e.addElemName('\x02', name) + e.addStr(s) + } + + case reflect.Float32, reflect.Float64: + e.addElemName('\x01', name) + e.addFloat64(v.Float()) + + case reflect.Uint, reflect.Uint8, reflect.Uint16, reflect.Uint32, reflect.Uint64, reflect.Uintptr: + u := v.Uint() + if int64(u) < 0 { + panic("BSON has no uint64 type, and value is too large to fit correctly in an int64") + } else if u <= math.MaxInt32 && (minSize || v.Kind() <= reflect.Uint32) { + e.addElemName('\x10', name) + e.addInt32(int32(u)) + } else { + e.addElemName('\x12', name) + e.addInt64(int64(u)) + } + + case reflect.Int, reflect.Int8, reflect.Int16, reflect.Int32, reflect.Int64: + switch v.Type() { + case typeMongoTimestamp: + e.addElemName('\x11', name) + e.addInt64(v.Int()) + + case typeOrderKey: + if v.Int() == int64(MaxKey) { + e.addElemName('\x7F', name) + } else { + e.addElemName('\xFF', name) + } + + default: + i := v.Int() + if (minSize || v.Type().Kind() != reflect.Int64) && i >= math.MinInt32 && i <= math.MaxInt32 { + // It fits into an int32, encode as such. + e.addElemName('\x10', name) + e.addInt32(int32(i)) + } else { + e.addElemName('\x12', name) + e.addInt64(i) + } + } + + case reflect.Bool: + e.addElemName('\x08', name) + if v.Bool() { + e.addBytes(1) + } else { + e.addBytes(0) + } + + case reflect.Map: + e.addElemName('\x03', name) + e.addDoc(v) + + case reflect.Slice: + vt := v.Type() + et := vt.Elem() + if et.Kind() == reflect.Uint8 { + e.addElemName('\x05', name) + e.addBinary('\x00', v.Bytes()) + } else if et == typeDocElem || et == typeRawDocElem { + e.addElemName('\x03', name) + e.addDoc(v) + } else { + e.addElemName('\x04', name) + e.addDoc(v) + } + + case reflect.Array: + et := v.Type().Elem() + if et.Kind() == reflect.Uint8 { + e.addElemName('\x05', name) + if v.CanAddr() { + e.addBinary('\x00', v.Slice(0, v.Len()).Interface().([]byte)) + } else { + n := v.Len() + e.addInt32(int32(n)) + e.addBytes('\x00') + for i := 0; i < n; i++ { + el := v.Index(i) + e.addBytes(byte(el.Uint())) + } + } + } else { + e.addElemName('\x04', name) + e.addDoc(v) + } + + case reflect.Struct: + switch s := v.Interface().(type) { + + case Raw: + kind := s.Kind + if kind == 0x00 { + kind = 0x03 + } + e.addElemName(kind, name) + e.addBytes(s.Data...) + + case Binary: + e.addElemName('\x05', name) + e.addBinary(s.Kind, s.Data) + + case DBPointer: + e.addElemName('\x0C', name) + e.addStr(s.Namespace) + if len(s.Id) != 12 { + panic("ObjectIDs must be exactly 12 bytes long (got " + + strconv.Itoa(len(s.Id)) + ")") + } + e.addBytes([]byte(s.Id)...) + + case RegEx: + e.addElemName('\x0B', name) + e.addCStr(s.Pattern) + e.addCStr(s.Options) + + case JavaScript: + if s.Scope == nil { + e.addElemName('\x0D', name) + e.addStr(s.Code) + } else { + e.addElemName('\x0F', name) + start := e.reserveInt32() + e.addStr(s.Code) + e.addDoc(reflect.ValueOf(s.Scope)) + e.setInt32(start, int32(len(e.out)-start)) + } + + case time.Time: + // MongoDB handles timestamps as milliseconds. + e.addElemName('\x09', name) + e.addInt64(s.Unix()*1000 + int64(s.Nanosecond()/1e6)) + + case url.URL: + e.addElemName('\x02', name) + e.addStr(s.String()) + + case undefined: + e.addElemName('\x06', name) + + default: + e.addElemName('\x03', name) + e.addDoc(v) + } + + default: + panic("Can't marshal " + v.Type().String() + " in a BSON document") + } +} + +// -------------------------------------------------------------------------- +// Marshaling of base types. + +func (e *encoder) addBinary(subtype byte, v []byte) { + if subtype == 0x02 { + // Wonder how that brilliant idea came to life. Obsolete, luckily. + e.addInt32(int32(len(v) + 4)) + e.addBytes(subtype) + e.addInt32(int32(len(v))) + } else { + e.addInt32(int32(len(v))) + e.addBytes(subtype) + } + e.addBytes(v...) +} + +func (e *encoder) addStr(v string) { + e.addInt32(int32(len(v) + 1)) + e.addCStr(v) +} + +func (e *encoder) addCStr(v string) { + e.addBytes([]byte(v)...) + e.addBytes(0) +} + +func (e *encoder) reserveInt32() (pos int) { + pos = len(e.out) + e.addBytes(0, 0, 0, 0) + return pos +} + +func (e *encoder) setInt32(pos int, v int32) { + e.out[pos+0] = byte(v) + e.out[pos+1] = byte(v >> 8) + e.out[pos+2] = byte(v >> 16) + e.out[pos+3] = byte(v >> 24) +} + +func (e *encoder) addInt32(v int32) { + u := uint32(v) + e.addBytes(byte(u), byte(u>>8), byte(u>>16), byte(u>>24)) +} + +func (e *encoder) addInt64(v int64) { + u := uint64(v) + e.addBytes(byte(u), byte(u>>8), byte(u>>16), byte(u>>24), + byte(u>>32), byte(u>>40), byte(u>>48), byte(u>>56)) +} + +func (e *encoder) addFloat64(v float64) { + e.addInt64(int64(math.Float64bits(v))) +} + +func (e *encoder) addBytes(v ...byte) { + e.out = append(e.out, v...) +} diff --git a/vendor/gopkg.in/mgo.v2/bulk.go b/vendor/gopkg.in/mgo.v2/bulk.go new file mode 100644 index 000000000..23f450853 --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/bulk.go @@ -0,0 +1,71 @@ +package mgo + +// Bulk represents an operation that can be prepared with several +// orthogonal changes before being delivered to the server. +// +// WARNING: This API is still experimental. +// +// Relevant documentation: +// +// http://blog.mongodb.org/post/84922794768/mongodbs-new-bulk-api +// +type Bulk struct { + c *Collection + ordered bool + inserts []interface{} +} + +// BulkError holds an error returned from running a Bulk operation. +// +// TODO: This is private for the moment, until we understand exactly how +// to report these multi-errors in a useful and convenient way. +type bulkError struct { + err error +} + +// BulkResult holds the results for a bulk operation. +type BulkResult struct { + // Be conservative while we understand exactly how to report these + // results in a useful and convenient way, and also how to emulate + // them with prior servers. + private bool +} + +func (e *bulkError) Error() string { + return e.err.Error() +} + +// Bulk returns a value to prepare the execution of a bulk operation. +// +// WARNING: This API is still experimental. +// +func (c *Collection) Bulk() *Bulk { + return &Bulk{c: c, ordered: true} +} + +// Unordered puts the bulk operation in unordered mode. +// +// In unordered mode the indvidual operations may be sent +// out of order, which means latter operations may proceed +// even if prior ones have failed. +func (b *Bulk) Unordered() { + b.ordered = false +} + +// Insert queues up the provided documents for insertion. +func (b *Bulk) Insert(docs ...interface{}) { + b.inserts = append(b.inserts, docs...) +} + +// Run runs all the operations queued up. +func (b *Bulk) Run() (*BulkResult, error) { + op := &insertOp{b.c.FullName, b.inserts, 0} + if !b.ordered { + op.flags = 1 // ContinueOnError + } + _, err := b.c.writeQuery(op) + if err != nil { + return nil, &bulkError{err} + } + return &BulkResult{}, nil +} diff --git a/vendor/gopkg.in/mgo.v2/bulk_test.go b/vendor/gopkg.in/mgo.v2/bulk_test.go new file mode 100644 index 000000000..d231d59d0 --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/bulk_test.go @@ -0,0 +1,131 @@ +// mgo - MongoDB driver for Go +// +// Copyright (c) 2010-2014 - Gustavo Niemeyer +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package mgo_test + +import ( + . "gopkg.in/check.v1" + "gopkg.in/mgo.v2" +) + +func (s *S) TestBulkInsert(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + bulk := coll.Bulk() + bulk.Insert(M{"n": 1}) + bulk.Insert(M{"n": 2}, M{"n": 3}) + r, err := bulk.Run() + c.Assert(err, IsNil) + c.Assert(r, FitsTypeOf, &mgo.BulkResult{}) + + type doc struct{ N int } + var res []doc + err = coll.Find(nil).Sort("n").All(&res) + c.Assert(err, IsNil) + c.Assert(res, DeepEquals, []doc{{1}, {2}, {3}}) +} + +func (s *S) TestBulkInsertError(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + bulk := coll.Bulk() + bulk.Insert(M{"_id": 1}, M{"_id": 2}, M{"_id": 2}, M{"_id": 3}) + _, err = bulk.Run() + c.Assert(err, ErrorMatches, ".*duplicate key.*") + + type doc struct { + N int `_id` + } + var res []doc + err = coll.Find(nil).Sort("_id").All(&res) + c.Assert(err, IsNil) + c.Assert(res, DeepEquals, []doc{{1}, {2}}) +} + +func (s *S) TestBulkInsertErrorUnordered(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + bulk := coll.Bulk() + bulk.Unordered() + bulk.Insert(M{"_id": 1}, M{"_id": 2}, M{"_id": 2}, M{"_id": 3}) + _, err = bulk.Run() + c.Assert(err, ErrorMatches, ".*duplicate key.*") + + type doc struct { + N int `_id` + } + var res []doc + err = coll.Find(nil).Sort("_id").All(&res) + c.Assert(err, IsNil) + c.Assert(res, DeepEquals, []doc{{1}, {2}, {3}}) +} + +func (s *S) TestBulkInsertErrorUnorderedSplitBatch(c *C) { + // The server has a batch limit of 1000 documents when using write commands. + // This artificial limit did not exist with the old wire protocol, so to + // avoid compatibility issues the implementation internally split batches + // into the proper size and delivers them one by one. This test ensures that + // the behavior of unordered (that is, continue on error) remains correct + // when errors happen and there are batches left. + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + bulk := coll.Bulk() + bulk.Unordered() + + const total = 4096 + type doc struct { + Id int `_id` + } + docs := make([]interface{}, total) + for i := 0; i < total; i++ { + docs[i] = doc{i} + } + docs[1] = doc{0} + bulk.Insert(docs...) + _, err = bulk.Run() + c.Assert(err, ErrorMatches, ".*duplicate key.*") + + n, err := coll.Count() + c.Assert(err, IsNil) + c.Assert(n, Equals, total-1) + + var res doc + err = coll.FindId(1500).One(&res) + c.Assert(err, IsNil) + c.Assert(res.Id, Equals, 1500) +} diff --git a/vendor/gopkg.in/mgo.v2/cluster.go b/vendor/gopkg.in/mgo.v2/cluster.go new file mode 100644 index 000000000..9ea0cb9c1 --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/cluster.go @@ -0,0 +1,632 @@ +// mgo - MongoDB driver for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package mgo + +import ( + "errors" + "fmt" + "net" + "sync" + "time" + + "gopkg.in/mgo.v2/bson" +) + +// --------------------------------------------------------------------------- +// Mongo cluster encapsulation. +// +// A cluster enables the communication with one or more servers participating +// in a mongo cluster. This works with individual servers, a replica set, +// a replica pair, one or multiple mongos routers, etc. + +type mongoCluster struct { + sync.RWMutex + serverSynced sync.Cond + userSeeds []string + dynaSeeds []string + servers mongoServers + masters mongoServers + references int + syncing bool + direct bool + failFast bool + syncCount uint + setName string + cachedIndex map[string]bool + sync chan bool + dial dialer +} + +func newCluster(userSeeds []string, direct, failFast bool, dial dialer, setName string) *mongoCluster { + cluster := &mongoCluster{ + userSeeds: userSeeds, + references: 1, + direct: direct, + failFast: failFast, + dial: dial, + setName: setName, + } + cluster.serverSynced.L = cluster.RWMutex.RLocker() + cluster.sync = make(chan bool, 1) + stats.cluster(+1) + go cluster.syncServersLoop() + return cluster +} + +// Acquire increases the reference count for the cluster. +func (cluster *mongoCluster) Acquire() { + cluster.Lock() + cluster.references++ + debugf("Cluster %p acquired (refs=%d)", cluster, cluster.references) + cluster.Unlock() +} + +// Release decreases the reference count for the cluster. Once +// it reaches zero, all servers will be closed. +func (cluster *mongoCluster) Release() { + cluster.Lock() + if cluster.references == 0 { + panic("cluster.Release() with references == 0") + } + cluster.references-- + debugf("Cluster %p released (refs=%d)", cluster, cluster.references) + if cluster.references == 0 { + for _, server := range cluster.servers.Slice() { + server.Close() + } + // Wake up the sync loop so it can die. + cluster.syncServers() + stats.cluster(-1) + } + cluster.Unlock() +} + +func (cluster *mongoCluster) LiveServers() (servers []string) { + cluster.RLock() + for _, serv := range cluster.servers.Slice() { + servers = append(servers, serv.Addr) + } + cluster.RUnlock() + return servers +} + +func (cluster *mongoCluster) removeServer(server *mongoServer) { + cluster.Lock() + cluster.masters.Remove(server) + other := cluster.servers.Remove(server) + cluster.Unlock() + if other != nil { + other.Close() + log("Removed server ", server.Addr, " from cluster.") + } + server.Close() +} + +type isMasterResult struct { + IsMaster bool + Secondary bool + Primary string + Hosts []string + Passives []string + Tags bson.D + Msg string + SetName string `bson:"setName"` + MaxWireVersion int `bson:"maxWireVersion"` +} + +func (cluster *mongoCluster) isMaster(socket *mongoSocket, result *isMasterResult) error { + // Monotonic let's it talk to a slave and still hold the socket. + session := newSession(Monotonic, cluster, 10*time.Second) + session.setSocket(socket) + err := session.Run("ismaster", result) + session.Close() + return err +} + +type possibleTimeout interface { + Timeout() bool +} + +var syncSocketTimeout = 5 * time.Second + +func (cluster *mongoCluster) syncServer(server *mongoServer) (info *mongoServerInfo, hosts []string, err error) { + var syncTimeout time.Duration + if raceDetector { + // This variable is only ever touched by tests. + globalMutex.Lock() + syncTimeout = syncSocketTimeout + globalMutex.Unlock() + } else { + syncTimeout = syncSocketTimeout + } + + addr := server.Addr + log("SYNC Processing ", addr, "...") + + // Retry a few times to avoid knocking a server down for a hiccup. + var result isMasterResult + var tryerr error + for retry := 0; ; retry++ { + if retry == 3 || retry == 1 && cluster.failFast { + return nil, nil, tryerr + } + if retry > 0 { + // Don't abuse the server needlessly if there's something actually wrong. + if err, ok := tryerr.(possibleTimeout); ok && err.Timeout() { + // Give a chance for waiters to timeout as well. + cluster.serverSynced.Broadcast() + } + time.Sleep(syncShortDelay) + } + + // It's not clear what would be a good timeout here. Is it + // better to wait longer or to retry? + socket, _, err := server.AcquireSocket(0, syncTimeout) + if err != nil { + tryerr = err + logf("SYNC Failed to get socket to %s: %v", addr, err) + continue + } + err = cluster.isMaster(socket, &result) + socket.Release() + if err != nil { + tryerr = err + logf("SYNC Command 'ismaster' to %s failed: %v", addr, err) + continue + } + debugf("SYNC Result of 'ismaster' from %s: %#v", addr, result) + break + } + + if cluster.setName != "" && result.SetName != cluster.setName { + logf("SYNC Server %s is not a member of replica set %q", addr, cluster.setName) + return nil, nil, fmt.Errorf("server %s is not a member of replica set %q", addr, cluster.setName) + } + + if result.IsMaster { + debugf("SYNC %s is a master.", addr) + if !server.info.Master { + // Made an incorrect assumption above, so fix stats. + stats.conn(-1, false) + stats.conn(+1, true) + } + } else if result.Secondary { + debugf("SYNC %s is a slave.", addr) + } else if cluster.direct { + logf("SYNC %s in unknown state. Pretending it's a slave due to direct connection.", addr) + } else { + logf("SYNC %s is neither a master nor a slave.", addr) + // Let stats track it as whatever was known before. + return nil, nil, errors.New(addr + " is not a master nor slave") + } + + info = &mongoServerInfo{ + Master: result.IsMaster, + Mongos: result.Msg == "isdbgrid", + Tags: result.Tags, + SetName: result.SetName, + MaxWireVersion: result.MaxWireVersion, + } + + hosts = make([]string, 0, 1+len(result.Hosts)+len(result.Passives)) + if result.Primary != "" { + // First in the list to speed up master discovery. + hosts = append(hosts, result.Primary) + } + hosts = append(hosts, result.Hosts...) + hosts = append(hosts, result.Passives...) + + debugf("SYNC %s knows about the following peers: %#v", addr, hosts) + return info, hosts, nil +} + +type syncKind bool + +const ( + completeSync syncKind = true + partialSync syncKind = false +) + +func (cluster *mongoCluster) addServer(server *mongoServer, info *mongoServerInfo, syncKind syncKind) { + cluster.Lock() + current := cluster.servers.Search(server.ResolvedAddr) + if current == nil { + if syncKind == partialSync { + cluster.Unlock() + server.Close() + log("SYNC Discarding unknown server ", server.Addr, " due to partial sync.") + return + } + cluster.servers.Add(server) + if info.Master { + cluster.masters.Add(server) + log("SYNC Adding ", server.Addr, " to cluster as a master.") + } else { + log("SYNC Adding ", server.Addr, " to cluster as a slave.") + } + } else { + if server != current { + panic("addServer attempting to add duplicated server") + } + if server.Info().Master != info.Master { + if info.Master { + log("SYNC Server ", server.Addr, " is now a master.") + cluster.masters.Add(server) + } else { + log("SYNC Server ", server.Addr, " is now a slave.") + cluster.masters.Remove(server) + } + } + } + server.SetInfo(info) + debugf("SYNC Broadcasting availability of server %s", server.Addr) + cluster.serverSynced.Broadcast() + cluster.Unlock() +} + +func (cluster *mongoCluster) getKnownAddrs() []string { + cluster.RLock() + max := len(cluster.userSeeds) + len(cluster.dynaSeeds) + cluster.servers.Len() + seen := make(map[string]bool, max) + known := make([]string, 0, max) + + add := func(addr string) { + if _, found := seen[addr]; !found { + seen[addr] = true + known = append(known, addr) + } + } + + for _, addr := range cluster.userSeeds { + add(addr) + } + for _, addr := range cluster.dynaSeeds { + add(addr) + } + for _, serv := range cluster.servers.Slice() { + add(serv.Addr) + } + cluster.RUnlock() + + return known +} + +// syncServers injects a value into the cluster.sync channel to force +// an iteration of the syncServersLoop function. +func (cluster *mongoCluster) syncServers() { + select { + case cluster.sync <- true: + default: + } +} + +// How long to wait for a checkup of the cluster topology if nothing +// else kicks a synchronization before that. +const syncServersDelay = 30 * time.Second +const syncShortDelay = 500 * time.Millisecond + +// syncServersLoop loops while the cluster is alive to keep its idea of +// the server topology up-to-date. It must be called just once from +// newCluster. The loop iterates once syncServersDelay has passed, or +// if somebody injects a value into the cluster.sync channel to force a +// synchronization. A loop iteration will contact all servers in +// parallel, ask them about known peers and their own role within the +// cluster, and then attempt to do the same with all the peers +// retrieved. +func (cluster *mongoCluster) syncServersLoop() { + for { + debugf("SYNC Cluster %p is starting a sync loop iteration.", cluster) + + cluster.Lock() + if cluster.references == 0 { + cluster.Unlock() + break + } + cluster.references++ // Keep alive while syncing. + direct := cluster.direct + cluster.Unlock() + + cluster.syncServersIteration(direct) + + // We just synchronized, so consume any outstanding requests. + select { + case <-cluster.sync: + default: + } + + cluster.Release() + + // Hold off before allowing another sync. No point in + // burning CPU looking for down servers. + if !cluster.failFast { + time.Sleep(syncShortDelay) + } + + cluster.Lock() + if cluster.references == 0 { + cluster.Unlock() + break + } + cluster.syncCount++ + // Poke all waiters so they have a chance to timeout or + // restart syncing if they wish to. + cluster.serverSynced.Broadcast() + // Check if we have to restart immediately either way. + restart := !direct && cluster.masters.Empty() || cluster.servers.Empty() + cluster.Unlock() + + if restart { + log("SYNC No masters found. Will synchronize again.") + time.Sleep(syncShortDelay) + continue + } + + debugf("SYNC Cluster %p waiting for next requested or scheduled sync.", cluster) + + // Hold off until somebody explicitly requests a synchronization + // or it's time to check for a cluster topology change again. + select { + case <-cluster.sync: + case <-time.After(syncServersDelay): + } + } + debugf("SYNC Cluster %p is stopping its sync loop.", cluster) +} + +func (cluster *mongoCluster) server(addr string, tcpaddr *net.TCPAddr) *mongoServer { + cluster.RLock() + server := cluster.servers.Search(tcpaddr.String()) + cluster.RUnlock() + if server != nil { + return server + } + return newServer(addr, tcpaddr, cluster.sync, cluster.dial) +} + +func resolveAddr(addr string) (*net.TCPAddr, error) { + // This hack allows having a timeout on resolution. + conn, err := net.DialTimeout("udp4", addr, 10*time.Second) + if err != nil { + log("SYNC Failed to resolve server address: ", addr) + return nil, errors.New("failed to resolve server address: " + addr) + } + tcpaddr := (*net.TCPAddr)(conn.RemoteAddr().(*net.UDPAddr)) + conn.Close() + if tcpaddr.String() != addr { + debug("SYNC Address ", addr, " resolved as ", tcpaddr.String()) + } + return tcpaddr, nil +} + +type pendingAdd struct { + server *mongoServer + info *mongoServerInfo +} + +func (cluster *mongoCluster) syncServersIteration(direct bool) { + log("SYNC Starting full topology synchronization...") + + var wg sync.WaitGroup + var m sync.Mutex + notYetAdded := make(map[string]pendingAdd) + addIfFound := make(map[string]bool) + seen := make(map[string]bool) + syncKind := partialSync + + var spawnSync func(addr string, byMaster bool) + spawnSync = func(addr string, byMaster bool) { + wg.Add(1) + go func() { + defer wg.Done() + + tcpaddr, err := resolveAddr(addr) + if err != nil { + log("SYNC Failed to start sync of ", addr, ": ", err.Error()) + return + } + resolvedAddr := tcpaddr.String() + + m.Lock() + if byMaster { + if pending, ok := notYetAdded[resolvedAddr]; ok { + delete(notYetAdded, resolvedAddr) + m.Unlock() + cluster.addServer(pending.server, pending.info, completeSync) + return + } + addIfFound[resolvedAddr] = true + } + if seen[resolvedAddr] { + m.Unlock() + return + } + seen[resolvedAddr] = true + m.Unlock() + + server := cluster.server(addr, tcpaddr) + info, hosts, err := cluster.syncServer(server) + if err != nil { + cluster.removeServer(server) + return + } + + m.Lock() + add := direct || info.Master || addIfFound[resolvedAddr] + if add { + syncKind = completeSync + } else { + notYetAdded[resolvedAddr] = pendingAdd{server, info} + } + m.Unlock() + if add { + cluster.addServer(server, info, completeSync) + } + if !direct { + for _, addr := range hosts { + spawnSync(addr, info.Master) + } + } + }() + } + + knownAddrs := cluster.getKnownAddrs() + for _, addr := range knownAddrs { + spawnSync(addr, false) + } + wg.Wait() + + if syncKind == completeSync { + logf("SYNC Synchronization was complete (got data from primary).") + for _, pending := range notYetAdded { + cluster.removeServer(pending.server) + } + } else { + logf("SYNC Synchronization was partial (cannot talk to primary).") + for _, pending := range notYetAdded { + cluster.addServer(pending.server, pending.info, partialSync) + } + } + + cluster.Lock() + ml := cluster.masters.Len() + logf("SYNC Synchronization completed: %d master(s) and %d slave(s) alive.", ml, cluster.servers.Len()-ml) + + // Update dynamic seeds, but only if we have any good servers. Otherwise, + // leave them alone for better chances of a successful sync in the future. + if syncKind == completeSync { + dynaSeeds := make([]string, cluster.servers.Len()) + for i, server := range cluster.servers.Slice() { + dynaSeeds[i] = server.Addr + } + cluster.dynaSeeds = dynaSeeds + debugf("SYNC New dynamic seeds: %#v\n", dynaSeeds) + } + cluster.Unlock() +} + +// AcquireSocket returns a socket to a server in the cluster. If slaveOk is +// true, it will attempt to return a socket to a slave server. If it is +// false, the socket will necessarily be to a master server. +func (cluster *mongoCluster) AcquireSocket(slaveOk bool, syncTimeout time.Duration, socketTimeout time.Duration, serverTags []bson.D, poolLimit int) (s *mongoSocket, err error) { + var started time.Time + var syncCount uint + warnedLimit := false + for { + cluster.RLock() + for { + ml := cluster.masters.Len() + sl := cluster.servers.Len() + debugf("Cluster has %d known masters and %d known slaves.", ml, sl-ml) + if ml > 0 || slaveOk && sl > 0 { + break + } + if started.IsZero() { + // Initialize after fast path above. + started = time.Now() + syncCount = cluster.syncCount + } else if syncTimeout != 0 && started.Before(time.Now().Add(-syncTimeout)) || cluster.failFast && cluster.syncCount != syncCount { + cluster.RUnlock() + return nil, errors.New("no reachable servers") + } + log("Waiting for servers to synchronize...") + cluster.syncServers() + + // Remember: this will release and reacquire the lock. + cluster.serverSynced.Wait() + } + + var server *mongoServer + if slaveOk { + server = cluster.servers.BestFit(serverTags) + } else { + server = cluster.masters.BestFit(nil) + } + cluster.RUnlock() + + if server == nil { + // Must have failed the requested tags. Sleep to avoid spinning. + time.Sleep(1e8) + continue + } + + s, abended, err := server.AcquireSocket(poolLimit, socketTimeout) + if err == errPoolLimit { + if !warnedLimit { + warnedLimit = true + log("WARNING: Per-server connection limit reached.") + } + time.Sleep(100 * time.Millisecond) + continue + } + if err != nil { + cluster.removeServer(server) + cluster.syncServers() + continue + } + if abended && !slaveOk { + var result isMasterResult + err := cluster.isMaster(s, &result) + if err != nil || !result.IsMaster { + logf("Cannot confirm server %s as master (%v)", server.Addr, err) + s.Release() + cluster.syncServers() + time.Sleep(100 * time.Millisecond) + continue + } + } + return s, nil + } + panic("unreached") +} + +func (cluster *mongoCluster) CacheIndex(cacheKey string, exists bool) { + cluster.Lock() + if cluster.cachedIndex == nil { + cluster.cachedIndex = make(map[string]bool) + } + if exists { + cluster.cachedIndex[cacheKey] = true + } else { + delete(cluster.cachedIndex, cacheKey) + } + cluster.Unlock() +} + +func (cluster *mongoCluster) HasCachedIndex(cacheKey string) (result bool) { + cluster.RLock() + if cluster.cachedIndex != nil { + result = cluster.cachedIndex[cacheKey] + } + cluster.RUnlock() + return +} + +func (cluster *mongoCluster) ResetIndexCache() { + cluster.Lock() + cluster.cachedIndex = make(map[string]bool) + cluster.Unlock() +} diff --git a/vendor/gopkg.in/mgo.v2/cluster_test.go b/vendor/gopkg.in/mgo.v2/cluster_test.go new file mode 100644 index 000000000..0b0b09a4b --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/cluster_test.go @@ -0,0 +1,1657 @@ +// mgo - MongoDB driver for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package mgo_test + +import ( + "fmt" + "io" + "net" + "strings" + "sync" + "time" + + . "gopkg.in/check.v1" + "gopkg.in/mgo.v2" + "gopkg.in/mgo.v2/bson" +) + +func (s *S) TestNewSession(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + // Do a dummy operation to wait for connection. + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"_id": 1}) + c.Assert(err, IsNil) + + // Tweak safety and query settings to ensure other has copied those. + session.SetSafe(nil) + session.SetBatch(-1) + other := session.New() + defer other.Close() + session.SetSafe(&mgo.Safe{}) + + // Clone was copied while session was unsafe, so no errors. + otherColl := other.DB("mydb").C("mycoll") + err = otherColl.Insert(M{"_id": 1}) + c.Assert(err, IsNil) + + // Original session was made safe again. + err = coll.Insert(M{"_id": 1}) + c.Assert(err, NotNil) + + // With New(), each session has its own socket now. + stats := mgo.GetStats() + c.Assert(stats.MasterConns, Equals, 2) + c.Assert(stats.SocketsInUse, Equals, 2) + + // Ensure query parameters were cloned. + err = otherColl.Insert(M{"_id": 2}) + c.Assert(err, IsNil) + + // Ping the database to ensure the nonce has been received already. + c.Assert(other.Ping(), IsNil) + + mgo.ResetStats() + + iter := otherColl.Find(M{}).Iter() + c.Assert(err, IsNil) + + m := M{} + ok := iter.Next(m) + c.Assert(ok, Equals, true) + err = iter.Close() + c.Assert(err, IsNil) + + // If Batch(-1) is in effect, a single document must have been received. + stats = mgo.GetStats() + c.Assert(stats.ReceivedDocs, Equals, 1) +} + +func (s *S) TestCloneSession(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + // Do a dummy operation to wait for connection. + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"_id": 1}) + c.Assert(err, IsNil) + + // Tweak safety and query settings to ensure clone is copying those. + session.SetSafe(nil) + session.SetBatch(-1) + clone := session.Clone() + defer clone.Close() + session.SetSafe(&mgo.Safe{}) + + // Clone was copied while session was unsafe, so no errors. + cloneColl := clone.DB("mydb").C("mycoll") + err = cloneColl.Insert(M{"_id": 1}) + c.Assert(err, IsNil) + + // Original session was made safe again. + err = coll.Insert(M{"_id": 1}) + c.Assert(err, NotNil) + + // With Clone(), same socket is shared between sessions now. + stats := mgo.GetStats() + c.Assert(stats.SocketsInUse, Equals, 1) + c.Assert(stats.SocketRefs, Equals, 2) + + // Refreshing one of them should let the original socket go, + // while preserving the safety settings. + clone.Refresh() + err = cloneColl.Insert(M{"_id": 1}) + c.Assert(err, IsNil) + + // Must have used another connection now. + stats = mgo.GetStats() + c.Assert(stats.SocketsInUse, Equals, 2) + c.Assert(stats.SocketRefs, Equals, 2) + + // Ensure query parameters were cloned. + err = cloneColl.Insert(M{"_id": 2}) + c.Assert(err, IsNil) + + // Ping the database to ensure the nonce has been received already. + c.Assert(clone.Ping(), IsNil) + + mgo.ResetStats() + + iter := cloneColl.Find(M{}).Iter() + c.Assert(err, IsNil) + + m := M{} + ok := iter.Next(m) + c.Assert(ok, Equals, true) + err = iter.Close() + c.Assert(err, IsNil) + + // If Batch(-1) is in effect, a single document must have been received. + stats = mgo.GetStats() + c.Assert(stats.ReceivedDocs, Equals, 1) +} + +func (s *S) TestSetModeStrong(c *C) { + session, err := mgo.Dial("localhost:40012") + c.Assert(err, IsNil) + defer session.Close() + + session.SetMode(mgo.Monotonic, false) + session.SetMode(mgo.Strong, false) + + c.Assert(session.Mode(), Equals, mgo.Strong) + + result := M{} + cmd := session.DB("admin").C("$cmd") + err = cmd.Find(M{"ismaster": 1}).One(&result) + c.Assert(err, IsNil) + c.Assert(result["ismaster"], Equals, true) + + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"a": 1}) + c.Assert(err, IsNil) + + // Wait since the sync also uses sockets. + for len(session.LiveServers()) != 3 { + c.Log("Waiting for cluster sync to finish...") + time.Sleep(5e8) + } + + stats := mgo.GetStats() + c.Assert(stats.MasterConns, Equals, 1) + c.Assert(stats.SlaveConns, Equals, 2) + c.Assert(stats.SocketsInUse, Equals, 1) + + session.SetMode(mgo.Strong, true) + + stats = mgo.GetStats() + c.Assert(stats.SocketsInUse, Equals, 0) +} + +func (s *S) TestSetModeMonotonic(c *C) { + // Must necessarily connect to a slave, otherwise the + // master connection will be available first. + session, err := mgo.Dial("localhost:40012") + c.Assert(err, IsNil) + defer session.Close() + + session.SetMode(mgo.Monotonic, false) + + c.Assert(session.Mode(), Equals, mgo.Monotonic) + + result := M{} + cmd := session.DB("admin").C("$cmd") + err = cmd.Find(M{"ismaster": 1}).One(&result) + c.Assert(err, IsNil) + c.Assert(result["ismaster"], Equals, false) + + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"a": 1}) + c.Assert(err, IsNil) + + result = M{} + err = cmd.Find(M{"ismaster": 1}).One(&result) + c.Assert(err, IsNil) + c.Assert(result["ismaster"], Equals, true) + + // Wait since the sync also uses sockets. + for len(session.LiveServers()) != 3 { + c.Log("Waiting for cluster sync to finish...") + time.Sleep(5e8) + } + + stats := mgo.GetStats() + c.Assert(stats.MasterConns, Equals, 1) + c.Assert(stats.SlaveConns, Equals, 2) + c.Assert(stats.SocketsInUse, Equals, 2) + + session.SetMode(mgo.Monotonic, true) + + stats = mgo.GetStats() + c.Assert(stats.SocketsInUse, Equals, 0) +} + +func (s *S) TestSetModeMonotonicAfterStrong(c *C) { + // Test that a strong session shifting to a monotonic + // one preserves the socket untouched. + + session, err := mgo.Dial("localhost:40012") + c.Assert(err, IsNil) + defer session.Close() + + // Insert something to force a connection to the master. + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"a": 1}) + c.Assert(err, IsNil) + + session.SetMode(mgo.Monotonic, false) + + // Wait since the sync also uses sockets. + for len(session.LiveServers()) != 3 { + c.Log("Waiting for cluster sync to finish...") + time.Sleep(5e8) + } + + // Master socket should still be reserved. + stats := mgo.GetStats() + c.Assert(stats.SocketsInUse, Equals, 1) + + // Confirm it's the master even though it's Monotonic by now. + result := M{} + cmd := session.DB("admin").C("$cmd") + err = cmd.Find(M{"ismaster": 1}).One(&result) + c.Assert(err, IsNil) + c.Assert(result["ismaster"], Equals, true) +} + +func (s *S) TestSetModeStrongAfterMonotonic(c *C) { + // Test that shifting from Monotonic to Strong while + // using a slave socket will keep the socket reserved + // until the master socket is necessary, so that no + // switch over occurs unless it's actually necessary. + + // Must necessarily connect to a slave, otherwise the + // master connection will be available first. + session, err := mgo.Dial("localhost:40012") + c.Assert(err, IsNil) + defer session.Close() + + session.SetMode(mgo.Monotonic, false) + + // Ensure we're talking to a slave, and reserve the socket. + result := M{} + err = session.Run("ismaster", &result) + c.Assert(err, IsNil) + c.Assert(result["ismaster"], Equals, false) + + // Switch to a Strong session. + session.SetMode(mgo.Strong, false) + + // Wait since the sync also uses sockets. + for len(session.LiveServers()) != 3 { + c.Log("Waiting for cluster sync to finish...") + time.Sleep(5e8) + } + + // Slave socket should still be reserved. + stats := mgo.GetStats() + c.Assert(stats.SocketsInUse, Equals, 1) + + // But any operation will switch it to the master. + result = M{} + err = session.Run("ismaster", &result) + c.Assert(err, IsNil) + c.Assert(result["ismaster"], Equals, true) +} + +func (s *S) TestSetModeMonotonicWriteOnIteration(c *C) { + // Must necessarily connect to a slave, otherwise the + // master connection will be available first. + session, err := mgo.Dial("localhost:40012") + c.Assert(err, IsNil) + defer session.Close() + + session.SetMode(mgo.Monotonic, false) + + c.Assert(session.Mode(), Equals, mgo.Monotonic) + + coll1 := session.DB("mydb").C("mycoll1") + coll2 := session.DB("mydb").C("mycoll2") + + ns := []int{40, 41, 42, 43, 44, 45, 46} + for _, n := range ns { + err := coll1.Insert(M{"n": n}) + c.Assert(err, IsNil) + } + + // Release master so we can grab a slave again. + session.Refresh() + + // Wait until synchronization is done. + for { + n, err := coll1.Count() + c.Assert(err, IsNil) + if n == len(ns) { + break + } + } + + iter := coll1.Find(nil).Batch(2).Iter() + i := 0 + m := M{} + for iter.Next(&m) { + i++ + if i > 3 { + err := coll2.Insert(M{"n": 47 + i}) + c.Assert(err, IsNil) + } + } + c.Assert(i, Equals, len(ns)) +} + +func (s *S) TestSetModeEventual(c *C) { + // Must necessarily connect to a slave, otherwise the + // master connection will be available first. + session, err := mgo.Dial("localhost:40012") + c.Assert(err, IsNil) + defer session.Close() + + session.SetMode(mgo.Eventual, false) + + c.Assert(session.Mode(), Equals, mgo.Eventual) + + result := M{} + err = session.Run("ismaster", &result) + c.Assert(err, IsNil) + c.Assert(result["ismaster"], Equals, false) + + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"a": 1}) + c.Assert(err, IsNil) + + result = M{} + err = session.Run("ismaster", &result) + c.Assert(err, IsNil) + c.Assert(result["ismaster"], Equals, false) + + // Wait since the sync also uses sockets. + for len(session.LiveServers()) != 3 { + c.Log("Waiting for cluster sync to finish...") + time.Sleep(5e8) + } + + stats := mgo.GetStats() + c.Assert(stats.MasterConns, Equals, 1) + c.Assert(stats.SlaveConns, Equals, 2) + c.Assert(stats.SocketsInUse, Equals, 0) +} + +func (s *S) TestSetModeEventualAfterStrong(c *C) { + // Test that a strong session shifting to an eventual + // one preserves the socket untouched. + + session, err := mgo.Dial("localhost:40012") + c.Assert(err, IsNil) + defer session.Close() + + // Insert something to force a connection to the master. + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"a": 1}) + c.Assert(err, IsNil) + + session.SetMode(mgo.Eventual, false) + + // Wait since the sync also uses sockets. + for len(session.LiveServers()) != 3 { + c.Log("Waiting for cluster sync to finish...") + time.Sleep(5e8) + } + + // Master socket should still be reserved. + stats := mgo.GetStats() + c.Assert(stats.SocketsInUse, Equals, 1) + + // Confirm it's the master even though it's Eventual by now. + result := M{} + cmd := session.DB("admin").C("$cmd") + err = cmd.Find(M{"ismaster": 1}).One(&result) + c.Assert(err, IsNil) + c.Assert(result["ismaster"], Equals, true) + + session.SetMode(mgo.Eventual, true) + + stats = mgo.GetStats() + c.Assert(stats.SocketsInUse, Equals, 0) +} + +func (s *S) TestPrimaryShutdownStrong(c *C) { + if *fast { + c.Skip("-fast") + } + + session, err := mgo.Dial("localhost:40021") + c.Assert(err, IsNil) + defer session.Close() + + // With strong consistency, this will open a socket to the master. + result := &struct{ Host string }{} + err = session.Run("serverStatus", result) + c.Assert(err, IsNil) + + // Kill the master. + host := result.Host + s.Stop(host) + + // This must fail, since the connection was broken. + err = session.Run("serverStatus", result) + c.Assert(err, Equals, io.EOF) + + // With strong consistency, it fails again until reset. + err = session.Run("serverStatus", result) + c.Assert(err, Equals, io.EOF) + + session.Refresh() + + // Now we should be able to talk to the new master. + // Increase the timeout since this may take quite a while. + session.SetSyncTimeout(3 * time.Minute) + + err = session.Run("serverStatus", result) + c.Assert(err, IsNil) + c.Assert(result.Host, Not(Equals), host) + + // Insert some data to confirm it's indeed a master. + err = session.DB("mydb").C("mycoll").Insert(M{"n": 42}) + c.Assert(err, IsNil) +} + +func (s *S) TestPrimaryHiccup(c *C) { + if *fast { + c.Skip("-fast") + } + + session, err := mgo.Dial("localhost:40021") + c.Assert(err, IsNil) + defer session.Close() + + // With strong consistency, this will open a socket to the master. + result := &struct{ Host string }{} + err = session.Run("serverStatus", result) + c.Assert(err, IsNil) + + // Establish a few extra sessions to create spare sockets to + // the master. This increases a bit the chances of getting an + // incorrect cached socket. + var sessions []*mgo.Session + for i := 0; i < 20; i++ { + sessions = append(sessions, session.Copy()) + err = sessions[len(sessions)-1].Run("serverStatus", result) + c.Assert(err, IsNil) + } + for i := range sessions { + sessions[i].Close() + } + + // Kill the master, but bring it back immediatelly. + host := result.Host + s.Stop(host) + s.StartAll() + + // This must fail, since the connection was broken. + err = session.Run("serverStatus", result) + c.Assert(err, Equals, io.EOF) + + // With strong consistency, it fails again until reset. + err = session.Run("serverStatus", result) + c.Assert(err, Equals, io.EOF) + + session.Refresh() + + // Now we should be able to talk to the new master. + // Increase the timeout since this may take quite a while. + session.SetSyncTimeout(3 * time.Minute) + + // Insert some data to confirm it's indeed a master. + err = session.DB("mydb").C("mycoll").Insert(M{"n": 42}) + c.Assert(err, IsNil) +} + +func (s *S) TestPrimaryShutdownMonotonic(c *C) { + if *fast { + c.Skip("-fast") + } + + session, err := mgo.Dial("localhost:40021") + c.Assert(err, IsNil) + defer session.Close() + + session.SetMode(mgo.Monotonic, true) + + // Insert something to force a switch to the master. + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"a": 1}) + c.Assert(err, IsNil) + + // Wait a bit for this to be synchronized to slaves. + time.Sleep(3 * time.Second) + + result := &struct{ Host string }{} + err = session.Run("serverStatus", result) + c.Assert(err, IsNil) + + // Kill the master. + host := result.Host + s.Stop(host) + + // This must fail, since the connection was broken. + err = session.Run("serverStatus", result) + c.Assert(err, Equals, io.EOF) + + // With monotonic consistency, it fails again until reset. + err = session.Run("serverStatus", result) + c.Assert(err, Equals, io.EOF) + + session.Refresh() + + // Now we should be able to talk to the new master. + err = session.Run("serverStatus", result) + c.Assert(err, IsNil) + c.Assert(result.Host, Not(Equals), host) +} + +func (s *S) TestPrimaryShutdownMonotonicWithSlave(c *C) { + if *fast { + c.Skip("-fast") + } + + session, err := mgo.Dial("localhost:40021") + c.Assert(err, IsNil) + defer session.Close() + + ssresult := &struct{ Host string }{} + imresult := &struct{ IsMaster bool }{} + + // Figure the master while still using the strong session. + err = session.Run("serverStatus", ssresult) + c.Assert(err, IsNil) + err = session.Run("isMaster", imresult) + c.Assert(err, IsNil) + master := ssresult.Host + c.Assert(imresult.IsMaster, Equals, true, Commentf("%s is not the master", master)) + + // Create new monotonic session with an explicit address to ensure + // a slave is synchronized before the master, otherwise a connection + // with the master may be used below for lack of other options. + var addr string + switch { + case strings.HasSuffix(ssresult.Host, ":40021"): + addr = "localhost:40022" + case strings.HasSuffix(ssresult.Host, ":40022"): + addr = "localhost:40021" + case strings.HasSuffix(ssresult.Host, ":40023"): + addr = "localhost:40021" + default: + c.Fatal("Unknown host: ", ssresult.Host) + } + + session, err = mgo.Dial(addr) + c.Assert(err, IsNil) + defer session.Close() + + session.SetMode(mgo.Monotonic, true) + + // Check the address of the socket associated with the monotonic session. + c.Log("Running serverStatus and isMaster with monotonic session") + err = session.Run("serverStatus", ssresult) + c.Assert(err, IsNil) + err = session.Run("isMaster", imresult) + c.Assert(err, IsNil) + slave := ssresult.Host + c.Assert(imresult.IsMaster, Equals, false, Commentf("%s is not a slave", slave)) + + c.Assert(master, Not(Equals), slave) + + // Kill the master. + s.Stop(master) + + // Session must still be good, since we were talking to a slave. + err = session.Run("serverStatus", ssresult) + c.Assert(err, IsNil) + + c.Assert(ssresult.Host, Equals, slave, + Commentf("Monotonic session moved from %s to %s", slave, ssresult.Host)) + + // If we try to insert something, it'll have to hold until the new + // master is available to move the connection, and work correctly. + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"a": 1}) + c.Assert(err, IsNil) + + // Must now be talking to the new master. + err = session.Run("serverStatus", ssresult) + c.Assert(err, IsNil) + err = session.Run("isMaster", imresult) + c.Assert(err, IsNil) + c.Assert(imresult.IsMaster, Equals, true, Commentf("%s is not the master", master)) + + // ... which is not the old one, since it's still dead. + c.Assert(ssresult.Host, Not(Equals), master) +} + +func (s *S) TestPrimaryShutdownEventual(c *C) { + if *fast { + c.Skip("-fast") + } + + session, err := mgo.Dial("localhost:40021") + c.Assert(err, IsNil) + defer session.Close() + + result := &struct{ Host string }{} + err = session.Run("serverStatus", result) + c.Assert(err, IsNil) + master := result.Host + + session.SetMode(mgo.Eventual, true) + + // Should connect to the master when needed. + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"a": 1}) + c.Assert(err, IsNil) + + // Wait a bit for this to be synchronized to slaves. + time.Sleep(3 * time.Second) + + // Kill the master. + s.Stop(master) + + // Should still work, with the new master now. + coll = session.DB("mydb").C("mycoll") + err = coll.Insert(M{"a": 1}) + c.Assert(err, IsNil) + + err = session.Run("serverStatus", result) + c.Assert(err, IsNil) + c.Assert(result.Host, Not(Equals), master) +} + +func (s *S) TestPreserveSocketCountOnSync(c *C) { + if *fast { + c.Skip("-fast") + } + + session, err := mgo.Dial("localhost:40011") + c.Assert(err, IsNil) + defer session.Close() + + stats := mgo.GetStats() + for stats.MasterConns+stats.SlaveConns != 3 { + stats = mgo.GetStats() + c.Log("Waiting for all connections to be established...") + time.Sleep(5e8) + } + + c.Assert(stats.SocketsAlive, Equals, 3) + + // Kill the master (with rs1, 'a' is always the master). + s.Stop("localhost:40011") + + // Wait for the logic to run for a bit and bring it back. + startedAll := make(chan bool) + go func() { + time.Sleep(5e9) + s.StartAll() + startedAll <- true + }() + + // Do not allow the test to return before the goroutine above is done. + defer func() { + <-startedAll + }() + + // Do an action to kick the resync logic in, and also to + // wait until the cluster recognizes the server is back. + result := struct{ Ok bool }{} + err = session.Run("getLastError", &result) + c.Assert(err, IsNil) + c.Assert(result.Ok, Equals, true) + + for i := 0; i != 20; i++ { + stats = mgo.GetStats() + if stats.SocketsAlive == 3 { + break + } + c.Logf("Waiting for 3 sockets alive, have %d", stats.SocketsAlive) + time.Sleep(5e8) + } + + // Ensure the number of sockets is preserved after syncing. + stats = mgo.GetStats() + c.Assert(stats.SocketsAlive, Equals, 3) + c.Assert(stats.SocketsInUse, Equals, 1) + c.Assert(stats.SocketRefs, Equals, 1) +} + +// Connect to the master of a deployment with a single server, +// run an insert, and then ensure the insert worked and that a +// single connection was established. +func (s *S) TestTopologySyncWithSingleMaster(c *C) { + // Use hostname here rather than IP, to make things trickier. + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"a": 1, "b": 2}) + c.Assert(err, IsNil) + + // One connection used for discovery. Master socket recycled for + // insert. Socket is reserved after insert. + stats := mgo.GetStats() + c.Assert(stats.MasterConns, Equals, 1) + c.Assert(stats.SlaveConns, Equals, 0) + c.Assert(stats.SocketsInUse, Equals, 1) + + // Refresh session and socket must be released. + session.Refresh() + stats = mgo.GetStats() + c.Assert(stats.SocketsInUse, Equals, 0) +} + +func (s *S) TestTopologySyncWithSlaveSeed(c *C) { + // That's supposed to be a slave. Must run discovery + // and find out master to insert successfully. + session, err := mgo.Dial("localhost:40012") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + coll.Insert(M{"a": 1, "b": 2}) + + result := struct{ Ok bool }{} + err = session.Run("getLastError", &result) + c.Assert(err, IsNil) + c.Assert(result.Ok, Equals, true) + + // One connection to each during discovery. Master + // socket recycled for insert. + stats := mgo.GetStats() + c.Assert(stats.MasterConns, Equals, 1) + c.Assert(stats.SlaveConns, Equals, 2) + + // Only one socket reference alive, in the master socket owned + // by the above session. + c.Assert(stats.SocketsInUse, Equals, 1) + + // Refresh it, and it must be gone. + session.Refresh() + stats = mgo.GetStats() + c.Assert(stats.SocketsInUse, Equals, 0) +} + +func (s *S) TestSyncTimeout(c *C) { + if *fast { + c.Skip("-fast") + } + + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + s.Stop("localhost:40001") + + timeout := 3 * time.Second + session.SetSyncTimeout(timeout) + started := time.Now() + + // Do something. + result := struct{ Ok bool }{} + err = session.Run("getLastError", &result) + c.Assert(err, ErrorMatches, "no reachable servers") + c.Assert(started.Before(time.Now().Add(-timeout)), Equals, true) + c.Assert(started.After(time.Now().Add(-timeout*2)), Equals, true) +} + +func (s *S) TestDialWithTimeout(c *C) { + if *fast { + c.Skip("-fast") + } + + timeout := 2 * time.Second + started := time.Now() + + // 40009 isn't used by the test servers. + session, err := mgo.DialWithTimeout("localhost:40009", timeout) + if session != nil { + session.Close() + } + c.Assert(err, ErrorMatches, "no reachable servers") + c.Assert(session, IsNil) + c.Assert(started.Before(time.Now().Add(-timeout)), Equals, true) + c.Assert(started.After(time.Now().Add(-timeout*2)), Equals, true) +} + +func (s *S) TestSocketTimeout(c *C) { + if *fast { + c.Skip("-fast") + } + + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + s.Freeze("localhost:40001") + + timeout := 3 * time.Second + session.SetSocketTimeout(timeout) + started := time.Now() + + // Do something. + result := struct{ Ok bool }{} + err = session.Run("getLastError", &result) + c.Assert(err, ErrorMatches, ".*: i/o timeout") + c.Assert(started.Before(time.Now().Add(-timeout)), Equals, true) + c.Assert(started.After(time.Now().Add(-timeout*2)), Equals, true) +} + +func (s *S) TestSocketTimeoutOnDial(c *C) { + if *fast { + c.Skip("-fast") + } + + timeout := 1 * time.Second + + defer mgo.HackSyncSocketTimeout(timeout)() + + s.Freeze("localhost:40001") + + started := time.Now() + + session, err := mgo.DialWithTimeout("localhost:40001", timeout) + c.Assert(err, ErrorMatches, "no reachable servers") + c.Assert(session, IsNil) + + c.Assert(started.Before(time.Now().Add(-timeout)), Equals, true) + c.Assert(started.After(time.Now().Add(-20*time.Second)), Equals, true) +} + +func (s *S) TestSocketTimeoutOnInactiveSocket(c *C) { + if *fast { + c.Skip("-fast") + } + + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + timeout := 2 * time.Second + session.SetSocketTimeout(timeout) + + // Do something that relies on the timeout and works. + c.Assert(session.Ping(), IsNil) + + // Freeze and wait for the timeout to go by. + s.Freeze("localhost:40001") + time.Sleep(timeout + 500*time.Millisecond) + s.Thaw("localhost:40001") + + // Do something again. The timeout above should not have killed + // the socket as there was nothing to be done. + c.Assert(session.Ping(), IsNil) +} + +func (s *S) TestDialWithReplicaSetName(c *C) { + seedLists := [][]string{ + // rs1 primary and rs2 primary + []string{"localhost:40011", "localhost:40021"}, + // rs1 primary and rs2 secondary + []string{"localhost:40011", "localhost:40022"}, + // rs1 secondary and rs2 primary + []string{"localhost:40012", "localhost:40021"}, + // rs1 secondary and rs2 secondary + []string{"localhost:40012", "localhost:40022"}, + } + + rs2Members := []string{":40021", ":40022", ":40023"} + + verifySyncedServers := func(session *mgo.Session, numServers int) { + // wait for the server(s) to be synced + for len(session.LiveServers()) != numServers { + c.Log("Waiting for cluster sync to finish...") + time.Sleep(5e8) + } + + // ensure none of the rs2 set members are communicated with + for _, addr := range session.LiveServers() { + for _, rs2Member := range rs2Members { + c.Assert(strings.HasSuffix(addr, rs2Member), Equals, false) + } + } + } + + // only communication with rs1 members is expected + for _, seedList := range seedLists { + info := mgo.DialInfo{ + Addrs: seedList, + Timeout: 5 * time.Second, + ReplicaSetName: "rs1", + } + + session, err := mgo.DialWithInfo(&info) + c.Assert(err, IsNil) + verifySyncedServers(session, 3) + session.Close() + + info.Direct = true + session, err = mgo.DialWithInfo(&info) + c.Assert(err, IsNil) + verifySyncedServers(session, 1) + session.Close() + + connectionUrl := fmt.Sprintf("mongodb://%v/?replicaSet=rs1", strings.Join(seedList, ",")) + session, err = mgo.Dial(connectionUrl) + c.Assert(err, IsNil) + verifySyncedServers(session, 3) + session.Close() + + connectionUrl += "&connect=direct" + session, err = mgo.Dial(connectionUrl) + c.Assert(err, IsNil) + verifySyncedServers(session, 1) + session.Close() + } + +} + +func (s *S) TestDirect(c *C) { + session, err := mgo.Dial("localhost:40012?connect=direct") + c.Assert(err, IsNil) + defer session.Close() + + // We know that server is a slave. + session.SetMode(mgo.Monotonic, true) + + result := &struct{ Host string }{} + err = session.Run("serverStatus", result) + c.Assert(err, IsNil) + c.Assert(strings.HasSuffix(result.Host, ":40012"), Equals, true) + + stats := mgo.GetStats() + c.Assert(stats.SocketsAlive, Equals, 1) + c.Assert(stats.SocketsInUse, Equals, 1) + c.Assert(stats.SocketRefs, Equals, 1) + + // We've got no master, so it'll timeout. + session.SetSyncTimeout(5e8 * time.Nanosecond) + + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"test": 1}) + c.Assert(err, ErrorMatches, "no reachable servers") + + // Writing to the local database is okay. + coll = session.DB("local").C("mycoll") + defer coll.RemoveAll(nil) + id := bson.NewObjectId() + err = coll.Insert(M{"_id": id}) + c.Assert(err, IsNil) + + // Data was stored in the right server. + n, err := coll.Find(M{"_id": id}).Count() + c.Assert(err, IsNil) + c.Assert(n, Equals, 1) + + // Server hasn't changed. + result.Host = "" + err = session.Run("serverStatus", result) + c.Assert(err, IsNil) + c.Assert(strings.HasSuffix(result.Host, ":40012"), Equals, true) +} + +func (s *S) TestDirectToUnknownStateMember(c *C) { + session, err := mgo.Dial("localhost:40041?connect=direct") + c.Assert(err, IsNil) + defer session.Close() + + session.SetMode(mgo.Monotonic, true) + + result := &struct{ Host string }{} + err = session.Run("serverStatus", result) + c.Assert(err, IsNil) + c.Assert(strings.HasSuffix(result.Host, ":40041"), Equals, true) + + // We've got no master, so it'll timeout. + session.SetSyncTimeout(5e8 * time.Nanosecond) + + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"test": 1}) + c.Assert(err, ErrorMatches, "no reachable servers") + + // Slave is still reachable. + result.Host = "" + err = session.Run("serverStatus", result) + c.Assert(err, IsNil) + c.Assert(strings.HasSuffix(result.Host, ":40041"), Equals, true) +} + +func (s *S) TestFailFast(c *C) { + info := mgo.DialInfo{ + Addrs: []string{"localhost:99999"}, + Timeout: 5 * time.Second, + FailFast: true, + } + + started := time.Now() + + _, err := mgo.DialWithInfo(&info) + c.Assert(err, ErrorMatches, "no reachable servers") + + c.Assert(started.After(time.Now().Add(-time.Second)), Equals, true) +} + +type OpCounters struct { + Insert int + Query int + Update int + Delete int + GetMore int + Command int +} + +func getOpCounters(server string) (c *OpCounters, err error) { + session, err := mgo.Dial(server + "?connect=direct") + if err != nil { + return nil, err + } + defer session.Close() + session.SetMode(mgo.Monotonic, true) + result := struct{ OpCounters }{} + err = session.Run("serverStatus", &result) + return &result.OpCounters, err +} + +func (s *S) TestMonotonicSlaveOkFlagWithMongos(c *C) { + session, err := mgo.Dial("localhost:40021") + c.Assert(err, IsNil) + defer session.Close() + + ssresult := &struct{ Host string }{} + imresult := &struct{ IsMaster bool }{} + + // Figure the master while still using the strong session. + err = session.Run("serverStatus", ssresult) + c.Assert(err, IsNil) + err = session.Run("isMaster", imresult) + c.Assert(err, IsNil) + master := ssresult.Host + c.Assert(imresult.IsMaster, Equals, true, Commentf("%s is not the master", master)) + + // Collect op counters for everyone. + opc21a, err := getOpCounters("localhost:40021") + c.Assert(err, IsNil) + opc22a, err := getOpCounters("localhost:40022") + c.Assert(err, IsNil) + opc23a, err := getOpCounters("localhost:40023") + c.Assert(err, IsNil) + + // Do a SlaveOk query through MongoS + + mongos, err := mgo.Dial("localhost:40202") + c.Assert(err, IsNil) + defer mongos.Close() + + mongos.SetMode(mgo.Monotonic, true) + + coll := mongos.DB("mydb").C("mycoll") + result := &struct{}{} + for i := 0; i != 5; i++ { + err := coll.Find(nil).One(result) + c.Assert(err, Equals, mgo.ErrNotFound) + } + + // Collect op counters for everyone again. + opc21b, err := getOpCounters("localhost:40021") + c.Assert(err, IsNil) + opc22b, err := getOpCounters("localhost:40022") + c.Assert(err, IsNil) + opc23b, err := getOpCounters("localhost:40023") + c.Assert(err, IsNil) + + masterPort := master[strings.Index(master, ":")+1:] + + var masterDelta, slaveDelta int + switch masterPort { + case "40021": + masterDelta = opc21b.Query - opc21a.Query + slaveDelta = (opc22b.Query - opc22a.Query) + (opc23b.Query - opc23a.Query) + case "40022": + masterDelta = opc22b.Query - opc22a.Query + slaveDelta = (opc21b.Query - opc21a.Query) + (opc23b.Query - opc23a.Query) + case "40023": + masterDelta = opc23b.Query - opc23a.Query + slaveDelta = (opc21b.Query - opc21a.Query) + (opc22b.Query - opc22a.Query) + default: + c.Fatal("Uh?") + } + + c.Check(masterDelta, Equals, 0) // Just the counting itself. + c.Check(slaveDelta, Equals, 5) // The counting for both, plus 5 queries above. +} + +func (s *S) TestRemovalOfClusterMember(c *C) { + if *fast { + c.Skip("-fast") + } + + master, err := mgo.Dial("localhost:40021") + c.Assert(err, IsNil) + defer master.Close() + + // Wait for cluster to fully sync up. + for i := 0; i < 10; i++ { + if len(master.LiveServers()) == 3 { + break + } + time.Sleep(5e8) + } + if len(master.LiveServers()) != 3 { + c.Fatalf("Test started with bad cluster state: %v", master.LiveServers()) + } + + result := &struct { + IsMaster bool + Me string + }{} + slave := master.Copy() + slave.SetMode(mgo.Monotonic, true) // Monotonic can hold a non-master socket persistently. + err = slave.Run("isMaster", result) + c.Assert(err, IsNil) + c.Assert(result.IsMaster, Equals, false) + slaveAddr := result.Me + + defer func() { + master.Refresh() + master.Run(bson.D{{"$eval", `rs.add("` + slaveAddr + `")`}}, nil) + master.Close() + slave.Close() + }() + + c.Logf("========== Removing slave: %s ==========", slaveAddr) + + master.Run(bson.D{{"$eval", `rs.remove("` + slaveAddr + `")`}}, nil) + + master.Refresh() + + // Give the cluster a moment to catch up by doing a roundtrip to the master. + err = master.Ping() + c.Assert(err, IsNil) + + time.Sleep(3e9) + + // This must fail since the slave has been taken off the cluster. + err = slave.Ping() + c.Assert(err, NotNil) + + for i := 0; i < 15; i++ { + if len(master.LiveServers()) == 2 { + break + } + time.Sleep(time.Second) + } + live := master.LiveServers() + if len(live) != 2 { + c.Errorf("Removed server still considered live: %#s", live) + } + + c.Log("========== Test succeeded. ==========") +} + +func (s *S) TestPoolLimitSimple(c *C) { + for test := 0; test < 2; test++ { + var session *mgo.Session + var err error + if test == 0 { + session, err = mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + session.SetPoolLimit(1) + } else { + session, err = mgo.Dial("localhost:40001?maxPoolSize=1") + c.Assert(err, IsNil) + } + defer session.Close() + + // Put one socket in use. + c.Assert(session.Ping(), IsNil) + + done := make(chan time.Duration) + + // Now block trying to get another one due to the pool limit. + go func() { + copy := session.Copy() + defer copy.Close() + started := time.Now() + c.Check(copy.Ping(), IsNil) + done <- time.Now().Sub(started) + }() + + time.Sleep(300 * time.Millisecond) + + // Put the one socket back in the pool, freeing it for the copy. + session.Refresh() + delay := <-done + c.Assert(delay > 300*time.Millisecond, Equals, true, Commentf("Delay: %s", delay)) + } +} + +func (s *S) TestPoolLimitMany(c *C) { + if *fast { + c.Skip("-fast") + } + + session, err := mgo.Dial("localhost:40011") + c.Assert(err, IsNil) + defer session.Close() + + stats := mgo.GetStats() + for stats.MasterConns+stats.SlaveConns != 3 { + stats = mgo.GetStats() + c.Log("Waiting for all connections to be established...") + time.Sleep(500 * time.Millisecond) + } + c.Assert(stats.SocketsAlive, Equals, 3) + + const poolLimit = 64 + session.SetPoolLimit(poolLimit) + + // Consume the whole limit for the master. + var master []*mgo.Session + for i := 0; i < poolLimit; i++ { + s := session.Copy() + defer s.Close() + c.Assert(s.Ping(), IsNil) + master = append(master, s) + } + + before := time.Now() + go func() { + time.Sleep(3e9) + master[0].Refresh() + }() + + // Then, a single ping must block, since it would need another + // connection to the master, over the limit. Once the goroutine + // above releases its socket, it should move on. + session.Ping() + delay := time.Now().Sub(before) + c.Assert(delay > 3e9, Equals, true) + c.Assert(delay < 6e9, Equals, true) +} + +func (s *S) TestSetModeEventualIterBug(c *C) { + session1, err := mgo.Dial("localhost:40011") + c.Assert(err, IsNil) + defer session1.Close() + + session1.SetMode(mgo.Eventual, false) + + coll1 := session1.DB("mydb").C("mycoll") + + const N = 100 + for i := 0; i < N; i++ { + err = coll1.Insert(M{"_id": i}) + c.Assert(err, IsNil) + } + + c.Logf("Waiting until secondary syncs") + for { + n, err := coll1.Count() + c.Assert(err, IsNil) + if n == N { + c.Logf("Found all") + break + } + } + + session2, err := mgo.Dial("localhost:40011") + c.Assert(err, IsNil) + defer session2.Close() + + session2.SetMode(mgo.Eventual, false) + + coll2 := session2.DB("mydb").C("mycoll") + + i := 0 + iter := coll2.Find(nil).Batch(10).Iter() + var result struct{} + for iter.Next(&result) { + i++ + } + c.Assert(iter.Close(), Equals, nil) + c.Assert(i, Equals, N) +} + +func (s *S) TestCustomDialOld(c *C) { + dials := make(chan bool, 16) + dial := func(addr net.Addr) (net.Conn, error) { + tcpaddr, ok := addr.(*net.TCPAddr) + if !ok { + return nil, fmt.Errorf("unexpected address type: %T", addr) + } + dials <- true + return net.DialTCP("tcp", nil, tcpaddr) + } + info := mgo.DialInfo{ + Addrs: []string{"localhost:40012"}, + Dial: dial, + } + + // Use hostname here rather than IP, to make things trickier. + session, err := mgo.DialWithInfo(&info) + c.Assert(err, IsNil) + defer session.Close() + + const N = 3 + for i := 0; i < N; i++ { + select { + case <-dials: + case <-time.After(5 * time.Second): + c.Fatalf("expected %d dials, got %d", N, i) + } + } + select { + case <-dials: + c.Fatalf("got more dials than expected") + case <-time.After(100 * time.Millisecond): + } +} + +func (s *S) TestCustomDialNew(c *C) { + dials := make(chan bool, 16) + dial := func(addr *mgo.ServerAddr) (net.Conn, error) { + dials <- true + if addr.TCPAddr().Port == 40012 { + c.Check(addr.String(), Equals, "localhost:40012") + } + return net.DialTCP("tcp", nil, addr.TCPAddr()) + } + info := mgo.DialInfo{ + Addrs: []string{"localhost:40012"}, + DialServer: dial, + } + + // Use hostname here rather than IP, to make things trickier. + session, err := mgo.DialWithInfo(&info) + c.Assert(err, IsNil) + defer session.Close() + + const N = 3 + for i := 0; i < N; i++ { + select { + case <-dials: + case <-time.After(5 * time.Second): + c.Fatalf("expected %d dials, got %d", N, i) + } + } + select { + case <-dials: + c.Fatalf("got more dials than expected") + case <-time.After(100 * time.Millisecond): + } +} + +func (s *S) TestPrimaryShutdownOnAuthShard(c *C) { + if *fast { + c.Skip("-fast") + } + + // Dial the shard. + session, err := mgo.Dial("localhost:40203") + c.Assert(err, IsNil) + defer session.Close() + + // Login and insert something to make it more realistic. + session.DB("admin").Login("root", "rapadura") + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(bson.M{"n": 1}) + c.Assert(err, IsNil) + + // Dial the replica set to figure the master out. + rs, err := mgo.Dial("root:rapadura@localhost:40031") + c.Assert(err, IsNil) + defer rs.Close() + + // With strong consistency, this will open a socket to the master. + result := &struct{ Host string }{} + err = rs.Run("serverStatus", result) + c.Assert(err, IsNil) + + // Kill the master. + host := result.Host + s.Stop(host) + + // This must fail, since the connection was broken. + err = rs.Run("serverStatus", result) + c.Assert(err, Equals, io.EOF) + + // This won't work because the master just died. + err = coll.Insert(bson.M{"n": 2}) + c.Assert(err, NotNil) + + // Refresh session and wait for re-election. + session.Refresh() + for i := 0; i < 60; i++ { + err = coll.Insert(bson.M{"n": 3}) + if err == nil { + break + } + c.Logf("Waiting for replica set to elect a new master. Last error: %v", err) + time.Sleep(500 * time.Millisecond) + } + c.Assert(err, IsNil) + + count, err := coll.Count() + c.Assert(count > 1, Equals, true) +} + +func (s *S) TestNearestSecondary(c *C) { + defer mgo.HackPingDelay(3 * time.Second)() + + rs1a := "127.0.0.1:40011" + rs1b := "127.0.0.1:40012" + rs1c := "127.0.0.1:40013" + s.Freeze(rs1b) + + session, err := mgo.Dial(rs1a) + c.Assert(err, IsNil) + defer session.Close() + + // Wait for the sync up to run through the first couple of servers. + for len(session.LiveServers()) != 2 { + c.Log("Waiting for two servers to be alive...") + time.Sleep(100 * time.Millisecond) + } + + // Extra delay to ensure the third server gets penalized. + time.Sleep(500 * time.Millisecond) + + // Release third server. + s.Thaw(rs1b) + + // Wait for it to come up. + for len(session.LiveServers()) != 3 { + c.Log("Waiting for all servers to be alive...") + time.Sleep(100 * time.Millisecond) + } + + session.SetMode(mgo.Monotonic, true) + var result struct{ Host string } + + // See which slave picks the line, several times to avoid chance. + for i := 0; i < 10; i++ { + session.Refresh() + err = session.Run("serverStatus", &result) + c.Assert(err, IsNil) + c.Assert(hostPort(result.Host), Equals, hostPort(rs1c)) + } + + if *fast { + // Don't hold back for several seconds. + return + } + + // Now hold the other server for long enough to penalize it. + s.Freeze(rs1c) + time.Sleep(5 * time.Second) + s.Thaw(rs1c) + + // Wait for the ping to be processed. + time.Sleep(500 * time.Millisecond) + + // Repeating the test should now pick the former server consistently. + for i := 0; i < 10; i++ { + session.Refresh() + err = session.Run("serverStatus", &result) + c.Assert(err, IsNil) + c.Assert(hostPort(result.Host), Equals, hostPort(rs1b)) + } +} + +func (s *S) TestConnectCloseConcurrency(c *C) { + restore := mgo.HackPingDelay(500 * time.Millisecond) + defer restore() + var wg sync.WaitGroup + const n = 500 + wg.Add(n) + for i := 0; i < n; i++ { + go func() { + defer wg.Done() + session, err := mgo.Dial("localhost:40001") + if err != nil { + c.Fatal(err) + } + time.Sleep(1) + session.Close() + }() + } + wg.Wait() +} + +func (s *S) TestSelectServers(c *C) { + if !s.versionAtLeast(2, 2) { + c.Skip("read preferences introduced in 2.2") + } + + session, err := mgo.Dial("localhost:40011") + c.Assert(err, IsNil) + defer session.Close() + + session.SetMode(mgo.Eventual, true) + + var result struct{ Host string } + + session.Refresh() + session.SelectServers(bson.D{{"rs1", "b"}}) + err = session.Run("serverStatus", &result) + c.Assert(err, IsNil) + c.Assert(hostPort(result.Host), Equals, "40012") + + session.Refresh() + session.SelectServers(bson.D{{"rs1", "c"}}) + err = session.Run("serverStatus", &result) + c.Assert(err, IsNil) + c.Assert(hostPort(result.Host), Equals, "40013") +} + +func (s *S) TestSelectServersWithMongos(c *C) { + if !s.versionAtLeast(2, 2) { + c.Skip("read preferences introduced in 2.2") + } + + session, err := mgo.Dial("localhost:40021") + c.Assert(err, IsNil) + defer session.Close() + + ssresult := &struct{ Host string }{} + imresult := &struct{ IsMaster bool }{} + + // Figure the master while still using the strong session. + err = session.Run("serverStatus", ssresult) + c.Assert(err, IsNil) + err = session.Run("isMaster", imresult) + c.Assert(err, IsNil) + master := ssresult.Host + c.Assert(imresult.IsMaster, Equals, true, Commentf("%s is not the master", master)) + + var slave1, slave2 string + switch hostPort(master) { + case "40021": + slave1, slave2 = "b", "c" + case "40022": + slave1, slave2 = "a", "c" + case "40023": + slave1, slave2 = "a", "b" + } + + // Collect op counters for everyone. + opc21a, err := getOpCounters("localhost:40021") + c.Assert(err, IsNil) + opc22a, err := getOpCounters("localhost:40022") + c.Assert(err, IsNil) + opc23a, err := getOpCounters("localhost:40023") + c.Assert(err, IsNil) + + // Do a SlaveOk query through MongoS + mongos, err := mgo.Dial("localhost:40202") + c.Assert(err, IsNil) + defer mongos.Close() + + mongos.SetMode(mgo.Monotonic, true) + + mongos.Refresh() + mongos.SelectServers(bson.D{{"rs2", slave1}}) + coll := mongos.DB("mydb").C("mycoll") + result := &struct{}{} + for i := 0; i != 5; i++ { + err := coll.Find(nil).One(result) + c.Assert(err, Equals, mgo.ErrNotFound) + } + + mongos.Refresh() + mongos.SelectServers(bson.D{{"rs2", slave2}}) + coll = mongos.DB("mydb").C("mycoll") + for i := 0; i != 7; i++ { + err := coll.Find(nil).One(result) + c.Assert(err, Equals, mgo.ErrNotFound) + } + + // Collect op counters for everyone again. + opc21b, err := getOpCounters("localhost:40021") + c.Assert(err, IsNil) + opc22b, err := getOpCounters("localhost:40022") + c.Assert(err, IsNil) + opc23b, err := getOpCounters("localhost:40023") + c.Assert(err, IsNil) + + switch hostPort(master) { + case "40021": + c.Check(opc21b.Query-opc21a.Query, Equals, 0) + c.Check(opc22b.Query-opc22a.Query, Equals, 5) + c.Check(opc23b.Query-opc23a.Query, Equals, 7) + case "40022": + c.Check(opc21b.Query-opc21a.Query, Equals, 5) + c.Check(opc22b.Query-opc22a.Query, Equals, 0) + c.Check(opc23b.Query-opc23a.Query, Equals, 7) + case "40023": + c.Check(opc21b.Query-opc21a.Query, Equals, 5) + c.Check(opc22b.Query-opc22a.Query, Equals, 7) + c.Check(opc23b.Query-opc23a.Query, Equals, 0) + default: + c.Fatal("Uh?") + } +} diff --git a/vendor/gopkg.in/mgo.v2/doc.go b/vendor/gopkg.in/mgo.v2/doc.go new file mode 100644 index 000000000..9316c5554 --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/doc.go @@ -0,0 +1,31 @@ +// Package mgo offers a rich MongoDB driver for Go. +// +// Details about the mgo project (pronounced as "mango") are found +// in its web page: +// +// http://labix.org/mgo +// +// Usage of the driver revolves around the concept of sessions. To +// get started, obtain a session using the Dial function: +// +// session, err := mgo.Dial(url) +// +// This will establish one or more connections with the cluster of +// servers defined by the url parameter. From then on, the cluster +// may be queried with multiple consistency rules (see SetMode) and +// documents retrieved with statements such as: +// +// c := session.DB(database).C(collection) +// err := c.Find(query).One(&result) +// +// New sessions are typically created by calling session.Copy on the +// initial session obtained at dial time. These new sessions will share +// the same cluster information and connection cache, and may be easily +// handed into other methods and functions for organizing logic. +// Every session created must have its Close method called at the end +// of its life time, so its resources may be put back in the pool or +// collected, depending on the case. +// +// For more details, see the documentation for the types and methods. +// +package mgo diff --git a/vendor/gopkg.in/mgo.v2/export_test.go b/vendor/gopkg.in/mgo.v2/export_test.go new file mode 100644 index 000000000..690f84d38 --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/export_test.go @@ -0,0 +1,33 @@ +package mgo + +import ( + "time" +) + +func HackPingDelay(newDelay time.Duration) (restore func()) { + globalMutex.Lock() + defer globalMutex.Unlock() + + oldDelay := pingDelay + restore = func() { + globalMutex.Lock() + pingDelay = oldDelay + globalMutex.Unlock() + } + pingDelay = newDelay + return +} + +func HackSyncSocketTimeout(newTimeout time.Duration) (restore func()) { + globalMutex.Lock() + defer globalMutex.Unlock() + + oldTimeout := syncSocketTimeout + restore = func() { + globalMutex.Lock() + syncSocketTimeout = oldTimeout + globalMutex.Unlock() + } + syncSocketTimeout = newTimeout + return +} diff --git a/vendor/gopkg.in/mgo.v2/gridfs.go b/vendor/gopkg.in/mgo.v2/gridfs.go new file mode 100644 index 000000000..54b3dd50e --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/gridfs.go @@ -0,0 +1,755 @@ +// mgo - MongoDB driver for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package mgo + +import ( + "crypto/md5" + "encoding/hex" + "errors" + "hash" + "io" + "os" + "sync" + "time" + + "gopkg.in/mgo.v2/bson" +) + +type GridFS struct { + Files *Collection + Chunks *Collection +} + +type gfsFileMode int + +const ( + gfsClosed gfsFileMode = 0 + gfsReading gfsFileMode = 1 + gfsWriting gfsFileMode = 2 +) + +type GridFile struct { + m sync.Mutex + c sync.Cond + gfs *GridFS + mode gfsFileMode + err error + + chunk int + offset int64 + + wpending int + wbuf []byte + wsum hash.Hash + + rbuf []byte + rcache *gfsCachedChunk + + doc gfsFile +} + +type gfsFile struct { + Id interface{} "_id" + ChunkSize int "chunkSize" + UploadDate time.Time "uploadDate" + Length int64 ",minsize" + MD5 string + Filename string ",omitempty" + ContentType string "contentType,omitempty" + Metadata *bson.Raw ",omitempty" +} + +type gfsChunk struct { + Id interface{} "_id" + FilesId interface{} "files_id" + N int + Data []byte +} + +type gfsCachedChunk struct { + wait sync.Mutex + n int + data []byte + err error +} + +func newGridFS(db *Database, prefix string) *GridFS { + return &GridFS{db.C(prefix + ".files"), db.C(prefix + ".chunks")} +} + +func (gfs *GridFS) newFile() *GridFile { + file := &GridFile{gfs: gfs} + file.c.L = &file.m + //runtime.SetFinalizer(file, finalizeFile) + return file +} + +func finalizeFile(file *GridFile) { + file.Close() +} + +// Create creates a new file with the provided name in the GridFS. If the file +// name already exists, a new version will be inserted with an up-to-date +// uploadDate that will cause it to be atomically visible to the Open and +// OpenId methods. If the file name is not important, an empty name may be +// provided and the file Id used instead. +// +// It's important to Close files whether they are being written to +// or read from, and to check the err result to ensure the operation +// completed successfully. +// +// A simple example inserting a new file: +// +// func check(err error) { +// if err != nil { +// panic(err.String()) +// } +// } +// file, err := db.GridFS("fs").Create("myfile.txt") +// check(err) +// n, err := file.Write([]byte("Hello world!")) +// check(err) +// err = file.Close() +// check(err) +// fmt.Printf("%d bytes written\n", n) +// +// The io.Writer interface is implemented by *GridFile and may be used to +// help on the file creation. For example: +// +// file, err := db.GridFS("fs").Create("myfile.txt") +// check(err) +// messages, err := os.Open("/var/log/messages") +// check(err) +// defer messages.Close() +// err = io.Copy(file, messages) +// check(err) +// err = file.Close() +// check(err) +// +func (gfs *GridFS) Create(name string) (file *GridFile, err error) { + file = gfs.newFile() + file.mode = gfsWriting + file.wsum = md5.New() + file.doc = gfsFile{Id: bson.NewObjectId(), ChunkSize: 255 * 1024, Filename: name} + return +} + +// OpenId returns the file with the provided id, for reading. +// If the file isn't found, err will be set to mgo.ErrNotFound. +// +// It's important to Close files whether they are being written to +// or read from, and to check the err result to ensure the operation +// completed successfully. +// +// The following example will print the first 8192 bytes from the file: +// +// func check(err error) { +// if err != nil { +// panic(err.String()) +// } +// } +// file, err := db.GridFS("fs").OpenId(objid) +// check(err) +// b := make([]byte, 8192) +// n, err := file.Read(b) +// check(err) +// fmt.Println(string(b)) +// check(err) +// err = file.Close() +// check(err) +// fmt.Printf("%d bytes read\n", n) +// +// The io.Reader interface is implemented by *GridFile and may be used to +// deal with it. As an example, the following snippet will dump the whole +// file into the standard output: +// +// file, err := db.GridFS("fs").OpenId(objid) +// check(err) +// err = io.Copy(os.Stdout, file) +// check(err) +// err = file.Close() +// check(err) +// +func (gfs *GridFS) OpenId(id interface{}) (file *GridFile, err error) { + var doc gfsFile + err = gfs.Files.Find(bson.M{"_id": id}).One(&doc) + if err != nil { + return + } + file = gfs.newFile() + file.mode = gfsReading + file.doc = doc + return +} + +// Open returns the most recently uploaded file with the provided +// name, for reading. If the file isn't found, err will be set +// to mgo.ErrNotFound. +// +// It's important to Close files whether they are being written to +// or read from, and to check the err result to ensure the operation +// completed successfully. +// +// The following example will print the first 8192 bytes from the file: +// +// file, err := db.GridFS("fs").Open("myfile.txt") +// check(err) +// b := make([]byte, 8192) +// n, err := file.Read(b) +// check(err) +// fmt.Println(string(b)) +// check(err) +// err = file.Close() +// check(err) +// fmt.Printf("%d bytes read\n", n) +// +// The io.Reader interface is implemented by *GridFile and may be used to +// deal with it. As an example, the following snippet will dump the whole +// file into the standard output: +// +// file, err := db.GridFS("fs").Open("myfile.txt") +// check(err) +// err = io.Copy(os.Stdout, file) +// check(err) +// err = file.Close() +// check(err) +// +func (gfs *GridFS) Open(name string) (file *GridFile, err error) { + var doc gfsFile + err = gfs.Files.Find(bson.M{"filename": name}).Sort("-uploadDate").One(&doc) + if err != nil { + return + } + file = gfs.newFile() + file.mode = gfsReading + file.doc = doc + return +} + +// OpenNext opens the next file from iter for reading, sets *file to it, +// and returns true on the success case. If no more documents are available +// on iter or an error occurred, *file is set to nil and the result is false. +// Errors will be available via iter.Err(). +// +// The iter parameter must be an iterator on the GridFS files collection. +// Using the GridFS.Find method is an easy way to obtain such an iterator, +// but any iterator on the collection will work. +// +// If the provided *file is non-nil, OpenNext will close it before attempting +// to iterate to the next element. This means that in a loop one only +// has to worry about closing files when breaking out of the loop early +// (break, return, or panic). +// +// For example: +// +// gfs := db.GridFS("fs") +// query := gfs.Find(nil).Sort("filename") +// iter := query.Iter() +// var f *mgo.GridFile +// for gfs.OpenNext(iter, &f) { +// fmt.Printf("Filename: %s\n", f.Name()) +// } +// if iter.Close() != nil { +// panic(iter.Close()) +// } +// +func (gfs *GridFS) OpenNext(iter *Iter, file **GridFile) bool { + if *file != nil { + // Ignoring the error here shouldn't be a big deal + // as we're reading the file and the loop iteration + // for this file is finished. + _ = (*file).Close() + } + var doc gfsFile + if !iter.Next(&doc) { + *file = nil + return false + } + f := gfs.newFile() + f.mode = gfsReading + f.doc = doc + *file = f + return true +} + +// Find runs query on GridFS's files collection and returns +// the resulting Query. +// +// This logic: +// +// gfs := db.GridFS("fs") +// iter := gfs.Find(nil).Iter() +// +// Is equivalent to: +// +// files := db.C("fs" + ".files") +// iter := files.Find(nil).Iter() +// +func (gfs *GridFS) Find(query interface{}) *Query { + return gfs.Files.Find(query) +} + +// RemoveId deletes the file with the provided id from the GridFS. +func (gfs *GridFS) RemoveId(id interface{}) error { + err := gfs.Files.Remove(bson.M{"_id": id}) + if err != nil { + return err + } + _, err = gfs.Chunks.RemoveAll(bson.D{{"files_id", id}}) + return err +} + +type gfsDocId struct { + Id interface{} "_id" +} + +// Remove deletes all files with the provided name from the GridFS. +func (gfs *GridFS) Remove(name string) (err error) { + iter := gfs.Files.Find(bson.M{"filename": name}).Select(bson.M{"_id": 1}).Iter() + var doc gfsDocId + for iter.Next(&doc) { + if e := gfs.RemoveId(doc.Id); e != nil { + err = e + } + } + if err == nil { + err = iter.Close() + } + return err +} + +func (file *GridFile) assertMode(mode gfsFileMode) { + switch file.mode { + case mode: + return + case gfsWriting: + panic("GridFile is open for writing") + case gfsReading: + panic("GridFile is open for reading") + case gfsClosed: + panic("GridFile is closed") + default: + panic("internal error: missing GridFile mode") + } +} + +// SetChunkSize sets size of saved chunks. Once the file is written to, it +// will be split in blocks of that size and each block saved into an +// independent chunk document. The default chunk size is 256kb. +// +// It is a runtime error to call this function once the file has started +// being written to. +func (file *GridFile) SetChunkSize(bytes int) { + file.assertMode(gfsWriting) + debugf("GridFile %p: setting chunk size to %d", file, bytes) + file.m.Lock() + file.doc.ChunkSize = bytes + file.m.Unlock() +} + +// Id returns the current file Id. +func (file *GridFile) Id() interface{} { + return file.doc.Id +} + +// SetId changes the current file Id. +// +// It is a runtime error to call this function once the file has started +// being written to, or when the file is not open for writing. +func (file *GridFile) SetId(id interface{}) { + file.assertMode(gfsWriting) + file.m.Lock() + file.doc.Id = id + file.m.Unlock() +} + +// Name returns the optional file name. An empty string will be returned +// in case it is unset. +func (file *GridFile) Name() string { + return file.doc.Filename +} + +// SetName changes the optional file name. An empty string may be used to +// unset it. +// +// It is a runtime error to call this function when the file is not open +// for writing. +func (file *GridFile) SetName(name string) { + file.assertMode(gfsWriting) + file.m.Lock() + file.doc.Filename = name + file.m.Unlock() +} + +// ContentType returns the optional file content type. An empty string will be +// returned in case it is unset. +func (file *GridFile) ContentType() string { + return file.doc.ContentType +} + +// ContentType changes the optional file content type. An empty string may be +// used to unset it. +// +// It is a runtime error to call this function when the file is not open +// for writing. +func (file *GridFile) SetContentType(ctype string) { + file.assertMode(gfsWriting) + file.m.Lock() + file.doc.ContentType = ctype + file.m.Unlock() +} + +// GetMeta unmarshals the optional "metadata" field associated with the +// file into the result parameter. The meaning of keys under that field +// is user-defined. For example: +// +// result := struct{ INode int }{} +// err = file.GetMeta(&result) +// if err != nil { +// panic(err.String()) +// } +// fmt.Printf("inode: %d\n", result.INode) +// +func (file *GridFile) GetMeta(result interface{}) (err error) { + file.m.Lock() + if file.doc.Metadata != nil { + err = bson.Unmarshal(file.doc.Metadata.Data, result) + } + file.m.Unlock() + return +} + +// SetMeta changes the optional "metadata" field associated with the +// file. The meaning of keys under that field is user-defined. +// For example: +// +// file.SetMeta(bson.M{"inode": inode}) +// +// It is a runtime error to call this function when the file is not open +// for writing. +func (file *GridFile) SetMeta(metadata interface{}) { + file.assertMode(gfsWriting) + data, err := bson.Marshal(metadata) + file.m.Lock() + if err != nil && file.err == nil { + file.err = err + } else { + file.doc.Metadata = &bson.Raw{Data: data} + } + file.m.Unlock() +} + +// Size returns the file size in bytes. +func (file *GridFile) Size() (bytes int64) { + file.m.Lock() + bytes = file.doc.Length + file.m.Unlock() + return +} + +// MD5 returns the file MD5 as a hex-encoded string. +func (file *GridFile) MD5() (md5 string) { + return file.doc.MD5 +} + +// UploadDate returns the file upload time. +func (file *GridFile) UploadDate() time.Time { + return file.doc.UploadDate +} + +// SetUploadDate changes the file upload time. +// +// It is a runtime error to call this function when the file is not open +// for writing. +func (file *GridFile) SetUploadDate(t time.Time) { + file.assertMode(gfsWriting) + file.m.Lock() + file.doc.UploadDate = t + file.m.Unlock() +} + +// Close flushes any pending changes in case the file is being written +// to, waits for any background operations to finish, and closes the file. +// +// It's important to Close files whether they are being written to +// or read from, and to check the err result to ensure the operation +// completed successfully. +func (file *GridFile) Close() (err error) { + file.m.Lock() + defer file.m.Unlock() + if file.mode == gfsWriting { + if len(file.wbuf) > 0 && file.err == nil { + file.insertChunk(file.wbuf) + file.wbuf = file.wbuf[0:0] + } + file.completeWrite() + } else if file.mode == gfsReading && file.rcache != nil { + file.rcache.wait.Lock() + file.rcache = nil + } + file.mode = gfsClosed + debugf("GridFile %p: closed", file) + return file.err +} + +func (file *GridFile) completeWrite() { + for file.wpending > 0 { + debugf("GridFile %p: waiting for %d pending chunks to complete file write", file, file.wpending) + file.c.Wait() + } + if file.err == nil { + hexsum := hex.EncodeToString(file.wsum.Sum(nil)) + if file.doc.UploadDate.IsZero() { + file.doc.UploadDate = bson.Now() + } + file.doc.MD5 = hexsum + file.err = file.gfs.Files.Insert(file.doc) + file.gfs.Chunks.EnsureIndexKey("files_id", "n") + } + if file.err != nil { + file.gfs.Chunks.RemoveAll(bson.D{{"files_id", file.doc.Id}}) + } +} + +// Abort cancels an in-progress write, preventing the file from being +// automically created and ensuring previously written chunks are +// removed when the file is closed. +// +// It is a runtime error to call Abort when the file was not opened +// for writing. +func (file *GridFile) Abort() { + if file.mode != gfsWriting { + panic("file.Abort must be called on file opened for writing") + } + file.err = errors.New("write aborted") +} + +// Write writes the provided data to the file and returns the +// number of bytes written and an error in case something +// wrong happened. +// +// The file will internally cache the data so that all but the last +// chunk sent to the database have the size defined by SetChunkSize. +// This also means that errors may be deferred until a future call +// to Write or Close. +// +// The parameters and behavior of this function turn the file +// into an io.Writer. +func (file *GridFile) Write(data []byte) (n int, err error) { + file.assertMode(gfsWriting) + file.m.Lock() + debugf("GridFile %p: writing %d bytes", file, len(data)) + defer file.m.Unlock() + + if file.err != nil { + return 0, file.err + } + + n = len(data) + file.doc.Length += int64(n) + chunkSize := file.doc.ChunkSize + + if len(file.wbuf)+len(data) < chunkSize { + file.wbuf = append(file.wbuf, data...) + return + } + + // First, flush file.wbuf complementing with data. + if len(file.wbuf) > 0 { + missing := chunkSize - len(file.wbuf) + if missing > len(data) { + missing = len(data) + } + file.wbuf = append(file.wbuf, data[:missing]...) + data = data[missing:] + file.insertChunk(file.wbuf) + file.wbuf = file.wbuf[0:0] + } + + // Then, flush all chunks from data without copying. + for len(data) > chunkSize { + size := chunkSize + if size > len(data) { + size = len(data) + } + file.insertChunk(data[:size]) + data = data[size:] + } + + // And append the rest for a future call. + file.wbuf = append(file.wbuf, data...) + + return n, file.err +} + +func (file *GridFile) insertChunk(data []byte) { + n := file.chunk + file.chunk++ + debugf("GridFile %p: adding to checksum: %q", file, string(data)) + file.wsum.Write(data) + + for file.doc.ChunkSize*file.wpending >= 1024*1024 { + // Hold on.. we got a MB pending. + file.c.Wait() + if file.err != nil { + return + } + } + + file.wpending++ + + debugf("GridFile %p: inserting chunk %d with %d bytes", file, n, len(data)) + + // We may not own the memory of data, so rather than + // simply copying it, we'll marshal the document ahead of time. + data, err := bson.Marshal(gfsChunk{bson.NewObjectId(), file.doc.Id, n, data}) + if err != nil { + file.err = err + return + } + + go func() { + err := file.gfs.Chunks.Insert(bson.Raw{Data: data}) + file.m.Lock() + file.wpending-- + if err != nil && file.err == nil { + file.err = err + } + file.c.Broadcast() + file.m.Unlock() + }() +} + +// Seek sets the offset for the next Read or Write on file to +// offset, interpreted according to whence: 0 means relative to +// the origin of the file, 1 means relative to the current offset, +// and 2 means relative to the end. It returns the new offset and +// an error, if any. +func (file *GridFile) Seek(offset int64, whence int) (pos int64, err error) { + file.m.Lock() + debugf("GridFile %p: seeking for %s (whence=%d)", file, offset, whence) + defer file.m.Unlock() + switch whence { + case os.SEEK_SET: + case os.SEEK_CUR: + offset += file.offset + case os.SEEK_END: + offset += file.doc.Length + default: + panic("unsupported whence value") + } + if offset > file.doc.Length { + return file.offset, errors.New("seek past end of file") + } + if offset == file.doc.Length { + // If we're seeking to the end of the file, + // no need to read anything. This enables + // a client to find the size of the file using only the + // io.ReadSeeker interface with low overhead. + file.offset = offset + return file.offset, nil + } + chunk := int(offset / int64(file.doc.ChunkSize)) + if chunk+1 == file.chunk && offset >= file.offset { + file.rbuf = file.rbuf[int(offset-file.offset):] + file.offset = offset + return file.offset, nil + } + file.offset = offset + file.chunk = chunk + file.rbuf = nil + file.rbuf, err = file.getChunk() + if err == nil { + file.rbuf = file.rbuf[int(file.offset-int64(chunk)*int64(file.doc.ChunkSize)):] + } + return file.offset, err +} + +// Read reads into b the next available data from the file and +// returns the number of bytes written and an error in case +// something wrong happened. At the end of the file, n will +// be zero and err will be set to os.EOF. +// +// The parameters and behavior of this function turn the file +// into an io.Reader. +func (file *GridFile) Read(b []byte) (n int, err error) { + file.assertMode(gfsReading) + file.m.Lock() + debugf("GridFile %p: reading at offset %d into buffer of length %d", file, file.offset, len(b)) + defer file.m.Unlock() + if file.offset == file.doc.Length { + return 0, io.EOF + } + for err == nil { + i := copy(b, file.rbuf) + n += i + file.offset += int64(i) + file.rbuf = file.rbuf[i:] + if i == len(b) || file.offset == file.doc.Length { + break + } + b = b[i:] + file.rbuf, err = file.getChunk() + } + return n, err +} + +func (file *GridFile) getChunk() (data []byte, err error) { + cache := file.rcache + file.rcache = nil + if cache != nil && cache.n == file.chunk { + debugf("GridFile %p: Getting chunk %d from cache", file, file.chunk) + cache.wait.Lock() + data, err = cache.data, cache.err + } else { + debugf("GridFile %p: Fetching chunk %d", file, file.chunk) + var doc gfsChunk + err = file.gfs.Chunks.Find(bson.D{{"files_id", file.doc.Id}, {"n", file.chunk}}).One(&doc) + data = doc.Data + } + file.chunk++ + if int64(file.chunk)*int64(file.doc.ChunkSize) < file.doc.Length { + // Read the next one in background. + cache = &gfsCachedChunk{n: file.chunk} + cache.wait.Lock() + debugf("GridFile %p: Scheduling chunk %d for background caching", file, file.chunk) + // Clone the session to avoid having it closed in between. + chunks := file.gfs.Chunks + session := chunks.Database.Session.Clone() + go func(id interface{}, n int) { + defer session.Close() + chunks = chunks.With(session) + var doc gfsChunk + cache.err = chunks.Find(bson.D{{"files_id", id}, {"n", n}}).One(&doc) + cache.data = doc.Data + cache.wait.Unlock() + }(file.doc.Id, file.chunk) + file.rcache = cache + } + debugf("Returning err: %#v", err) + return +} diff --git a/vendor/gopkg.in/mgo.v2/gridfs_test.go b/vendor/gopkg.in/mgo.v2/gridfs_test.go new file mode 100644 index 000000000..5a6ed5559 --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/gridfs_test.go @@ -0,0 +1,708 @@ +// mgo - MongoDB driver for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package mgo_test + +import ( + "io" + "os" + "time" + + . "gopkg.in/check.v1" + "gopkg.in/mgo.v2" + "gopkg.in/mgo.v2/bson" +) + +func (s *S) TestGridFSCreate(c *C) { + session, err := mgo.Dial("localhost:40011") + c.Assert(err, IsNil) + defer session.Close() + + db := session.DB("mydb") + + before := bson.Now() + + gfs := db.GridFS("fs") + file, err := gfs.Create("") + c.Assert(err, IsNil) + + n, err := file.Write([]byte("some data")) + c.Assert(err, IsNil) + c.Assert(n, Equals, 9) + + err = file.Close() + c.Assert(err, IsNil) + + after := bson.Now() + + // Check the file information. + result := M{} + err = db.C("fs.files").Find(nil).One(result) + c.Assert(err, IsNil) + + fileId, ok := result["_id"].(bson.ObjectId) + c.Assert(ok, Equals, true) + c.Assert(fileId.Valid(), Equals, true) + result["_id"] = "" + + ud, ok := result["uploadDate"].(time.Time) + c.Assert(ok, Equals, true) + c.Assert(ud.After(before) && ud.Before(after), Equals, true) + result["uploadDate"] = "" + + expected := M{ + "_id": "", + "length": 9, + "chunkSize": 255 * 1024, + "uploadDate": "", + "md5": "1e50210a0202497fb79bc38b6ade6c34", + } + c.Assert(result, DeepEquals, expected) + + // Check the chunk. + result = M{} + err = db.C("fs.chunks").Find(nil).One(result) + c.Assert(err, IsNil) + + chunkId, ok := result["_id"].(bson.ObjectId) + c.Assert(ok, Equals, true) + c.Assert(chunkId.Valid(), Equals, true) + result["_id"] = "" + + expected = M{ + "_id": "", + "files_id": fileId, + "n": 0, + "data": []byte("some data"), + } + c.Assert(result, DeepEquals, expected) + + // Check that an index was created. + indexes, err := db.C("fs.chunks").Indexes() + c.Assert(err, IsNil) + c.Assert(len(indexes), Equals, 2) + c.Assert(indexes[1].Key, DeepEquals, []string{"files_id", "n"}) +} + +func (s *S) TestGridFSFileDetails(c *C) { + session, err := mgo.Dial("localhost:40011") + c.Assert(err, IsNil) + defer session.Close() + + db := session.DB("mydb") + + gfs := db.GridFS("fs") + + file, err := gfs.Create("myfile1.txt") + c.Assert(err, IsNil) + + n, err := file.Write([]byte("some")) + c.Assert(err, IsNil) + c.Assert(n, Equals, 4) + + c.Assert(file.Size(), Equals, int64(4)) + + n, err = file.Write([]byte(" data")) + c.Assert(err, IsNil) + c.Assert(n, Equals, 5) + + c.Assert(file.Size(), Equals, int64(9)) + + id, _ := file.Id().(bson.ObjectId) + c.Assert(id.Valid(), Equals, true) + c.Assert(file.Name(), Equals, "myfile1.txt") + c.Assert(file.ContentType(), Equals, "") + + var info interface{} + err = file.GetMeta(&info) + c.Assert(err, IsNil) + c.Assert(info, IsNil) + + file.SetId("myid") + file.SetName("myfile2.txt") + file.SetContentType("text/plain") + file.SetMeta(M{"any": "thing"}) + + c.Assert(file.Id(), Equals, "myid") + c.Assert(file.Name(), Equals, "myfile2.txt") + c.Assert(file.ContentType(), Equals, "text/plain") + + err = file.GetMeta(&info) + c.Assert(err, IsNil) + c.Assert(info, DeepEquals, bson.M{"any": "thing"}) + + err = file.Close() + c.Assert(err, IsNil) + + c.Assert(file.MD5(), Equals, "1e50210a0202497fb79bc38b6ade6c34") + + ud := file.UploadDate() + now := time.Now() + c.Assert(ud.Before(now), Equals, true) + c.Assert(ud.After(now.Add(-3*time.Second)), Equals, true) + + result := M{} + err = db.C("fs.files").Find(nil).One(result) + c.Assert(err, IsNil) + + result["uploadDate"] = "" + + expected := M{ + "_id": "myid", + "length": 9, + "chunkSize": 255 * 1024, + "uploadDate": "", + "md5": "1e50210a0202497fb79bc38b6ade6c34", + "filename": "myfile2.txt", + "contentType": "text/plain", + "metadata": M{"any": "thing"}, + } + c.Assert(result, DeepEquals, expected) +} + +func (s *S) TestGridFSSetUploadDate(c *C) { + session, err := mgo.Dial("localhost:40011") + c.Assert(err, IsNil) + defer session.Close() + + db := session.DB("mydb") + + gfs := db.GridFS("fs") + file, err := gfs.Create("") + c.Assert(err, IsNil) + + t := time.Date(2014, 1, 1, 1, 1, 1, 0, time.Local) + file.SetUploadDate(t) + + err = file.Close() + c.Assert(err, IsNil) + + // Check the file information. + result := M{} + err = db.C("fs.files").Find(nil).One(result) + c.Assert(err, IsNil) + + ud := result["uploadDate"].(time.Time) + if !ud.Equal(t) { + c.Fatalf("want upload date %s, got %s", t, ud) + } +} + +func (s *S) TestGridFSCreateWithChunking(c *C) { + session, err := mgo.Dial("localhost:40011") + c.Assert(err, IsNil) + defer session.Close() + + db := session.DB("mydb") + + gfs := db.GridFS("fs") + + file, err := gfs.Create("") + c.Assert(err, IsNil) + + file.SetChunkSize(5) + + // Smaller than the chunk size. + n, err := file.Write([]byte("abc")) + c.Assert(err, IsNil) + c.Assert(n, Equals, 3) + + // Boundary in the middle. + n, err = file.Write([]byte("defg")) + c.Assert(err, IsNil) + c.Assert(n, Equals, 4) + + // Boundary at the end. + n, err = file.Write([]byte("hij")) + c.Assert(err, IsNil) + c.Assert(n, Equals, 3) + + // Larger than the chunk size, with 3 chunks. + n, err = file.Write([]byte("klmnopqrstuv")) + c.Assert(err, IsNil) + c.Assert(n, Equals, 12) + + err = file.Close() + c.Assert(err, IsNil) + + // Check the file information. + result := M{} + err = db.C("fs.files").Find(nil).One(result) + c.Assert(err, IsNil) + + fileId, _ := result["_id"].(bson.ObjectId) + c.Assert(fileId.Valid(), Equals, true) + result["_id"] = "" + result["uploadDate"] = "" + + expected := M{ + "_id": "", + "length": 22, + "chunkSize": 5, + "uploadDate": "", + "md5": "44a66044834cbe55040089cabfc102d5", + } + c.Assert(result, DeepEquals, expected) + + // Check the chunks. + iter := db.C("fs.chunks").Find(nil).Sort("n").Iter() + dataChunks := []string{"abcde", "fghij", "klmno", "pqrst", "uv"} + for i := 0; ; i++ { + result = M{} + if !iter.Next(result) { + if i != 5 { + c.Fatalf("Expected 5 chunks, got %d", i) + } + break + } + c.Assert(iter.Close(), IsNil) + + result["_id"] = "" + + expected = M{ + "_id": "", + "files_id": fileId, + "n": i, + "data": []byte(dataChunks[i]), + } + c.Assert(result, DeepEquals, expected) + } +} + +func (s *S) TestGridFSAbort(c *C) { + session, err := mgo.Dial("localhost:40011") + c.Assert(err, IsNil) + defer session.Close() + + db := session.DB("mydb") + + gfs := db.GridFS("fs") + file, err := gfs.Create("") + c.Assert(err, IsNil) + + file.SetChunkSize(5) + + n, err := file.Write([]byte("some data")) + c.Assert(err, IsNil) + c.Assert(n, Equals, 9) + + var count int + for i := 0; i < 10; i++ { + count, err = db.C("fs.chunks").Count() + if count > 0 || err != nil { + break + } + } + c.Assert(err, IsNil) + c.Assert(count, Equals, 1) + + file.Abort() + + err = file.Close() + c.Assert(err, ErrorMatches, "write aborted") + + count, err = db.C("fs.chunks").Count() + c.Assert(err, IsNil) + c.Assert(count, Equals, 0) +} + +func (s *S) TestGridFSCloseConflict(c *C) { + session, err := mgo.Dial("localhost:40011") + c.Assert(err, IsNil) + defer session.Close() + + db := session.DB("mydb") + + db.C("fs.files").EnsureIndex(mgo.Index{Key: []string{"filename"}, Unique: true}) + + // For a closing-time conflict + err = db.C("fs.files").Insert(M{"filename": "foo.txt"}) + c.Assert(err, IsNil) + + gfs := db.GridFS("fs") + file, err := gfs.Create("foo.txt") + c.Assert(err, IsNil) + + _, err = file.Write([]byte("some data")) + c.Assert(err, IsNil) + + err = file.Close() + c.Assert(mgo.IsDup(err), Equals, true) + + count, err := db.C("fs.chunks").Count() + c.Assert(err, IsNil) + c.Assert(count, Equals, 0) +} + +func (s *S) TestGridFSOpenNotFound(c *C) { + session, err := mgo.Dial("localhost:40011") + c.Assert(err, IsNil) + defer session.Close() + + db := session.DB("mydb") + + gfs := db.GridFS("fs") + file, err := gfs.OpenId("non-existent") + c.Assert(err == mgo.ErrNotFound, Equals, true) + c.Assert(file, IsNil) + + file, err = gfs.Open("non-existent") + c.Assert(err == mgo.ErrNotFound, Equals, true) + c.Assert(file, IsNil) +} + +func (s *S) TestGridFSReadAll(c *C) { + session, err := mgo.Dial("localhost:40011") + c.Assert(err, IsNil) + defer session.Close() + + db := session.DB("mydb") + + gfs := db.GridFS("fs") + file, err := gfs.Create("") + c.Assert(err, IsNil) + id := file.Id() + + file.SetChunkSize(5) + + n, err := file.Write([]byte("abcdefghijklmnopqrstuv")) + c.Assert(err, IsNil) + c.Assert(n, Equals, 22) + + err = file.Close() + c.Assert(err, IsNil) + + file, err = gfs.OpenId(id) + c.Assert(err, IsNil) + + b := make([]byte, 30) + n, err = file.Read(b) + c.Assert(n, Equals, 22) + c.Assert(err, IsNil) + + n, err = file.Read(b) + c.Assert(n, Equals, 0) + c.Assert(err == io.EOF, Equals, true) + + err = file.Close() + c.Assert(err, IsNil) +} + +func (s *S) TestGridFSReadChunking(c *C) { + session, err := mgo.Dial("localhost:40011") + c.Assert(err, IsNil) + defer session.Close() + + db := session.DB("mydb") + + gfs := db.GridFS("fs") + + file, err := gfs.Create("") + c.Assert(err, IsNil) + + id := file.Id() + + file.SetChunkSize(5) + + n, err := file.Write([]byte("abcdefghijklmnopqrstuv")) + c.Assert(err, IsNil) + c.Assert(n, Equals, 22) + + err = file.Close() + c.Assert(err, IsNil) + + file, err = gfs.OpenId(id) + c.Assert(err, IsNil) + + b := make([]byte, 30) + + // Smaller than the chunk size. + n, err = file.Read(b[:3]) + c.Assert(err, IsNil) + c.Assert(n, Equals, 3) + c.Assert(b[:3], DeepEquals, []byte("abc")) + + // Boundary in the middle. + n, err = file.Read(b[:4]) + c.Assert(err, IsNil) + c.Assert(n, Equals, 4) + c.Assert(b[:4], DeepEquals, []byte("defg")) + + // Boundary at the end. + n, err = file.Read(b[:3]) + c.Assert(err, IsNil) + c.Assert(n, Equals, 3) + c.Assert(b[:3], DeepEquals, []byte("hij")) + + // Larger than the chunk size, with 3 chunks. + n, err = file.Read(b) + c.Assert(err, IsNil) + c.Assert(n, Equals, 12) + c.Assert(b[:12], DeepEquals, []byte("klmnopqrstuv")) + + n, err = file.Read(b) + c.Assert(n, Equals, 0) + c.Assert(err == io.EOF, Equals, true) + + err = file.Close() + c.Assert(err, IsNil) +} + +func (s *S) TestGridFSOpen(c *C) { + session, err := mgo.Dial("localhost:40011") + c.Assert(err, IsNil) + defer session.Close() + + db := session.DB("mydb") + + gfs := db.GridFS("fs") + + file, err := gfs.Create("myfile.txt") + c.Assert(err, IsNil) + file.Write([]byte{'1'}) + file.Close() + + file, err = gfs.Create("myfile.txt") + c.Assert(err, IsNil) + file.Write([]byte{'2'}) + file.Close() + + file, err = gfs.Open("myfile.txt") + c.Assert(err, IsNil) + defer file.Close() + + var b [1]byte + + _, err = file.Read(b[:]) + c.Assert(err, IsNil) + c.Assert(string(b[:]), Equals, "2") +} + +func (s *S) TestGridFSSeek(c *C) { + session, err := mgo.Dial("localhost:40011") + c.Assert(err, IsNil) + defer session.Close() + + db := session.DB("mydb") + + gfs := db.GridFS("fs") + file, err := gfs.Create("") + c.Assert(err, IsNil) + id := file.Id() + + file.SetChunkSize(5) + + n, err := file.Write([]byte("abcdefghijklmnopqrstuv")) + c.Assert(err, IsNil) + c.Assert(n, Equals, 22) + + err = file.Close() + c.Assert(err, IsNil) + + b := make([]byte, 5) + + file, err = gfs.OpenId(id) + c.Assert(err, IsNil) + + o, err := file.Seek(3, os.SEEK_SET) + c.Assert(err, IsNil) + c.Assert(o, Equals, int64(3)) + _, err = file.Read(b) + c.Assert(err, IsNil) + c.Assert(b, DeepEquals, []byte("defgh")) + + o, err = file.Seek(5, os.SEEK_CUR) + c.Assert(err, IsNil) + c.Assert(o, Equals, int64(13)) + _, err = file.Read(b) + c.Assert(err, IsNil) + c.Assert(b, DeepEquals, []byte("nopqr")) + + o, err = file.Seek(0, os.SEEK_END) + c.Assert(err, IsNil) + c.Assert(o, Equals, int64(22)) + n, err = file.Read(b) + c.Assert(err, Equals, io.EOF) + c.Assert(n, Equals, 0) + + o, err = file.Seek(-10, os.SEEK_END) + c.Assert(err, IsNil) + c.Assert(o, Equals, int64(12)) + _, err = file.Read(b) + c.Assert(err, IsNil) + c.Assert(b, DeepEquals, []byte("mnopq")) + + o, err = file.Seek(8, os.SEEK_SET) + c.Assert(err, IsNil) + c.Assert(o, Equals, int64(8)) + _, err = file.Read(b) + c.Assert(err, IsNil) + c.Assert(b, DeepEquals, []byte("ijklm")) + + // Trivial seek forward within same chunk. Already + // got the data, shouldn't touch the database. + sent := mgo.GetStats().SentOps + o, err = file.Seek(1, os.SEEK_CUR) + c.Assert(err, IsNil) + c.Assert(o, Equals, int64(14)) + c.Assert(mgo.GetStats().SentOps, Equals, sent) + _, err = file.Read(b) + c.Assert(err, IsNil) + c.Assert(b, DeepEquals, []byte("opqrs")) + + // Try seeking past end of file. + file.Seek(3, os.SEEK_SET) + o, err = file.Seek(23, os.SEEK_SET) + c.Assert(err, ErrorMatches, "seek past end of file") + c.Assert(o, Equals, int64(3)) +} + +func (s *S) TestGridFSRemoveId(c *C) { + session, err := mgo.Dial("localhost:40011") + c.Assert(err, IsNil) + defer session.Close() + + db := session.DB("mydb") + + gfs := db.GridFS("fs") + + file, err := gfs.Create("myfile.txt") + c.Assert(err, IsNil) + file.Write([]byte{'1'}) + file.Close() + + file, err = gfs.Create("myfile.txt") + c.Assert(err, IsNil) + file.Write([]byte{'2'}) + id := file.Id() + file.Close() + + err = gfs.RemoveId(id) + c.Assert(err, IsNil) + + file, err = gfs.Open("myfile.txt") + c.Assert(err, IsNil) + defer file.Close() + + var b [1]byte + + _, err = file.Read(b[:]) + c.Assert(err, IsNil) + c.Assert(string(b[:]), Equals, "1") + + n, err := db.C("fs.chunks").Find(M{"files_id": id}).Count() + c.Assert(err, IsNil) + c.Assert(n, Equals, 0) +} + +func (s *S) TestGridFSRemove(c *C) { + session, err := mgo.Dial("localhost:40011") + c.Assert(err, IsNil) + defer session.Close() + + db := session.DB("mydb") + + gfs := db.GridFS("fs") + + file, err := gfs.Create("myfile.txt") + c.Assert(err, IsNil) + file.Write([]byte{'1'}) + file.Close() + + file, err = gfs.Create("myfile.txt") + c.Assert(err, IsNil) + file.Write([]byte{'2'}) + file.Close() + + err = gfs.Remove("myfile.txt") + c.Assert(err, IsNil) + + _, err = gfs.Open("myfile.txt") + c.Assert(err == mgo.ErrNotFound, Equals, true) + + n, err := db.C("fs.chunks").Find(nil).Count() + c.Assert(err, IsNil) + c.Assert(n, Equals, 0) +} + +func (s *S) TestGridFSOpenNext(c *C) { + session, err := mgo.Dial("localhost:40011") + c.Assert(err, IsNil) + defer session.Close() + + db := session.DB("mydb") + + gfs := db.GridFS("fs") + + file, err := gfs.Create("myfile1.txt") + c.Assert(err, IsNil) + file.Write([]byte{'1'}) + file.Close() + + file, err = gfs.Create("myfile2.txt") + c.Assert(err, IsNil) + file.Write([]byte{'2'}) + file.Close() + + var f *mgo.GridFile + var b [1]byte + + iter := gfs.Find(nil).Sort("-filename").Iter() + + ok := gfs.OpenNext(iter, &f) + c.Assert(ok, Equals, true) + c.Check(f.Name(), Equals, "myfile2.txt") + + _, err = f.Read(b[:]) + c.Assert(err, IsNil) + c.Assert(string(b[:]), Equals, "2") + + ok = gfs.OpenNext(iter, &f) + c.Assert(ok, Equals, true) + c.Check(f.Name(), Equals, "myfile1.txt") + + _, err = f.Read(b[:]) + c.Assert(err, IsNil) + c.Assert(string(b[:]), Equals, "1") + + ok = gfs.OpenNext(iter, &f) + c.Assert(ok, Equals, false) + c.Assert(iter.Close(), IsNil) + c.Assert(f, IsNil) + + // Do it again with a more restrictive query to make sure + // it's actually taken into account. + iter = gfs.Find(bson.M{"filename": "myfile1.txt"}).Iter() + + ok = gfs.OpenNext(iter, &f) + c.Assert(ok, Equals, true) + c.Check(f.Name(), Equals, "myfile1.txt") + + ok = gfs.OpenNext(iter, &f) + c.Assert(ok, Equals, false) + c.Assert(iter.Close(), IsNil) + c.Assert(f, IsNil) +} diff --git a/vendor/gopkg.in/mgo.v2/internal/sasl/sasl.c b/vendor/gopkg.in/mgo.v2/internal/sasl/sasl.c new file mode 100644 index 000000000..8be0bc459 --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/internal/sasl/sasl.c @@ -0,0 +1,77 @@ +// +build !windows + +#include +#include +#include +#include + +static int mgo_sasl_simple(void *context, int id, const char **result, unsigned int *len) +{ + if (!result) { + return SASL_BADPARAM; + } + switch (id) { + case SASL_CB_USER: + *result = (char *)context; + break; + case SASL_CB_AUTHNAME: + *result = (char *)context; + break; + case SASL_CB_LANGUAGE: + *result = NULL; + break; + default: + return SASL_BADPARAM; + } + if (len) { + *len = *result ? strlen(*result) : 0; + } + return SASL_OK; +} + +typedef int (*callback)(void); + +static int mgo_sasl_secret(sasl_conn_t *conn, void *context, int id, sasl_secret_t **result) +{ + if (!conn || !result || id != SASL_CB_PASS) { + return SASL_BADPARAM; + } + *result = (sasl_secret_t *)context; + return SASL_OK; +} + +sasl_callback_t *mgo_sasl_callbacks(const char *username, const char *password) +{ + sasl_callback_t *cb = malloc(4 * sizeof(sasl_callback_t)); + int n = 0; + + size_t len = strlen(password); + sasl_secret_t *secret = (sasl_secret_t*)malloc(sizeof(sasl_secret_t) + len); + if (!secret) { + free(cb); + return NULL; + } + strcpy((char *)secret->data, password); + secret->len = len; + + cb[n].id = SASL_CB_PASS; + cb[n].proc = (callback)&mgo_sasl_secret; + cb[n].context = secret; + n++; + + cb[n].id = SASL_CB_USER; + cb[n].proc = (callback)&mgo_sasl_simple; + cb[n].context = (char*)username; + n++; + + cb[n].id = SASL_CB_AUTHNAME; + cb[n].proc = (callback)&mgo_sasl_simple; + cb[n].context = (char*)username; + n++; + + cb[n].id = SASL_CB_LIST_END; + cb[n].proc = NULL; + cb[n].context = NULL; + + return cb; +} diff --git a/vendor/gopkg.in/mgo.v2/internal/sasl/sasl.go b/vendor/gopkg.in/mgo.v2/internal/sasl/sasl.go new file mode 100644 index 000000000..8375dddf8 --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/internal/sasl/sasl.go @@ -0,0 +1,138 @@ +// Package sasl is an implementation detail of the mgo package. +// +// This package is not meant to be used by itself. +// + +// +build !windows + +package sasl + +// #cgo LDFLAGS: -lsasl2 +// +// struct sasl_conn {}; +// +// #include +// #include +// +// sasl_callback_t *mgo_sasl_callbacks(const char *username, const char *password); +// +import "C" + +import ( + "fmt" + "strings" + "sync" + "unsafe" +) + +type saslStepper interface { + Step(serverData []byte) (clientData []byte, done bool, err error) + Close() +} + +type saslSession struct { + conn *C.sasl_conn_t + step int + mech string + + cstrings []*C.char + callbacks *C.sasl_callback_t +} + +var initError error +var initOnce sync.Once + +func initSASL() { + rc := C.sasl_client_init(nil) + if rc != C.SASL_OK { + initError = saslError(rc, nil, "cannot initialize SASL library") + } +} + +func New(username, password, mechanism, service, host string) (saslStepper, error) { + initOnce.Do(initSASL) + if initError != nil { + return nil, initError + } + + ss := &saslSession{mech: mechanism} + if service == "" { + service = "mongodb" + } + if i := strings.Index(host, ":"); i >= 0 { + host = host[:i] + } + ss.callbacks = C.mgo_sasl_callbacks(ss.cstr(username), ss.cstr(password)) + rc := C.sasl_client_new(ss.cstr(service), ss.cstr(host), nil, nil, ss.callbacks, 0, &ss.conn) + if rc != C.SASL_OK { + ss.Close() + return nil, saslError(rc, nil, "cannot create new SASL client") + } + return ss, nil +} + +func (ss *saslSession) cstr(s string) *C.char { + cstr := C.CString(s) + ss.cstrings = append(ss.cstrings, cstr) + return cstr +} + +func (ss *saslSession) Close() { + for _, cstr := range ss.cstrings { + C.free(unsafe.Pointer(cstr)) + } + ss.cstrings = nil + + if ss.callbacks != nil { + C.free(unsafe.Pointer(ss.callbacks)) + } + + // The documentation of SASL dispose makes it clear that this should only + // be done when the connection is done, not when the authentication phase + // is done, because an encryption layer may have been negotiated. + // Even then, we'll do this for now, because it's simpler and prevents + // keeping track of this state for every socket. If it breaks, we'll fix it. + C.sasl_dispose(&ss.conn) +} + +func (ss *saslSession) Step(serverData []byte) (clientData []byte, done bool, err error) { + ss.step++ + if ss.step > 10 { + return nil, false, fmt.Errorf("too many SASL steps without authentication") + } + var cclientData *C.char + var cclientDataLen C.uint + var rc C.int + if ss.step == 1 { + var mechanism *C.char // ignored - must match cred + rc = C.sasl_client_start(ss.conn, ss.cstr(ss.mech), nil, &cclientData, &cclientDataLen, &mechanism) + } else { + var cserverData *C.char + var cserverDataLen C.uint + if len(serverData) > 0 { + cserverData = (*C.char)(unsafe.Pointer(&serverData[0])) + cserverDataLen = C.uint(len(serverData)) + } + rc = C.sasl_client_step(ss.conn, cserverData, cserverDataLen, nil, &cclientData, &cclientDataLen) + } + if cclientData != nil && cclientDataLen > 0 { + clientData = C.GoBytes(unsafe.Pointer(cclientData), C.int(cclientDataLen)) + } + if rc == C.SASL_OK { + return clientData, true, nil + } + if rc == C.SASL_CONTINUE { + return clientData, false, nil + } + return nil, false, saslError(rc, ss.conn, "cannot establish SASL session") +} + +func saslError(rc C.int, conn *C.sasl_conn_t, msg string) error { + var detail string + if conn == nil { + detail = C.GoString(C.sasl_errstring(rc, nil, nil)) + } else { + detail = C.GoString(C.sasl_errdetail(conn)) + } + return fmt.Errorf(msg + ": " + detail) +} diff --git a/vendor/gopkg.in/mgo.v2/internal/sasl/sasl_windows.c b/vendor/gopkg.in/mgo.v2/internal/sasl/sasl_windows.c new file mode 100644 index 000000000..dd6a88ab6 --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/internal/sasl/sasl_windows.c @@ -0,0 +1,118 @@ +#include "sasl_windows.h" + +static const LPSTR SSPI_PACKAGE_NAME = "kerberos"; + +SECURITY_STATUS SEC_ENTRY sspi_acquire_credentials_handle(CredHandle *cred_handle, char *username, char *password, char *domain) +{ + SEC_WINNT_AUTH_IDENTITY auth_identity; + SECURITY_INTEGER ignored; + + auth_identity.Flags = SEC_WINNT_AUTH_IDENTITY_ANSI; + auth_identity.User = (LPSTR) username; + auth_identity.UserLength = strlen(username); + auth_identity.Password = (LPSTR) password; + auth_identity.PasswordLength = strlen(password); + auth_identity.Domain = (LPSTR) domain; + auth_identity.DomainLength = strlen(domain); + return call_sspi_acquire_credentials_handle(NULL, SSPI_PACKAGE_NAME, SECPKG_CRED_OUTBOUND, NULL, &auth_identity, NULL, NULL, cred_handle, &ignored); +} + +int sspi_step(CredHandle *cred_handle, int has_context, CtxtHandle *context, PVOID *buffer, ULONG *buffer_length, char *target) +{ + SecBufferDesc inbuf; + SecBuffer in_bufs[1]; + SecBufferDesc outbuf; + SecBuffer out_bufs[1]; + + if (has_context > 0) { + // If we already have a context, we now have data to send. + // Put this data in an inbuf. + inbuf.ulVersion = SECBUFFER_VERSION; + inbuf.cBuffers = 1; + inbuf.pBuffers = in_bufs; + in_bufs[0].pvBuffer = *buffer; + in_bufs[0].cbBuffer = *buffer_length; + in_bufs[0].BufferType = SECBUFFER_TOKEN; + } + + outbuf.ulVersion = SECBUFFER_VERSION; + outbuf.cBuffers = 1; + outbuf.pBuffers = out_bufs; + out_bufs[0].pvBuffer = NULL; + out_bufs[0].cbBuffer = 0; + out_bufs[0].BufferType = SECBUFFER_TOKEN; + + ULONG context_attr = 0; + + int ret = call_sspi_initialize_security_context(cred_handle, + has_context > 0 ? context : NULL, + (LPSTR) target, + ISC_REQ_ALLOCATE_MEMORY | ISC_REQ_MUTUAL_AUTH, + 0, + SECURITY_NETWORK_DREP, + has_context > 0 ? &inbuf : NULL, + 0, + context, + &outbuf, + &context_attr, + NULL); + + *buffer = malloc(out_bufs[0].cbBuffer); + *buffer_length = out_bufs[0].cbBuffer; + memcpy(*buffer, out_bufs[0].pvBuffer, *buffer_length); + + return ret; +} + +int sspi_send_client_authz_id(CtxtHandle *context, PVOID *buffer, ULONG *buffer_length, char *user_plus_realm) +{ + SecPkgContext_Sizes sizes; + SECURITY_STATUS status = call_sspi_query_context_attributes(context, SECPKG_ATTR_SIZES, &sizes); + + if (status != SEC_E_OK) { + return status; + } + + size_t user_plus_realm_length = strlen(user_plus_realm); + int msgSize = 4 + user_plus_realm_length; + char *msg = malloc((sizes.cbSecurityTrailer + msgSize + sizes.cbBlockSize) * sizeof(char)); + msg[sizes.cbSecurityTrailer + 0] = 1; + msg[sizes.cbSecurityTrailer + 1] = 0; + msg[sizes.cbSecurityTrailer + 2] = 0; + msg[sizes.cbSecurityTrailer + 3] = 0; + memcpy(&msg[sizes.cbSecurityTrailer + 4], user_plus_realm, user_plus_realm_length); + + SecBuffer wrapBufs[3]; + SecBufferDesc wrapBufDesc; + wrapBufDesc.cBuffers = 3; + wrapBufDesc.pBuffers = wrapBufs; + wrapBufDesc.ulVersion = SECBUFFER_VERSION; + + wrapBufs[0].cbBuffer = sizes.cbSecurityTrailer; + wrapBufs[0].BufferType = SECBUFFER_TOKEN; + wrapBufs[0].pvBuffer = msg; + + wrapBufs[1].cbBuffer = msgSize; + wrapBufs[1].BufferType = SECBUFFER_DATA; + wrapBufs[1].pvBuffer = msg + sizes.cbSecurityTrailer; + + wrapBufs[2].cbBuffer = sizes.cbBlockSize; + wrapBufs[2].BufferType = SECBUFFER_PADDING; + wrapBufs[2].pvBuffer = msg + sizes.cbSecurityTrailer + msgSize; + + status = call_sspi_encrypt_message(context, SECQOP_WRAP_NO_ENCRYPT, &wrapBufDesc, 0); + if (status != SEC_E_OK) { + free(msg); + return status; + } + + *buffer_length = wrapBufs[0].cbBuffer + wrapBufs[1].cbBuffer + wrapBufs[2].cbBuffer; + *buffer = malloc(*buffer_length); + + memcpy(*buffer, wrapBufs[0].pvBuffer, wrapBufs[0].cbBuffer); + memcpy(*buffer + wrapBufs[0].cbBuffer, wrapBufs[1].pvBuffer, wrapBufs[1].cbBuffer); + memcpy(*buffer + wrapBufs[0].cbBuffer + wrapBufs[1].cbBuffer, wrapBufs[2].pvBuffer, wrapBufs[2].cbBuffer); + + free(msg); + return SEC_E_OK; +} diff --git a/vendor/gopkg.in/mgo.v2/internal/sasl/sasl_windows.go b/vendor/gopkg.in/mgo.v2/internal/sasl/sasl_windows.go new file mode 100644 index 000000000..3302cfe05 --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/internal/sasl/sasl_windows.go @@ -0,0 +1,140 @@ +package sasl + +// #include "sasl_windows.h" +import "C" + +import ( + "fmt" + "strings" + "sync" + "unsafe" +) + +type saslStepper interface { + Step(serverData []byte) (clientData []byte, done bool, err error) + Close() +} + +type saslSession struct { + // Credentials + mech string + service string + host string + userPlusRealm string + target string + domain string + + // Internal state + authComplete bool + errored bool + step int + + // C internal state + credHandle C.CredHandle + context C.CtxtHandle + hasContext C.int + + // Keep track of pointers we need to explicitly free + stringsToFree []*C.char +} + +var initError error +var initOnce sync.Once + +func initSSPI() { + rc := C.load_secur32_dll() + if rc != 0 { + initError = fmt.Errorf("Error loading libraries: %v", rc) + } +} + +func New(username, password, mechanism, service, host string) (saslStepper, error) { + initOnce.Do(initSSPI) + ss := &saslSession{mech: mechanism, hasContext: 0, userPlusRealm: username} + if service == "" { + service = "mongodb" + } + if i := strings.Index(host, ":"); i >= 0 { + host = host[:i] + } + ss.service = service + ss.host = host + + usernameComponents := strings.Split(username, "@") + if len(usernameComponents) < 2 { + return nil, fmt.Errorf("Username '%v' doesn't contain a realm!", username) + } + user := usernameComponents[0] + ss.domain = usernameComponents[1] + ss.target = fmt.Sprintf("%s/%s", ss.service, ss.host) + + var status C.SECURITY_STATUS + // Step 0: call AcquireCredentialsHandle to get a nice SSPI CredHandle + if len(password) > 0 { + status = C.sspi_acquire_credentials_handle(&ss.credHandle, ss.cstr(user), ss.cstr(password), ss.cstr(ss.domain)) + } else { + status = C.sspi_acquire_credentials_handle(&ss.credHandle, ss.cstr(user), nil, ss.cstr(ss.domain)) + } + if status != C.SEC_E_OK { + ss.errored = true + return nil, fmt.Errorf("Couldn't create new SSPI client, error code %v", status) + } + return ss, nil +} + +func (ss *saslSession) cstr(s string) *C.char { + cstr := C.CString(s) + ss.stringsToFree = append(ss.stringsToFree, cstr) + return cstr +} + +func (ss *saslSession) Close() { + for _, cstr := range ss.stringsToFree { + C.free(unsafe.Pointer(cstr)) + } +} + +func (ss *saslSession) Step(serverData []byte) (clientData []byte, done bool, err error) { + ss.step++ + if ss.step > 10 { + return nil, false, fmt.Errorf("too many SSPI steps without authentication") + } + var buffer C.PVOID + var bufferLength C.ULONG + if len(serverData) > 0 { + buffer = (C.PVOID)(unsafe.Pointer(&serverData[0])) + bufferLength = C.ULONG(len(serverData)) + } + var status C.int + if ss.authComplete { + // Step 3: last bit of magic to use the correct server credentials + status = C.sspi_send_client_authz_id(&ss.context, &buffer, &bufferLength, ss.cstr(ss.userPlusRealm)) + } else { + // Step 1 + Step 2: set up security context with the server and TGT + status = C.sspi_step(&ss.credHandle, ss.hasContext, &ss.context, &buffer, &bufferLength, ss.cstr(ss.target)) + } + if buffer != C.PVOID(nil) { + defer C.free(unsafe.Pointer(buffer)) + } + if status != C.SEC_E_OK && status != C.SEC_I_CONTINUE_NEEDED { + ss.errored = true + return nil, false, ss.handleSSPIErrorCode(status) + } + + clientData = C.GoBytes(unsafe.Pointer(buffer), C.int(bufferLength)) + if status == C.SEC_E_OK { + ss.authComplete = true + return clientData, true, nil + } else { + ss.hasContext = 1 + return clientData, false, nil + } +} + +func (ss *saslSession) handleSSPIErrorCode(code C.int) error { + switch { + case code == C.SEC_E_TARGET_UNKNOWN: + return fmt.Errorf("Target %v@%v not found", ss.target, ss.domain) + } + return fmt.Errorf("Unknown error doing step %v, error code %v", ss.step, code) +} diff --git a/vendor/gopkg.in/mgo.v2/internal/sasl/sasl_windows.h b/vendor/gopkg.in/mgo.v2/internal/sasl/sasl_windows.h new file mode 100644 index 000000000..94321b208 --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/internal/sasl/sasl_windows.h @@ -0,0 +1,7 @@ +#include + +#include "sspi_windows.h" + +SECURITY_STATUS SEC_ENTRY sspi_acquire_credentials_handle(CredHandle* cred_handle, char* username, char* password, char* domain); +int sspi_step(CredHandle* cred_handle, int has_context, CtxtHandle* context, PVOID* buffer, ULONG* buffer_length, char* target); +int sspi_send_client_authz_id(CtxtHandle* context, PVOID* buffer, ULONG* buffer_length, char* user_plus_realm); diff --git a/vendor/gopkg.in/mgo.v2/internal/sasl/sspi_windows.c b/vendor/gopkg.in/mgo.v2/internal/sasl/sspi_windows.c new file mode 100644 index 000000000..63f9a6f86 --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/internal/sasl/sspi_windows.c @@ -0,0 +1,96 @@ +// Code adapted from the NodeJS kerberos library: +// +// https://github.com/christkv/kerberos/tree/master/lib/win32/kerberos_sspi.c +// +// Under the terms of the Apache License, Version 2.0: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +#include + +#include "sspi_windows.h" + +static HINSTANCE sspi_secur32_dll = NULL; + +int load_secur32_dll() +{ + sspi_secur32_dll = LoadLibrary("secur32.dll"); + if (sspi_secur32_dll == NULL) { + return GetLastError(); + } + return 0; +} + +SECURITY_STATUS SEC_ENTRY call_sspi_encrypt_message(PCtxtHandle phContext, unsigned long fQOP, PSecBufferDesc pMessage, unsigned long MessageSeqNo) +{ + if (sspi_secur32_dll == NULL) { + return -1; + } + encryptMessage_fn pfn_encryptMessage = (encryptMessage_fn) GetProcAddress(sspi_secur32_dll, "EncryptMessage"); + if (!pfn_encryptMessage) { + return -2; + } + return (*pfn_encryptMessage)(phContext, fQOP, pMessage, MessageSeqNo); +} + +SECURITY_STATUS SEC_ENTRY call_sspi_acquire_credentials_handle( + LPSTR pszPrincipal, LPSTR pszPackage, unsigned long fCredentialUse, + void *pvLogonId, void *pAuthData, SEC_GET_KEY_FN pGetKeyFn, void *pvGetKeyArgument, + PCredHandle phCredential, PTimeStamp ptsExpiry) +{ + if (sspi_secur32_dll == NULL) { + return -1; + } + acquireCredentialsHandle_fn pfn_acquireCredentialsHandle; +#ifdef _UNICODE + pfn_acquireCredentialsHandle = (acquireCredentialsHandle_fn) GetProcAddress(sspi_secur32_dll, "AcquireCredentialsHandleW"); +#else + pfn_acquireCredentialsHandle = (acquireCredentialsHandle_fn) GetProcAddress(sspi_secur32_dll, "AcquireCredentialsHandleA"); +#endif + if (!pfn_acquireCredentialsHandle) { + return -2; + } + return (*pfn_acquireCredentialsHandle)( + pszPrincipal, pszPackage, fCredentialUse, pvLogonId, pAuthData, + pGetKeyFn, pvGetKeyArgument, phCredential, ptsExpiry); +} + +SECURITY_STATUS SEC_ENTRY call_sspi_initialize_security_context( + PCredHandle phCredential, PCtxtHandle phContext, LPSTR pszTargetName, + unsigned long fContextReq, unsigned long Reserved1, unsigned long TargetDataRep, + PSecBufferDesc pInput, unsigned long Reserved2, PCtxtHandle phNewContext, + PSecBufferDesc pOutput, unsigned long *pfContextAttr, PTimeStamp ptsExpiry) +{ + if (sspi_secur32_dll == NULL) { + return -1; + } + initializeSecurityContext_fn pfn_initializeSecurityContext; +#ifdef _UNICODE + pfn_initializeSecurityContext = (initializeSecurityContext_fn) GetProcAddress(sspi_secur32_dll, "InitializeSecurityContextW"); +#else + pfn_initializeSecurityContext = (initializeSecurityContext_fn) GetProcAddress(sspi_secur32_dll, "InitializeSecurityContextA"); +#endif + if (!pfn_initializeSecurityContext) { + return -2; + } + return (*pfn_initializeSecurityContext)( + phCredential, phContext, pszTargetName, fContextReq, Reserved1, TargetDataRep, + pInput, Reserved2, phNewContext, pOutput, pfContextAttr, ptsExpiry); +} + +SECURITY_STATUS SEC_ENTRY call_sspi_query_context_attributes(PCtxtHandle phContext, unsigned long ulAttribute, void *pBuffer) +{ + if (sspi_secur32_dll == NULL) { + return -1; + } + queryContextAttributes_fn pfn_queryContextAttributes; +#ifdef _UNICODE + pfn_queryContextAttributes = (queryContextAttributes_fn) GetProcAddress(sspi_secur32_dll, "QueryContextAttributesW"); +#else + pfn_queryContextAttributes = (queryContextAttributes_fn) GetProcAddress(sspi_secur32_dll, "QueryContextAttributesA"); +#endif + if (!pfn_queryContextAttributes) { + return -2; + } + return (*pfn_queryContextAttributes)(phContext, ulAttribute, pBuffer); +} diff --git a/vendor/gopkg.in/mgo.v2/internal/sasl/sspi_windows.h b/vendor/gopkg.in/mgo.v2/internal/sasl/sspi_windows.h new file mode 100644 index 000000000..d28327031 --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/internal/sasl/sspi_windows.h @@ -0,0 +1,70 @@ +// Code adapted from the NodeJS kerberos library: +// +// https://github.com/christkv/kerberos/tree/master/lib/win32/kerberos_sspi.h +// +// Under the terms of the Apache License, Version 2.0: +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +#ifndef SSPI_WINDOWS_H +#define SSPI_WINDOWS_H + +#define SECURITY_WIN32 1 + +#include +#include + +int load_secur32_dll(); + +SECURITY_STATUS SEC_ENTRY call_sspi_encrypt_message(PCtxtHandle phContext, unsigned long fQOP, PSecBufferDesc pMessage, unsigned long MessageSeqNo); + +typedef DWORD (WINAPI *encryptMessage_fn)(PCtxtHandle phContext, ULONG fQOP, PSecBufferDesc pMessage, ULONG MessageSeqNo); + +SECURITY_STATUS SEC_ENTRY call_sspi_acquire_credentials_handle( + LPSTR pszPrincipal, // Name of principal + LPSTR pszPackage, // Name of package + unsigned long fCredentialUse, // Flags indicating use + void *pvLogonId, // Pointer to logon ID + void *pAuthData, // Package specific data + SEC_GET_KEY_FN pGetKeyFn, // Pointer to GetKey() func + void *pvGetKeyArgument, // Value to pass to GetKey() + PCredHandle phCredential, // (out) Cred Handle + PTimeStamp ptsExpiry // (out) Lifetime (optional) +); + +typedef DWORD (WINAPI *acquireCredentialsHandle_fn)( + LPSTR pszPrincipal, LPSTR pszPackage, unsigned long fCredentialUse, + void *pvLogonId, void *pAuthData, SEC_GET_KEY_FN pGetKeyFn, void *pvGetKeyArgument, + PCredHandle phCredential, PTimeStamp ptsExpiry +); + +SECURITY_STATUS SEC_ENTRY call_sspi_initialize_security_context( + PCredHandle phCredential, // Cred to base context + PCtxtHandle phContext, // Existing context (OPT) + LPSTR pszTargetName, // Name of target + unsigned long fContextReq, // Context Requirements + unsigned long Reserved1, // Reserved, MBZ + unsigned long TargetDataRep, // Data rep of target + PSecBufferDesc pInput, // Input Buffers + unsigned long Reserved2, // Reserved, MBZ + PCtxtHandle phNewContext, // (out) New Context handle + PSecBufferDesc pOutput, // (inout) Output Buffers + unsigned long *pfContextAttr, // (out) Context attrs + PTimeStamp ptsExpiry // (out) Life span (OPT) +); + +typedef DWORD (WINAPI *initializeSecurityContext_fn)( + PCredHandle phCredential, PCtxtHandle phContext, LPSTR pszTargetName, unsigned long fContextReq, + unsigned long Reserved1, unsigned long TargetDataRep, PSecBufferDesc pInput, unsigned long Reserved2, + PCtxtHandle phNewContext, PSecBufferDesc pOutput, unsigned long *pfContextAttr, PTimeStamp ptsExpiry); + +SECURITY_STATUS SEC_ENTRY call_sspi_query_context_attributes( + PCtxtHandle phContext, // Context to query + unsigned long ulAttribute, // Attribute to query + void *pBuffer // Buffer for attributes +); + +typedef DWORD (WINAPI *queryContextAttributes_fn)( + PCtxtHandle phContext, unsigned long ulAttribute, void *pBuffer); + +#endif // SSPI_WINDOWS_H diff --git a/vendor/gopkg.in/mgo.v2/internal/scram/scram.go b/vendor/gopkg.in/mgo.v2/internal/scram/scram.go new file mode 100644 index 000000000..80cda9135 --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/internal/scram/scram.go @@ -0,0 +1,266 @@ +// mgo - MongoDB driver for Go +// +// Copyright (c) 2014 - Gustavo Niemeyer +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +// Pacakage scram implements a SCRAM-{SHA-1,etc} client per RFC5802. +// +// http://tools.ietf.org/html/rfc5802 +// +package scram + +import ( + "bytes" + "crypto/hmac" + "crypto/rand" + "encoding/base64" + "fmt" + "hash" + "strconv" + "strings" +) + +// Client implements a SCRAM-* client (SCRAM-SHA-1, SCRAM-SHA-256, etc). +// +// A Client may be used within a SASL conversation with logic resembling: +// +// var in []byte +// var client = scram.NewClient(sha1.New, user, pass) +// for client.Step(in) { +// out := client.Out() +// // send out to server +// in := serverOut +// } +// if client.Err() != nil { +// // auth failed +// } +// +type Client struct { + newHash func() hash.Hash + + user string + pass string + step int + out bytes.Buffer + err error + + clientNonce []byte + serverNonce []byte + saltedPass []byte + authMsg bytes.Buffer +} + +// NewClient returns a new SCRAM-* client with the provided hash algorithm. +// +// For SCRAM-SHA-1, for example, use: +// +// client := scram.NewClient(sha1.New, user, pass) +// +func NewClient(newHash func() hash.Hash, user, pass string) *Client { + c := &Client{ + newHash: newHash, + user: user, + pass: pass, + } + c.out.Grow(256) + c.authMsg.Grow(256) + return c +} + +// Out returns the data to be sent to the server in the current step. +func (c *Client) Out() []byte { + if c.out.Len() == 0 { + return nil + } + return c.out.Bytes() +} + +// Err returns the error that ocurred, or nil if there were no errors. +func (c *Client) Err() error { + return c.err +} + +// SetNonce sets the client nonce to the provided value. +// If not set, the nonce is generated automatically out of crypto/rand on the first step. +func (c *Client) SetNonce(nonce []byte) { + c.clientNonce = nonce +} + +var escaper = strings.NewReplacer("=", "=3D", ",", "=2C") + +// Step processes the incoming data from the server and makes the +// next round of data for the server available via Client.Out. +// Step returns false if there are no errors and more data is +// still expected. +func (c *Client) Step(in []byte) bool { + c.out.Reset() + if c.step > 2 || c.err != nil { + return false + } + c.step++ + switch c.step { + case 1: + c.err = c.step1(in) + case 2: + c.err = c.step2(in) + case 3: + c.err = c.step3(in) + } + return c.step > 2 || c.err != nil +} + +func (c *Client) step1(in []byte) error { + if len(c.clientNonce) == 0 { + const nonceLen = 6 + buf := make([]byte, nonceLen + b64.EncodedLen(nonceLen)) + if _, err := rand.Read(buf[:nonceLen]); err != nil { + return fmt.Errorf("cannot read random SCRAM-SHA-1 nonce from operating system: %v", err) + } + c.clientNonce = buf[nonceLen:] + b64.Encode(c.clientNonce, buf[:nonceLen]) + } + c.authMsg.WriteString("n=") + escaper.WriteString(&c.authMsg, c.user) + c.authMsg.WriteString(",r=") + c.authMsg.Write(c.clientNonce) + + c.out.WriteString("n,,") + c.out.Write(c.authMsg.Bytes()) + return nil +} + +var b64 = base64.StdEncoding + +func (c *Client) step2(in []byte) error { + c.authMsg.WriteByte(',') + c.authMsg.Write(in) + + fields := bytes.Split(in, []byte(",")) + if len(fields) != 3 { + return fmt.Errorf("expected 3 fields in first SCRAM-SHA-1 server message, got %d: %q", len(fields), in) + } + if !bytes.HasPrefix(fields[0], []byte("r=")) || len(fields[0]) < 2 { + return fmt.Errorf("server sent an invalid SCRAM-SHA-1 nonce: %q", fields[0]) + } + if !bytes.HasPrefix(fields[1], []byte("s=")) || len(fields[1]) < 6 { + return fmt.Errorf("server sent an invalid SCRAM-SHA-1 salt: %q", fields[1]) + } + if !bytes.HasPrefix(fields[2], []byte("i=")) || len(fields[2]) < 6 { + return fmt.Errorf("server sent an invalid SCRAM-SHA-1 iteration count: %q", fields[2]) + } + + c.serverNonce = fields[0][2:] + if !bytes.HasPrefix(c.serverNonce, c.clientNonce) { + return fmt.Errorf("server SCRAM-SHA-1 nonce is not prefixed by client nonce: got %q, want %q+\"...\"", c.serverNonce, c.clientNonce) + } + + salt := make([]byte, b64.DecodedLen(len(fields[1][2:]))) + n, err := b64.Decode(salt, fields[1][2:]) + if err != nil { + return fmt.Errorf("cannot decode SCRAM-SHA-1 salt sent by server: %q", fields[1]) + } + salt = salt[:n] + iterCount, err := strconv.Atoi(string(fields[2][2:])) + if err != nil { + return fmt.Errorf("server sent an invalid SCRAM-SHA-1 iteration count: %q", fields[2]) + } + c.saltPassword(salt, iterCount) + + c.authMsg.WriteString(",c=biws,r=") + c.authMsg.Write(c.serverNonce) + + c.out.WriteString("c=biws,r=") + c.out.Write(c.serverNonce) + c.out.WriteString(",p=") + c.out.Write(c.clientProof()) + return nil +} + +func (c *Client) step3(in []byte) error { + var isv, ise bool + var fields = bytes.Split(in, []byte(",")) + if len(fields) == 1 { + isv = bytes.HasPrefix(fields[0], []byte("v=")) + ise = bytes.HasPrefix(fields[0], []byte("e=")) + } + if ise { + return fmt.Errorf("SCRAM-SHA-1 authentication error: %s", fields[0][2:]) + } else if !isv { + return fmt.Errorf("unsupported SCRAM-SHA-1 final message from server: %q", in) + } + if !bytes.Equal(c.serverSignature(), fields[0][2:]) { + return fmt.Errorf("cannot authenticate SCRAM-SHA-1 server signature: %q", fields[0][2:]) + } + return nil +} + +func (c *Client) saltPassword(salt []byte, iterCount int) { + mac := hmac.New(c.newHash, []byte(c.pass)) + mac.Write(salt) + mac.Write([]byte{0, 0, 0, 1}) + ui := mac.Sum(nil) + hi := make([]byte, len(ui)) + copy(hi, ui) + for i := 1; i < iterCount; i++ { + mac.Reset() + mac.Write(ui) + mac.Sum(ui[:0]) + for j, b := range ui { + hi[j] ^= b + } + } + c.saltedPass = hi +} + +func (c *Client) clientProof() []byte { + mac := hmac.New(c.newHash, c.saltedPass) + mac.Write([]byte("Client Key")) + clientKey := mac.Sum(nil) + hash := c.newHash() + hash.Write(clientKey) + storedKey := hash.Sum(nil) + mac = hmac.New(c.newHash, storedKey) + mac.Write(c.authMsg.Bytes()) + clientProof := mac.Sum(nil) + for i, b := range clientKey { + clientProof[i] ^= b + } + clientProof64 := make([]byte, b64.EncodedLen(len(clientProof))) + b64.Encode(clientProof64, clientProof) + return clientProof64 +} + +func (c *Client) serverSignature() []byte { + mac := hmac.New(c.newHash, c.saltedPass) + mac.Write([]byte("Server Key")) + serverKey := mac.Sum(nil) + + mac = hmac.New(c.newHash, serverKey) + mac.Write(c.authMsg.Bytes()) + serverSignature := mac.Sum(nil) + + encoded := make([]byte, b64.EncodedLen(len(serverSignature))) + b64.Encode(encoded, serverSignature) + return encoded +} diff --git a/vendor/gopkg.in/mgo.v2/internal/scram/scram_test.go b/vendor/gopkg.in/mgo.v2/internal/scram/scram_test.go new file mode 100644 index 000000000..9c20fdfc4 --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/internal/scram/scram_test.go @@ -0,0 +1,67 @@ +package scram_test + +import ( + "crypto/sha1" + "testing" + + . "gopkg.in/check.v1" + "gopkg.in/mgo.v2/internal/scram" + "strings" +) + +var _ = Suite(&S{}) + +func Test(t *testing.T) { TestingT(t) } + +type S struct{} + +var tests = [][]string{{ + "U: user pencil", + "N: fyko+d2lbbFgONRv9qkxdawL", + "C: n,,n=user,r=fyko+d2lbbFgONRv9qkxdawL", + "S: r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,s=QSXCR+Q6sek8bf92,i=4096", + "C: c=biws,r=fyko+d2lbbFgONRv9qkxdawL3rfcNHYJY1ZVvWVs7j,p=v0X8v3Bz2T0CJGbJQyF0X+HI4Ts=", + "S: v=rmF9pqV8S7suAoZWja4dJRkFsKQ=", +}, { + "U: root fe8c89e308ec08763df36333cbf5d3a2", + "N: OTcxNDk5NjM2MzE5", + "C: n,,n=root,r=OTcxNDk5NjM2MzE5", + "S: r=OTcxNDk5NjM2MzE581Ra3provgG0iDsMkDiIAlrh4532dDLp,s=XRDkVrFC9JuL7/F4tG0acQ==,i=10000", + "C: c=biws,r=OTcxNDk5NjM2MzE581Ra3provgG0iDsMkDiIAlrh4532dDLp,p=6y1jp9R7ETyouTXS9fW9k5UHdBc=", + "S: v=LBnd9dUJRxdqZiEq91NKP3z/bHA=", +}} + +func (s *S) TestExamples(c *C) { + for _, steps := range tests { + if len(steps) < 2 || len(steps[0]) < 3 || !strings.HasPrefix(steps[0], "U: ") { + c.Fatalf("Invalid test: %#v", steps) + } + auth := strings.Fields(steps[0][3:]) + client := scram.NewClient(sha1.New, auth[0], auth[1]) + first, done := true, false + c.Logf("-----") + c.Logf("%s", steps[0]) + for _, step := range steps[1:] { + c.Logf("%s", step) + switch step[:3] { + case "N: ": + client.SetNonce([]byte(step[3:])) + case "C: ": + if first { + first = false + done = client.Step(nil) + } + c.Assert(done, Equals, false) + c.Assert(client.Err(), IsNil) + c.Assert(string(client.Out()), Equals, step[3:]) + case "S: ": + first = false + done = client.Step([]byte(step[3:])) + default: + panic("invalid test line: " + step) + } + } + c.Assert(done, Equals, true) + c.Assert(client.Err(), IsNil) + } +} diff --git a/vendor/gopkg.in/mgo.v2/log.go b/vendor/gopkg.in/mgo.v2/log.go new file mode 100644 index 000000000..53eb4237b --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/log.go @@ -0,0 +1,133 @@ +// mgo - MongoDB driver for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package mgo + +import ( + "fmt" + "sync" +) + +// --------------------------------------------------------------------------- +// Logging integration. + +// Avoid importing the log type information unnecessarily. There's a small cost +// associated with using an interface rather than the type. Depending on how +// often the logger is plugged in, it would be worth using the type instead. +type log_Logger interface { + Output(calldepth int, s string) error +} + +var ( + globalLogger log_Logger + globalDebug bool + globalMutex sync.Mutex +) + +// RACE WARNING: There are known data races when logging, which are manually +// silenced when the race detector is in use. These data races won't be +// observed in typical use, because logging is supposed to be set up once when +// the application starts. Having raceDetector as a constant, the compiler +// should elide the locks altogether in actual use. + +// Specify the *log.Logger object where log messages should be sent to. +func SetLogger(logger log_Logger) { + if raceDetector { + globalMutex.Lock() + defer globalMutex.Unlock() + } + globalLogger = logger +} + +// Enable the delivery of debug messages to the logger. Only meaningful +// if a logger is also set. +func SetDebug(debug bool) { + if raceDetector { + globalMutex.Lock() + defer globalMutex.Unlock() + } + globalDebug = debug +} + +func log(v ...interface{}) { + if raceDetector { + globalMutex.Lock() + defer globalMutex.Unlock() + } + if globalLogger != nil { + globalLogger.Output(2, fmt.Sprint(v...)) + } +} + +func logln(v ...interface{}) { + if raceDetector { + globalMutex.Lock() + defer globalMutex.Unlock() + } + if globalLogger != nil { + globalLogger.Output(2, fmt.Sprintln(v...)) + } +} + +func logf(format string, v ...interface{}) { + if raceDetector { + globalMutex.Lock() + defer globalMutex.Unlock() + } + if globalLogger != nil { + globalLogger.Output(2, fmt.Sprintf(format, v...)) + } +} + +func debug(v ...interface{}) { + if raceDetector { + globalMutex.Lock() + defer globalMutex.Unlock() + } + if globalDebug && globalLogger != nil { + globalLogger.Output(2, fmt.Sprint(v...)) + } +} + +func debugln(v ...interface{}) { + if raceDetector { + globalMutex.Lock() + defer globalMutex.Unlock() + } + if globalDebug && globalLogger != nil { + globalLogger.Output(2, fmt.Sprintln(v...)) + } +} + +func debugf(format string, v ...interface{}) { + if raceDetector { + globalMutex.Lock() + defer globalMutex.Unlock() + } + if globalDebug && globalLogger != nil { + globalLogger.Output(2, fmt.Sprintf(format, v...)) + } +} diff --git a/vendor/gopkg.in/mgo.v2/queue.go b/vendor/gopkg.in/mgo.v2/queue.go new file mode 100644 index 000000000..e9245de70 --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/queue.go @@ -0,0 +1,91 @@ +// mgo - MongoDB driver for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package mgo + +type queue struct { + elems []interface{} + nelems, popi, pushi int +} + +func (q *queue) Len() int { + return q.nelems +} + +func (q *queue) Push(elem interface{}) { + //debugf("Pushing(pushi=%d popi=%d cap=%d): %#v\n", + // q.pushi, q.popi, len(q.elems), elem) + if q.nelems == len(q.elems) { + q.expand() + } + q.elems[q.pushi] = elem + q.nelems++ + q.pushi = (q.pushi + 1) % len(q.elems) + //debugf(" Pushed(pushi=%d popi=%d cap=%d): %#v\n", + // q.pushi, q.popi, len(q.elems), elem) +} + +func (q *queue) Pop() (elem interface{}) { + //debugf("Popping(pushi=%d popi=%d cap=%d)\n", + // q.pushi, q.popi, len(q.elems)) + if q.nelems == 0 { + return nil + } + elem = q.elems[q.popi] + q.elems[q.popi] = nil // Help GC. + q.nelems-- + q.popi = (q.popi + 1) % len(q.elems) + //debugf(" Popped(pushi=%d popi=%d cap=%d): %#v\n", + // q.pushi, q.popi, len(q.elems), elem) + return elem +} + +func (q *queue) expand() { + curcap := len(q.elems) + var newcap int + if curcap == 0 { + newcap = 8 + } else if curcap < 1024 { + newcap = curcap * 2 + } else { + newcap = curcap + (curcap / 4) + } + elems := make([]interface{}, newcap) + + if q.popi == 0 { + copy(elems, q.elems) + q.pushi = curcap + } else { + newpopi := newcap - (curcap - q.popi) + copy(elems, q.elems[:q.popi]) + copy(elems[newpopi:], q.elems[q.popi:]) + q.popi = newpopi + } + for i := range q.elems { + q.elems[i] = nil // Help GC. + } + q.elems = elems +} diff --git a/vendor/gopkg.in/mgo.v2/queue_test.go b/vendor/gopkg.in/mgo.v2/queue_test.go new file mode 100644 index 000000000..bd0ab550f --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/queue_test.go @@ -0,0 +1,101 @@ +// mgo - MongoDB driver for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package mgo + +import ( + . "gopkg.in/check.v1" +) + +type QS struct{} + +var _ = Suite(&QS{}) + +func (s *QS) TestSequentialGrowth(c *C) { + q := queue{} + n := 2048 + for i := 0; i != n; i++ { + q.Push(i) + } + for i := 0; i != n; i++ { + c.Assert(q.Pop(), Equals, i) + } +} + +var queueTestLists = [][]int{ + // {0, 1, 2, 3, 4, 5, 6, 7, 8, 9} + {0, 1, 2, 3, 4, 5, 6, 7, 8, 9}, + + // {8, 9, 10, 11, ... 2, 3, 4, 5, 6, 7} + {0, 1, 2, 3, 4, 5, 6, 7, -1, -1, 8, 9, 10, 11}, + + // {8, 9, 10, 11, ... 2, 3, 4, 5, 6, 7} + {0, 1, 2, 3, -1, -1, 4, 5, 6, 7, 8, 9, 10, 11}, + + // {0, 1, 2, 3, 4, 5, 6, 7, 8} + {0, 1, 2, 3, 4, 5, 6, 7, 8, + -1, -1, -1, -1, -1, -1, -1, -1, -1, + 0, 1, 2, 3, 4, 5, 6, 7, 8}, +} + +func (s *QS) TestQueueTestLists(c *C) { + test := []int{} + testi := 0 + reset := func() { + test = test[0:0] + testi = 0 + } + push := func(i int) { + test = append(test, i) + } + pop := func() (i int) { + if testi == len(test) { + return -1 + } + i = test[testi] + testi++ + return + } + + for _, list := range queueTestLists { + reset() + q := queue{} + for _, n := range list { + if n == -1 { + c.Assert(q.Pop(), Equals, pop(), Commentf("With list %#v", list)) + } else { + q.Push(n) + push(n) + } + } + + for n := pop(); n != -1; n = pop() { + c.Assert(q.Pop(), Equals, n, Commentf("With list %#v", list)) + } + + c.Assert(q.Pop(), Equals, nil, Commentf("With list %#v", list)) + } +} diff --git a/vendor/gopkg.in/mgo.v2/raceoff.go b/vendor/gopkg.in/mgo.v2/raceoff.go new file mode 100644 index 000000000..e60b14144 --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/raceoff.go @@ -0,0 +1,5 @@ +// +build !race + +package mgo + +const raceDetector = false diff --git a/vendor/gopkg.in/mgo.v2/raceon.go b/vendor/gopkg.in/mgo.v2/raceon.go new file mode 100644 index 000000000..737b08ece --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/raceon.go @@ -0,0 +1,5 @@ +// +build race + +package mgo + +const raceDetector = true diff --git a/vendor/gopkg.in/mgo.v2/saslimpl.go b/vendor/gopkg.in/mgo.v2/saslimpl.go new file mode 100644 index 000000000..0d25f25cb --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/saslimpl.go @@ -0,0 +1,11 @@ +//+build sasl + +package mgo + +import ( + "gopkg.in/mgo.v2/internal/sasl" +) + +func saslNew(cred Credential, host string) (saslStepper, error) { + return sasl.New(cred.Username, cred.Password, cred.Mechanism, cred.Service, host) +} diff --git a/vendor/gopkg.in/mgo.v2/saslstub.go b/vendor/gopkg.in/mgo.v2/saslstub.go new file mode 100644 index 000000000..6e9e30986 --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/saslstub.go @@ -0,0 +1,11 @@ +//+build !sasl + +package mgo + +import ( + "fmt" +) + +func saslNew(cred Credential, host string) (saslStepper, error) { + return nil, fmt.Errorf("SASL support not enabled during build (-tags sasl)") +} diff --git a/vendor/gopkg.in/mgo.v2/server.go b/vendor/gopkg.in/mgo.v2/server.go new file mode 100644 index 000000000..d5086a290 --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/server.go @@ -0,0 +1,448 @@ +// mgo - MongoDB driver for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package mgo + +import ( + "errors" + "net" + "sort" + "sync" + "time" + + "gopkg.in/mgo.v2/bson" +) + +// --------------------------------------------------------------------------- +// Mongo server encapsulation. + +type mongoServer struct { + sync.RWMutex + Addr string + ResolvedAddr string + tcpaddr *net.TCPAddr + unusedSockets []*mongoSocket + liveSockets []*mongoSocket + closed bool + abended bool + sync chan bool + dial dialer + pingValue time.Duration + pingIndex int + pingCount uint32 + pingWindow [6]time.Duration + info *mongoServerInfo +} + +type dialer struct { + old func(addr net.Addr) (net.Conn, error) + new func(addr *ServerAddr) (net.Conn, error) +} + +func (dial dialer) isSet() bool { + return dial.old != nil || dial.new != nil +} + +type mongoServerInfo struct { + Master bool + Mongos bool + Tags bson.D + MaxWireVersion int + SetName string +} + +var defaultServerInfo mongoServerInfo + +func newServer(addr string, tcpaddr *net.TCPAddr, sync chan bool, dial dialer) *mongoServer { + server := &mongoServer{ + Addr: addr, + ResolvedAddr: tcpaddr.String(), + tcpaddr: tcpaddr, + sync: sync, + dial: dial, + info: &defaultServerInfo, + } + // Once so the server gets a ping value, then loop in background. + server.pinger(false) + go server.pinger(true) + return server +} + +var errPoolLimit = errors.New("per-server connection limit reached") +var errServerClosed = errors.New("server was closed") + +// AcquireSocket returns a socket for communicating with the server. +// This will attempt to reuse an old connection, if one is available. Otherwise, +// it will establish a new one. The returned socket is owned by the call site, +// and will return to the cache when the socket has its Release method called +// the same number of times as AcquireSocket + Acquire were called for it. +// If the poolLimit argument is greater than zero and the number of sockets in +// use in this server is greater than the provided limit, errPoolLimit is +// returned. +func (server *mongoServer) AcquireSocket(poolLimit int, timeout time.Duration) (socket *mongoSocket, abended bool, err error) { + for { + server.Lock() + abended = server.abended + if server.closed { + server.Unlock() + return nil, abended, errServerClosed + } + n := len(server.unusedSockets) + if poolLimit > 0 && len(server.liveSockets)-n >= poolLimit { + server.Unlock() + return nil, false, errPoolLimit + } + if n > 0 { + socket = server.unusedSockets[n-1] + server.unusedSockets[n-1] = nil // Help GC. + server.unusedSockets = server.unusedSockets[:n-1] + info := server.info + server.Unlock() + err = socket.InitialAcquire(info, timeout) + if err != nil { + continue + } + } else { + server.Unlock() + socket, err = server.Connect(timeout) + if err == nil { + server.Lock() + // We've waited for the Connect, see if we got + // closed in the meantime + if server.closed { + server.Unlock() + socket.Release() + socket.Close() + return nil, abended, errServerClosed + } + server.liveSockets = append(server.liveSockets, socket) + server.Unlock() + } + } + return + } + panic("unreachable") +} + +// Connect establishes a new connection to the server. This should +// generally be done through server.AcquireSocket(). +func (server *mongoServer) Connect(timeout time.Duration) (*mongoSocket, error) { + server.RLock() + master := server.info.Master + dial := server.dial + server.RUnlock() + + logf("Establishing new connection to %s (timeout=%s)...", server.Addr, timeout) + var conn net.Conn + var err error + switch { + case !dial.isSet(): + // Cannot do this because it lacks timeout support. :-( + //conn, err = net.DialTCP("tcp", nil, server.tcpaddr) + conn, err = net.DialTimeout("tcp", server.ResolvedAddr, timeout) + case dial.old != nil: + conn, err = dial.old(server.tcpaddr) + case dial.new != nil: + conn, err = dial.new(&ServerAddr{server.Addr, server.tcpaddr}) + default: + panic("dialer is set, but both dial.old and dial.new are nil") + } + if err != nil { + logf("Connection to %s failed: %v", server.Addr, err.Error()) + return nil, err + } + logf("Connection to %s established.", server.Addr) + + stats.conn(+1, master) + return newSocket(server, conn, timeout), nil +} + +// Close forces closing all sockets that are alive, whether +// they're currently in use or not. +func (server *mongoServer) Close() { + server.Lock() + server.closed = true + liveSockets := server.liveSockets + unusedSockets := server.unusedSockets + server.liveSockets = nil + server.unusedSockets = nil + server.Unlock() + logf("Connections to %s closing (%d live sockets).", server.Addr, len(liveSockets)) + for i, s := range liveSockets { + s.Close() + liveSockets[i] = nil + } + for i := range unusedSockets { + unusedSockets[i] = nil + } +} + +// RecycleSocket puts socket back into the unused cache. +func (server *mongoServer) RecycleSocket(socket *mongoSocket) { + server.Lock() + if !server.closed { + server.unusedSockets = append(server.unusedSockets, socket) + } + server.Unlock() +} + +func removeSocket(sockets []*mongoSocket, socket *mongoSocket) []*mongoSocket { + for i, s := range sockets { + if s == socket { + copy(sockets[i:], sockets[i+1:]) + n := len(sockets) - 1 + sockets[n] = nil + sockets = sockets[:n] + break + } + } + return sockets +} + +// AbendSocket notifies the server that the given socket has terminated +// abnormally, and thus should be discarded rather than cached. +func (server *mongoServer) AbendSocket(socket *mongoSocket) { + server.Lock() + server.abended = true + if server.closed { + server.Unlock() + return + } + server.liveSockets = removeSocket(server.liveSockets, socket) + server.unusedSockets = removeSocket(server.unusedSockets, socket) + server.Unlock() + // Maybe just a timeout, but suggest a cluster sync up just in case. + select { + case server.sync <- true: + default: + } +} + +func (server *mongoServer) SetInfo(info *mongoServerInfo) { + server.Lock() + server.info = info + server.Unlock() +} + +func (server *mongoServer) Info() *mongoServerInfo { + server.Lock() + info := server.info + server.Unlock() + return info +} + +func (server *mongoServer) hasTags(serverTags []bson.D) bool { +NextTagSet: + for _, tags := range serverTags { + NextReqTag: + for _, req := range tags { + for _, has := range server.info.Tags { + if req.Name == has.Name { + if req.Value == has.Value { + continue NextReqTag + } + continue NextTagSet + } + } + continue NextTagSet + } + return true + } + return false +} + +var pingDelay = 5 * time.Second + +func (server *mongoServer) pinger(loop bool) { + var delay time.Duration + if raceDetector { + // This variable is only ever touched by tests. + globalMutex.Lock() + delay = pingDelay + globalMutex.Unlock() + } else { + delay = pingDelay + } + op := queryOp{ + collection: "admin.$cmd", + query: bson.D{{"ping", 1}}, + flags: flagSlaveOk, + limit: -1, + } + for { + if loop { + time.Sleep(delay) + } + op := op + socket, _, err := server.AcquireSocket(0, 3*delay) + if err == nil { + start := time.Now() + _, _ = socket.SimpleQuery(&op) + delay := time.Now().Sub(start) + + server.pingWindow[server.pingIndex] = delay + server.pingIndex = (server.pingIndex + 1) % len(server.pingWindow) + server.pingCount++ + var max time.Duration + for i := 0; i < len(server.pingWindow) && uint32(i) < server.pingCount; i++ { + if server.pingWindow[i] > max { + max = server.pingWindow[i] + } + } + socket.Release() + server.Lock() + if server.closed { + loop = false + } + server.pingValue = max + server.Unlock() + logf("Ping for %s is %d ms", server.Addr, max/time.Millisecond) + } else if err == errServerClosed { + return + } + if !loop { + return + } + } +} + +type mongoServerSlice []*mongoServer + +func (s mongoServerSlice) Len() int { + return len(s) +} + +func (s mongoServerSlice) Less(i, j int) bool { + return s[i].ResolvedAddr < s[j].ResolvedAddr +} + +func (s mongoServerSlice) Swap(i, j int) { + s[i], s[j] = s[j], s[i] +} + +func (s mongoServerSlice) Sort() { + sort.Sort(s) +} + +func (s mongoServerSlice) Search(resolvedAddr string) (i int, ok bool) { + n := len(s) + i = sort.Search(n, func(i int) bool { + return s[i].ResolvedAddr >= resolvedAddr + }) + return i, i != n && s[i].ResolvedAddr == resolvedAddr +} + +type mongoServers struct { + slice mongoServerSlice +} + +func (servers *mongoServers) Search(resolvedAddr string) (server *mongoServer) { + if i, ok := servers.slice.Search(resolvedAddr); ok { + return servers.slice[i] + } + return nil +} + +func (servers *mongoServers) Add(server *mongoServer) { + servers.slice = append(servers.slice, server) + servers.slice.Sort() +} + +func (servers *mongoServers) Remove(other *mongoServer) (server *mongoServer) { + if i, found := servers.slice.Search(other.ResolvedAddr); found { + server = servers.slice[i] + copy(servers.slice[i:], servers.slice[i+1:]) + n := len(servers.slice) - 1 + servers.slice[n] = nil // Help GC. + servers.slice = servers.slice[:n] + } + return +} + +func (servers *mongoServers) Slice() []*mongoServer { + return ([]*mongoServer)(servers.slice) +} + +func (servers *mongoServers) Get(i int) *mongoServer { + return servers.slice[i] +} + +func (servers *mongoServers) Len() int { + return len(servers.slice) +} + +func (servers *mongoServers) Empty() bool { + return len(servers.slice) == 0 +} + +// BestFit returns the best guess of what would be the most interesting +// server to perform operations on at this point in time. +func (servers *mongoServers) BestFit(serverTags []bson.D) *mongoServer { + var best *mongoServer + for _, next := range servers.slice { + if best == nil { + best = next + best.RLock() + if serverTags != nil && !next.info.Mongos && !best.hasTags(serverTags) { + best.RUnlock() + best = nil + } + continue + } + next.RLock() + swap := false + switch { + case serverTags != nil && !next.info.Mongos && !next.hasTags(serverTags): + // Must have requested tags. + case next.info.Master != best.info.Master: + // Prefer slaves. + swap = best.info.Master + case absDuration(next.pingValue-best.pingValue) > 15*time.Millisecond: + // Prefer nearest server. + swap = next.pingValue < best.pingValue + case len(next.liveSockets)-len(next.unusedSockets) < len(best.liveSockets)-len(best.unusedSockets): + // Prefer servers with less connections. + swap = true + } + if swap { + best.RUnlock() + best = next + } else { + next.RUnlock() + } + } + if best != nil { + best.RUnlock() + } + return best +} + +func absDuration(d time.Duration) time.Duration { + if d < 0 { + return -d + } + return d +} diff --git a/vendor/gopkg.in/mgo.v2/session.go b/vendor/gopkg.in/mgo.v2/session.go new file mode 100644 index 000000000..036f44a63 --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/session.go @@ -0,0 +1,4224 @@ +// mgo - MongoDB driver for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package mgo + +import ( + "crypto/md5" + "encoding/hex" + "errors" + "fmt" + "math" + "net" + "net/url" + "reflect" + "sort" + "strconv" + "strings" + "sync" + "time" + + "gopkg.in/mgo.v2/bson" +) + +type mode int + +const ( + Eventual mode = 0 + Monotonic mode = 1 + Strong mode = 2 +) + +// When changing the Session type, check if newSession and copySession +// need to be updated too. + +type Session struct { + m sync.RWMutex + cluster_ *mongoCluster + slaveSocket *mongoSocket + masterSocket *mongoSocket + slaveOk bool + consistency mode + queryConfig query + safeOp *queryOp + syncTimeout time.Duration + sockTimeout time.Duration + defaultdb string + sourcedb string + dialCred *Credential + creds []Credential + poolLimit int +} + +type Database struct { + Session *Session + Name string +} + +type Collection struct { + Database *Database + Name string // "collection" + FullName string // "db.collection" +} + +type Query struct { + m sync.Mutex + session *Session + query // Enables default settings in session. +} + +type query struct { + op queryOp + prefetch float64 + limit int32 +} + +type getLastError struct { + CmdName int "getLastError,omitempty" + W interface{} "w,omitempty" + WTimeout int "wtimeout,omitempty" + FSync bool "fsync,omitempty" + J bool "j,omitempty" +} + +type Iter struct { + m sync.Mutex + gotReply sync.Cond + session *Session + server *mongoServer + docData queue + err error + op getMoreOp + prefetch float64 + limit int32 + docsToReceive int + docsBeforeMore int + timeout time.Duration + timedout bool +} + +var ( + ErrNotFound = errors.New("not found") + ErrCursor = errors.New("invalid cursor") +) + +const defaultPrefetch = 0.25 + +// Dial establishes a new session to the cluster identified by the given seed +// server(s). The session will enable communication with all of the servers in +// the cluster, so the seed servers are used only to find out about the cluster +// topology. +// +// Dial will timeout after 10 seconds if a server isn't reached. The returned +// session will timeout operations after one minute by default if servers +// aren't available. To customize the timeout, see DialWithTimeout, +// SetSyncTimeout, and SetSocketTimeout. +// +// This method is generally called just once for a given cluster. Further +// sessions to the same cluster are then established using the New or Copy +// methods on the obtained session. This will make them share the underlying +// cluster, and manage the pool of connections appropriately. +// +// Once the session is not useful anymore, Close must be called to release the +// resources appropriately. +// +// The seed servers must be provided in the following format: +// +// [mongodb://][user:pass@]host1[:port1][,host2[:port2],...][/database][?options] +// +// For example, it may be as simple as: +// +// localhost +// +// Or more involved like: +// +// mongodb://myuser:mypass@localhost:40001,otherhost:40001/mydb +// +// If the port number is not provided for a server, it defaults to 27017. +// +// The username and password provided in the URL will be used to authenticate +// into the database named after the slash at the end of the host names, or +// into the "admin" database if none is provided. The authentication information +// will persist in sessions obtained through the New method as well. +// +// The following connection options are supported after the question mark: +// +// connect=direct +// +// Disables the automatic replica set server discovery logic, and +// forces the use of servers provided only (even if secondaries). +// Note that to talk to a secondary the consistency requirements +// must be relaxed to Monotonic or Eventual via SetMode. +// +// +// authSource= +// +// Informs the database used to establish credentials and privileges +// with a MongoDB server. Defaults to the database name provided via +// the URL path, and "admin" if that's unset. +// +// +// authMechanism= +// +// Defines the protocol for credential negotiation. Defaults to "MONGODB-CR", +// which is the default username/password challenge-response mechanism. +// +// +// gssapiServiceName= +// +// Defines the service name to use when authenticating with the GSSAPI +// mechanism. Defaults to "mongodb". +// +// maxPoolSize= +// +// Defines the per-server socket pool limit. Defaults to 4096. +// See Session.SetPoolLimit for details. +// +// +// Relevant documentation: +// +// http://docs.mongodb.org/manual/reference/connection-string/ +// +func Dial(url string) (*Session, error) { + session, err := DialWithTimeout(url, 10*time.Second) + if err == nil { + session.SetSyncTimeout(1 * time.Minute) + session.SetSocketTimeout(1 * time.Minute) + } + return session, err +} + +// DialWithTimeout works like Dial, but uses timeout as the amount of time to +// wait for a server to respond when first connecting and also on follow up +// operations in the session. If timeout is zero, the call may block +// forever waiting for a connection to be made. +// +// See SetSyncTimeout for customizing the timeout for the session. +func DialWithTimeout(url string, timeout time.Duration) (*Session, error) { + info, err := ParseURL(url) + if err != nil { + return nil, err + } + info.Timeout = timeout + return DialWithInfo(info) +} + +// ParseURL parses a MongoDB URL as accepted by the Dial function and returns +// a value suitable for providing into DialWithInfo. +// +// See Dial for more details on the format of url. +func ParseURL(url string) (*DialInfo, error) { + uinfo, err := extractURL(url) + if err != nil { + return nil, err + } + direct := false + mechanism := "" + service := "" + source := "" + setName := "" + poolLimit := 0 + for k, v := range uinfo.options { + switch k { + case "authSource": + source = v + case "authMechanism": + mechanism = v + case "gssapiServiceName": + service = v + case "replicaSet": + setName = v + case "maxPoolSize": + poolLimit, err = strconv.Atoi(v) + if err != nil { + return nil, errors.New("bad value for maxPoolSize: " + v) + } + case "connect": + if v == "direct" { + direct = true + break + } + if v == "replicaSet" { + break + } + fallthrough + default: + return nil, errors.New("unsupported connection URL option: " + k + "=" + v) + } + } + info := DialInfo{ + Addrs: uinfo.addrs, + Direct: direct, + Database: uinfo.db, + Username: uinfo.user, + Password: uinfo.pass, + Mechanism: mechanism, + Service: service, + Source: source, + PoolLimit: poolLimit, + ReplicaSetName: setName, + } + return &info, nil +} + +// DialInfo holds options for establishing a session with a MongoDB cluster. +// To use a URL, see the Dial function. +type DialInfo struct { + // Addrs holds the addresses for the seed servers. + Addrs []string + + // Direct informs whether to establish connections only with the + // specified seed servers, or to obtain information for the whole + // cluster and establish connections with further servers too. + Direct bool + + // Timeout is the amount of time to wait for a server to respond when + // first connecting and on follow up operations in the session. If + // timeout is zero, the call may block forever waiting for a connection + // to be established. + Timeout time.Duration + + // FailFast will cause connection and query attempts to fail faster when + // the server is unavailable, instead of retrying until the configured + // timeout period. Note that an unavailable server may silently drop + // packets instead of rejecting them, in which case it's impossible to + // distinguish it from a slow server, so the timeout stays relevant. + FailFast bool + + // Database is the default database name used when the Session.DB method + // is called with an empty name, and is also used during the intial + // authentication if Source is unset. + Database string + + // ReplicaSetName, if specified, will prevent the obtained session from + // communicating with any server which is not part of a replica set + // with the given name. The default is to communicate with any server + // specified or discovered via the servers contacted. + ReplicaSetName string + + // Source is the database used to establish credentials and privileges + // with a MongoDB server. Defaults to the value of Database, if that is + // set, or "admin" otherwise. + Source string + + // Service defines the service name to use when authenticating with the GSSAPI + // mechanism. Defaults to "mongodb". + Service string + + // ServiceHost defines which hostname to use when authenticating + // with the GSSAPI mechanism. If not specified, defaults to the MongoDB + // server's address. + ServiceHost string + + // Mechanism defines the protocol for credential negotiation. + // Defaults to "MONGODB-CR". + Mechanism string + + // Username and Password inform the credentials for the initial authentication + // done on the database defined by the Source field. See Session.Login. + Username string + Password string + + // PoolLimit defines the per-server socket pool limit. Defaults to 4096. + // See Session.SetPoolLimit for details. + PoolLimit int + + // DialServer optionally specifies the dial function for establishing + // connections with the MongoDB servers. + DialServer func(addr *ServerAddr) (net.Conn, error) + + // WARNING: This field is obsolete. See DialServer above. + Dial func(addr net.Addr) (net.Conn, error) +} + +// ServerAddr represents the address for establishing a connection to an +// individual MongoDB server. +type ServerAddr struct { + str string + tcp *net.TCPAddr +} + +// String returns the address that was provided for the server before resolution. +func (addr *ServerAddr) String() string { + return addr.str +} + +// TCPAddr returns the resolved TCP address for the server. +func (addr *ServerAddr) TCPAddr() *net.TCPAddr { + return addr.tcp +} + +// DialWithInfo establishes a new session to the cluster identified by info. +func DialWithInfo(info *DialInfo) (*Session, error) { + addrs := make([]string, len(info.Addrs)) + for i, addr := range info.Addrs { + p := strings.LastIndexAny(addr, "]:") + if p == -1 || addr[p] != ':' { + // XXX This is untested. The test suite doesn't use the standard port. + addr += ":27017" + } + addrs[i] = addr + } + cluster := newCluster(addrs, info.Direct, info.FailFast, dialer{info.Dial, info.DialServer}, info.ReplicaSetName) + session := newSession(Eventual, cluster, info.Timeout) + session.defaultdb = info.Database + if session.defaultdb == "" { + session.defaultdb = "test" + } + session.sourcedb = info.Source + if session.sourcedb == "" { + session.sourcedb = info.Database + if session.sourcedb == "" { + session.sourcedb = "admin" + } + } + if info.Username != "" { + source := session.sourcedb + if info.Source == "" && + (info.Mechanism == "GSSAPI" || info.Mechanism == "PLAIN" || info.Mechanism == "MONGODB-X509") { + source = "$external" + } + session.dialCred = &Credential{ + Username: info.Username, + Password: info.Password, + Mechanism: info.Mechanism, + Service: info.Service, + ServiceHost: info.ServiceHost, + Source: source, + } + session.creds = []Credential{*session.dialCred} + } + if info.PoolLimit > 0 { + session.poolLimit = info.PoolLimit + } + cluster.Release() + + // People get confused when we return a session that is not actually + // established to any servers yet (e.g. what if url was wrong). So, + // ping the server to ensure there's someone there, and abort if it + // fails. + if err := session.Ping(); err != nil { + session.Close() + return nil, err + } + session.SetMode(Strong, true) + return session, nil +} + +func isOptSep(c rune) bool { + return c == ';' || c == '&' +} + +type urlInfo struct { + addrs []string + user string + pass string + db string + options map[string]string +} + +func extractURL(s string) (*urlInfo, error) { + if strings.HasPrefix(s, "mongodb://") { + s = s[10:] + } + info := &urlInfo{options: make(map[string]string)} + if c := strings.Index(s, "?"); c != -1 { + for _, pair := range strings.FieldsFunc(s[c+1:], isOptSep) { + l := strings.SplitN(pair, "=", 2) + if len(l) != 2 || l[0] == "" || l[1] == "" { + return nil, errors.New("connection option must be key=value: " + pair) + } + info.options[l[0]] = l[1] + } + s = s[:c] + } + if c := strings.Index(s, "@"); c != -1 { + pair := strings.SplitN(s[:c], ":", 2) + if len(pair) > 2 || pair[0] == "" { + return nil, errors.New("credentials must be provided as user:pass@host") + } + var err error + info.user, err = url.QueryUnescape(pair[0]) + if err != nil { + return nil, fmt.Errorf("cannot unescape username in URL: %q", pair[0]) + } + if len(pair) > 1 { + info.pass, err = url.QueryUnescape(pair[1]) + if err != nil { + return nil, fmt.Errorf("cannot unescape password in URL") + } + } + s = s[c+1:] + } + if c := strings.Index(s, "/"); c != -1 { + info.db = s[c+1:] + s = s[:c] + } + info.addrs = strings.Split(s, ",") + return info, nil +} + +func newSession(consistency mode, cluster *mongoCluster, timeout time.Duration) (session *Session) { + cluster.Acquire() + session = &Session{ + cluster_: cluster, + syncTimeout: timeout, + sockTimeout: timeout, + poolLimit: 4096, + } + debugf("New session %p on cluster %p", session, cluster) + session.SetMode(consistency, true) + session.SetSafe(&Safe{}) + session.queryConfig.prefetch = defaultPrefetch + return session +} + +func copySession(session *Session, keepCreds bool) (s *Session) { + cluster := session.cluster() + cluster.Acquire() + if session.masterSocket != nil { + session.masterSocket.Acquire() + } + if session.slaveSocket != nil { + session.slaveSocket.Acquire() + } + var creds []Credential + if keepCreds { + creds = make([]Credential, len(session.creds)) + copy(creds, session.creds) + } else if session.dialCred != nil { + creds = []Credential{*session.dialCred} + } + scopy := *session + scopy.m = sync.RWMutex{} + scopy.creds = creds + s = &scopy + debugf("New session %p on cluster %p (copy from %p)", s, cluster, session) + return s +} + +// LiveServers returns a list of server addresses which are +// currently known to be alive. +func (s *Session) LiveServers() (addrs []string) { + s.m.RLock() + addrs = s.cluster().LiveServers() + s.m.RUnlock() + return addrs +} + +// DB returns a value representing the named database. If name +// is empty, the database name provided in the dialed URL is +// used instead. If that is also empty, "test" is used as a +// fallback in a way equivalent to the mongo shell. +// +// Creating this value is a very lightweight operation, and +// involves no network communication. +func (s *Session) DB(name string) *Database { + if name == "" { + name = s.defaultdb + } + return &Database{s, name} +} + +// C returns a value representing the named collection. +// +// Creating this value is a very lightweight operation, and +// involves no network communication. +func (db *Database) C(name string) *Collection { + return &Collection{db, name, db.Name + "." + name} +} + +// With returns a copy of db that uses session s. +func (db *Database) With(s *Session) *Database { + newdb := *db + newdb.Session = s + return &newdb +} + +// With returns a copy of c that uses session s. +func (c *Collection) With(s *Session) *Collection { + newdb := *c.Database + newdb.Session = s + newc := *c + newc.Database = &newdb + return &newc +} + +// GridFS returns a GridFS value representing collections in db that +// follow the standard GridFS specification. +// The provided prefix (sometimes known as root) will determine which +// collections to use, and is usually set to "fs" when there is a +// single GridFS in the database. +// +// See the GridFS Create, Open, and OpenId methods for more details. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/GridFS +// http://www.mongodb.org/display/DOCS/GridFS+Tools +// http://www.mongodb.org/display/DOCS/GridFS+Specification +// +func (db *Database) GridFS(prefix string) *GridFS { + return newGridFS(db, prefix) +} + +// Run issues the provided command on the db database and unmarshals +// its result in the respective argument. The cmd argument may be either +// a string with the command name itself, in which case an empty document of +// the form bson.M{cmd: 1} will be used, or it may be a full command document. +// +// Note that MongoDB considers the first marshalled key as the command +// name, so when providing a command with options, it's important to +// use an ordering-preserving document, such as a struct value or an +// instance of bson.D. For instance: +// +// db.Run(bson.D{{"create", "mycollection"}, {"size", 1024}}) +// +// For privilleged commands typically run on the "admin" database, see +// the Run method in the Session type. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/Commands +// http://www.mongodb.org/display/DOCS/List+of+Database+CommandSkips +// +func (db *Database) Run(cmd interface{}, result interface{}) error { + socket, err := db.Session.acquireSocket(true) + if err != nil { + return err + } + defer socket.Release() + + // This is an optimized form of db.C("$cmd").Find(cmd).One(result). + return db.run(socket, cmd, result) +} + +// Credential holds details to authenticate with a MongoDB server. +type Credential struct { + // Username and Password hold the basic details for authentication. + // Password is optional with some authentication mechanisms. + Username string + Password string + + // Source is the database used to establish credentials and privileges + // with a MongoDB server. Defaults to the default database provided + // during dial, or "admin" if that was unset. + Source string + + // Service defines the service name to use when authenticating with the GSSAPI + // mechanism. Defaults to "mongodb". + Service string + + // ServiceHost defines which hostname to use when authenticating + // with the GSSAPI mechanism. If not specified, defaults to the MongoDB + // server's address. + ServiceHost string + + // Mechanism defines the protocol for credential negotiation. + // Defaults to "MONGODB-CR". + Mechanism string +} + +// Login authenticates with MongoDB using the provided credential. The +// authentication is valid for the whole session and will stay valid until +// Logout is explicitly called for the same database, or the session is +// closed. +func (db *Database) Login(user, pass string) error { + return db.Session.Login(&Credential{Username: user, Password: pass, Source: db.Name}) +} + +// Login authenticates with MongoDB using the provided credential. The +// authentication is valid for the whole session and will stay valid until +// Logout is explicitly called for the same database, or the session is +// closed. +func (s *Session) Login(cred *Credential) error { + socket, err := s.acquireSocket(true) + if err != nil { + return err + } + defer socket.Release() + + credCopy := *cred + if cred.Source == "" { + if cred.Mechanism == "GSSAPI" { + credCopy.Source = "$external" + } else { + credCopy.Source = s.sourcedb + } + } + err = socket.Login(credCopy) + if err != nil { + return err + } + + s.m.Lock() + s.creds = append(s.creds, credCopy) + s.m.Unlock() + return nil +} + +func (s *Session) socketLogin(socket *mongoSocket) error { + for _, cred := range s.creds { + if err := socket.Login(cred); err != nil { + return err + } + } + return nil +} + +// Logout removes any established authentication credentials for the database. +func (db *Database) Logout() { + session := db.Session + dbname := db.Name + session.m.Lock() + found := false + for i, cred := range session.creds { + if cred.Source == dbname { + copy(session.creds[i:], session.creds[i+1:]) + session.creds = session.creds[:len(session.creds)-1] + found = true + break + } + } + if found { + if session.masterSocket != nil { + session.masterSocket.Logout(dbname) + } + if session.slaveSocket != nil { + session.slaveSocket.Logout(dbname) + } + } + session.m.Unlock() +} + +// LogoutAll removes all established authentication credentials for the session. +func (s *Session) LogoutAll() { + s.m.Lock() + for _, cred := range s.creds { + if s.masterSocket != nil { + s.masterSocket.Logout(cred.Source) + } + if s.slaveSocket != nil { + s.slaveSocket.Logout(cred.Source) + } + } + s.creds = s.creds[0:0] + s.m.Unlock() +} + +// User represents a MongoDB user. +// +// Relevant documentation: +// +// http://docs.mongodb.org/manual/reference/privilege-documents/ +// http://docs.mongodb.org/manual/reference/user-privileges/ +// +type User struct { + // Username is how the user identifies itself to the system. + Username string `bson:"user"` + + // Password is the plaintext password for the user. If set, + // the UpsertUser method will hash it into PasswordHash and + // unset it before the user is added to the database. + Password string `bson:",omitempty"` + + // PasswordHash is the MD5 hash of Username+":mongo:"+Password. + PasswordHash string `bson:"pwd,omitempty"` + + // CustomData holds arbitrary data admins decide to associate + // with this user, such as the full name or employee id. + CustomData interface{} `bson:"customData,omitempty"` + + // Roles indicates the set of roles the user will be provided. + // See the Role constants. + Roles []Role `bson:"roles"` + + // OtherDBRoles allows assigning roles in other databases from + // user documents inserted in the admin database. This field + // only works in the admin database. + OtherDBRoles map[string][]Role `bson:"otherDBRoles,omitempty"` + + // UserSource indicates where to look for this user's credentials. + // It may be set to a database name, or to "$external" for + // consulting an external resource such as Kerberos. UserSource + // must not be set if Password or PasswordHash are present. + // + // WARNING: This setting was only ever supported in MongoDB 2.4, + // and is now obsolete. + UserSource string `bson:"userSource,omitempty"` +} + +type Role string + +const ( + // Relevant documentation: + // + // http://docs.mongodb.org/manual/reference/user-privileges/ + // + RoleRoot Role = "root" + RoleRead Role = "read" + RoleReadAny Role = "readAnyDatabase" + RoleReadWrite Role = "readWrite" + RoleReadWriteAny Role = "readWriteAnyDatabase" + RoleDBAdmin Role = "dbAdmin" + RoleDBAdminAny Role = "dbAdminAnyDatabase" + RoleUserAdmin Role = "userAdmin" + RoleUserAdminAny Role = "userAdminAnyDatabase" + RoleClusterAdmin Role = "clusterAdmin" +) + +// UpsertUser updates the authentication credentials and the roles for +// a MongoDB user within the db database. If the named user doesn't exist +// it will be created. +// +// This method should only be used from MongoDB 2.4 and on. For older +// MongoDB releases, use the obsolete AddUser method instead. +// +// Relevant documentation: +// +// http://docs.mongodb.org/manual/reference/user-privileges/ +// http://docs.mongodb.org/manual/reference/privilege-documents/ +// +func (db *Database) UpsertUser(user *User) error { + if user.Username == "" { + return fmt.Errorf("user has no Username") + } + if (user.Password != "" || user.PasswordHash != "") && user.UserSource != "" { + return fmt.Errorf("user has both Password/PasswordHash and UserSource set") + } + if len(user.OtherDBRoles) > 0 && db.Name != "admin" && db.Name != "$external" { + return fmt.Errorf("user with OtherDBRoles is only supported in the admin or $external databases") + } + + // Attempt to run this using 2.6+ commands. + rundb := db + if user.UserSource != "" { + // Compatibility logic for the userSource field of MongoDB <= 2.4.X + rundb = db.Session.DB(user.UserSource) + } + err := rundb.runUserCmd("updateUser", user) + // retry with createUser when isAuthError in order to enable the "localhost exception" + if isNotFound(err) || isAuthError(err) { + return rundb.runUserCmd("createUser", user) + } + if !isNoCmd(err) { + return err + } + + // Command does not exist. Fallback to pre-2.6 behavior. + var set, unset bson.D + if user.Password != "" { + psum := md5.New() + psum.Write([]byte(user.Username + ":mongo:" + user.Password)) + set = append(set, bson.DocElem{"pwd", hex.EncodeToString(psum.Sum(nil))}) + unset = append(unset, bson.DocElem{"userSource", 1}) + } else if user.PasswordHash != "" { + set = append(set, bson.DocElem{"pwd", user.PasswordHash}) + unset = append(unset, bson.DocElem{"userSource", 1}) + } + if user.UserSource != "" { + set = append(set, bson.DocElem{"userSource", user.UserSource}) + unset = append(unset, bson.DocElem{"pwd", 1}) + } + if user.Roles != nil || user.OtherDBRoles != nil { + set = append(set, bson.DocElem{"roles", user.Roles}) + if len(user.OtherDBRoles) > 0 { + set = append(set, bson.DocElem{"otherDBRoles", user.OtherDBRoles}) + } else { + unset = append(unset, bson.DocElem{"otherDBRoles", 1}) + } + } + users := db.C("system.users") + err = users.Update(bson.D{{"user", user.Username}}, bson.D{{"$unset", unset}, {"$set", set}}) + if err == ErrNotFound { + set = append(set, bson.DocElem{"user", user.Username}) + if user.Roles == nil && user.OtherDBRoles == nil { + // Roles must be sent, as it's the way MongoDB distinguishes + // old-style documents from new-style documents in pre-2.6. + set = append(set, bson.DocElem{"roles", user.Roles}) + } + err = users.Insert(set) + } + return err +} + +func isNoCmd(err error) bool { + e, ok := err.(*QueryError) + return ok && (e.Code == 59 || e.Code == 13390 || strings.HasPrefix(e.Message, "no such cmd:")) +} + +func isNotFound(err error) bool { + e, ok := err.(*QueryError) + return ok && e.Code == 11 +} + +func isAuthError(err error) bool { + e, ok := err.(*QueryError) + return ok && e.Code == 13 +} + +func (db *Database) runUserCmd(cmdName string, user *User) error { + cmd := make(bson.D, 0, 16) + cmd = append(cmd, bson.DocElem{cmdName, user.Username}) + if user.Password != "" { + cmd = append(cmd, bson.DocElem{"pwd", user.Password}) + } + var roles []interface{} + for _, role := range user.Roles { + roles = append(roles, role) + } + for db, dbroles := range user.OtherDBRoles { + for _, role := range dbroles { + roles = append(roles, bson.D{{"role", role}, {"db", db}}) + } + } + if roles != nil || user.Roles != nil || cmdName == "createUser" { + cmd = append(cmd, bson.DocElem{"roles", roles}) + } + err := db.Run(cmd, nil) + if !isNoCmd(err) && user.UserSource != "" && (user.UserSource != "$external" || db.Name != "$external") { + return fmt.Errorf("MongoDB 2.6+ does not support the UserSource setting") + } + return err +} + +// AddUser creates or updates the authentication credentials of user within +// the db database. +// +// WARNING: This method is obsolete and should only be used with MongoDB 2.2 +// or earlier. For MongoDB 2.4 and on, use UpsertUser instead. +func (db *Database) AddUser(username, password string, readOnly bool) error { + // Try to emulate the old behavior on 2.6+ + user := &User{Username: username, Password: password} + if db.Name == "admin" { + if readOnly { + user.Roles = []Role{RoleReadAny} + } else { + user.Roles = []Role{RoleReadWriteAny} + } + } else { + if readOnly { + user.Roles = []Role{RoleRead} + } else { + user.Roles = []Role{RoleReadWrite} + } + } + err := db.runUserCmd("updateUser", user) + if isNotFound(err) { + return db.runUserCmd("createUser", user) + } + if !isNoCmd(err) { + return err + } + + // Command doesn't exist. Fallback to pre-2.6 behavior. + psum := md5.New() + psum.Write([]byte(username + ":mongo:" + password)) + digest := hex.EncodeToString(psum.Sum(nil)) + c := db.C("system.users") + _, err = c.Upsert(bson.M{"user": username}, bson.M{"$set": bson.M{"user": username, "pwd": digest, "readOnly": readOnly}}) + return err +} + +// RemoveUser removes the authentication credentials of user from the database. +func (db *Database) RemoveUser(user string) error { + err := db.Run(bson.D{{"dropUser", user}}, nil) + if isNoCmd(err) { + users := db.C("system.users") + return users.Remove(bson.M{"user": user}) + } + if isNotFound(err) { + return ErrNotFound + } + return err +} + +type indexSpec struct { + Name, NS string + Key bson.D + Unique bool ",omitempty" + DropDups bool "dropDups,omitempty" + Background bool ",omitempty" + Sparse bool ",omitempty" + Bits, Min, Max int ",omitempty" + BucketSize float64 "bucketSize,omitempty" + ExpireAfter int "expireAfterSeconds,omitempty" + Weights bson.D ",omitempty" + DefaultLanguage string "default_language,omitempty" + LanguageOverride string "language_override,omitempty" +} + +type Index struct { + Key []string // Index key fields; prefix name with dash (-) for descending order + Unique bool // Prevent two documents from having the same index key + DropDups bool // Drop documents with the same index key as a previously indexed one + Background bool // Build index in background and return immediately + Sparse bool // Only index documents containing the Key fields + + // If ExpireAfter is defined the server will periodically delete + // documents with indexed time.Time older than the provided delta. + ExpireAfter time.Duration + + // Name holds the stored index name. On creation this field is ignored and the index name + // is automatically computed by EnsureIndex based on the index key + Name string + + // Properties for spatial indexes. + Bits, Min, Max int + BucketSize float64 + + // Properties for text indexes. + DefaultLanguage string + LanguageOverride string + + // Weights defines the significance of provided fields relative to other + // fields in a text index. The score for a given word in a document is derived + // from the weighted sum of the frequency for each of the indexed fields in + // that document. The default field weight is 1. + Weights map[string]int +} + +type indexKeyInfo struct { + name string + key bson.D + weights bson.D +} + +func parseIndexKey(key []string) (*indexKeyInfo, error) { + var keyInfo indexKeyInfo + isText := false + var order interface{} + for _, field := range key { + raw := field + if keyInfo.name != "" { + keyInfo.name += "_" + } + var kind string + if field != "" { + if field[0] == '$' { + if c := strings.Index(field, ":"); c > 1 && c < len(field)-1 { + kind = field[1:c] + field = field[c+1:] + keyInfo.name += field + "_" + kind + } else { + field = "\x00" + } + } + switch field[0] { + case 0: + // Logic above failed. Reset and error. + field = "" + case '@': + order = "2d" + field = field[1:] + // The shell used to render this field as key_ instead of key_2d, + // and mgo followed suit. This has been fixed in recent server + // releases, and mgo followed as well. + keyInfo.name += field + "_2d" + case '-': + order = -1 + field = field[1:] + keyInfo.name += field + "_-1" + case '+': + field = field[1:] + fallthrough + default: + if kind == "" { + order = 1 + keyInfo.name += field + "_1" + } else { + order = kind + } + } + } + if field == "" || kind != "" && order != kind { + return nil, fmt.Errorf(`invalid index key: want "[$:][-]", got %q`, raw) + } + if kind == "text" { + if !isText { + keyInfo.key = append(keyInfo.key, bson.DocElem{"_fts", "text"}, bson.DocElem{"_ftsx", 1}) + isText = true + } + keyInfo.weights = append(keyInfo.weights, bson.DocElem{field, 1}) + } else { + keyInfo.key = append(keyInfo.key, bson.DocElem{field, order}) + } + } + if keyInfo.name == "" { + return nil, errors.New("invalid index key: no fields provided") + } + return &keyInfo, nil +} + +// EnsureIndexKey ensures an index with the given key exists, creating it +// if necessary. +// +// This example: +// +// err := collection.EnsureIndexKey("a", "b") +// +// Is equivalent to: +// +// err := collection.EnsureIndex(mgo.Index{Key: []string{"a", "b"}}) +// +// See the EnsureIndex method for more details. +func (c *Collection) EnsureIndexKey(key ...string) error { + return c.EnsureIndex(Index{Key: key}) +} + +// EnsureIndex ensures an index with the given key exists, creating it with +// the provided parameters if necessary. EnsureIndex does not modify a previously +// existent index with a matching key. The old index must be dropped first instead. +// +// Once EnsureIndex returns successfully, following requests for the same index +// will not contact the server unless Collection.DropIndex is used to drop the +// same index, or Session.ResetIndexCache is called. +// +// For example: +// +// index := Index{ +// Key: []string{"lastname", "firstname"}, +// Unique: true, +// DropDups: true, +// Background: true, // See notes. +// Sparse: true, +// } +// err := collection.EnsureIndex(index) +// +// The Key value determines which fields compose the index. The index ordering +// will be ascending by default. To obtain an index with a descending order, +// the field name should be prefixed by a dash (e.g. []string{"-time"}). It can +// also be optionally prefixed by an index kind, as in "$text:summary" or +// "$2d:-point". The key string format is: +// +// [$:][-] +// +// If the Unique field is true, the index must necessarily contain only a single +// document per Key. With DropDups set to true, documents with the same key +// as a previously indexed one will be dropped rather than an error returned. +// +// If Background is true, other connections will be allowed to proceed using +// the collection without the index while it's being built. Note that the +// session executing EnsureIndex will be blocked for as long as it takes for +// the index to be built. +// +// If Sparse is true, only documents containing the provided Key fields will be +// included in the index. When using a sparse index for sorting, only indexed +// documents will be returned. +// +// If ExpireAfter is non-zero, the server will periodically scan the collection +// and remove documents containing an indexed time.Time field with a value +// older than ExpireAfter. See the documentation for details: +// +// http://docs.mongodb.org/manual/tutorial/expire-data +// +// Other kinds of indexes are also supported through that API. Here is an example: +// +// index := Index{ +// Key: []string{"$2d:loc"}, +// Bits: 26, +// } +// err := collection.EnsureIndex(index) +// +// The example above requests the creation of a "2d" index for the "loc" field. +// +// The 2D index bounds may be changed using the Min and Max attributes of the +// Index value. The default bound setting of (-180, 180) is suitable for +// latitude/longitude pairs. +// +// The Bits parameter sets the precision of the 2D geohash values. If not +// provided, 26 bits are used, which is roughly equivalent to 1 foot of +// precision for the default (-180, 180) index bounds. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/Indexes +// http://www.mongodb.org/display/DOCS/Indexing+Advice+and+FAQ +// http://www.mongodb.org/display/DOCS/Indexing+as+a+Background+Operation +// http://www.mongodb.org/display/DOCS/Geospatial+Indexing +// http://www.mongodb.org/display/DOCS/Multikeys +// +func (c *Collection) EnsureIndex(index Index) error { + keyInfo, err := parseIndexKey(index.Key) + if err != nil { + return err + } + + session := c.Database.Session + cacheKey := c.FullName + "\x00" + keyInfo.name + if session.cluster().HasCachedIndex(cacheKey) { + return nil + } + + spec := indexSpec{ + Name: keyInfo.name, + NS: c.FullName, + Key: keyInfo.key, + Unique: index.Unique, + DropDups: index.DropDups, + Background: index.Background, + Sparse: index.Sparse, + Bits: index.Bits, + Min: index.Min, + Max: index.Max, + BucketSize: index.BucketSize, + ExpireAfter: int(index.ExpireAfter / time.Second), + Weights: keyInfo.weights, + DefaultLanguage: index.DefaultLanguage, + LanguageOverride: index.LanguageOverride, + } + +NextField: + for name, weight := range index.Weights { + for i, elem := range spec.Weights { + if elem.Name == name { + spec.Weights[i].Value = weight + continue NextField + } + } + panic("weight provided for field that is not part of index key: " + name) + } + + cloned := session.Clone() + defer cloned.Close() + cloned.SetMode(Strong, false) + cloned.EnsureSafe(&Safe{}) + db := c.Database.With(cloned) + + // Try with a command first. + err = db.Run(bson.D{{"createIndexes", c.Name}, {"indexes", []indexSpec{spec}}}, nil) + if isNoCmd(err) { + // Command not yet supported. Insert into the indexes collection instead. + err = db.C("system.indexes").Insert(&spec) + } + if err == nil { + session.cluster().CacheIndex(cacheKey, true) + } + return err +} + +// DropIndex removes the index with key from the collection. +// +// The key value determines which fields compose the index. The index ordering +// will be ascending by default. To obtain an index with a descending order, +// the field name should be prefixed by a dash (e.g. []string{"-time"}). +// +// For example: +// +// err := collection.DropIndex("lastname", "firstname") +// +// See the EnsureIndex method for more details on indexes. +func (c *Collection) DropIndex(key ...string) error { + keyInfo, err := parseIndexKey(key) + if err != nil { + return err + } + + session := c.Database.Session + cacheKey := c.FullName + "\x00" + keyInfo.name + session.cluster().CacheIndex(cacheKey, false) + + session = session.Clone() + defer session.Close() + session.SetMode(Strong, false) + + db := c.Database.With(session) + result := struct { + ErrMsg string + Ok bool + }{} + err = db.Run(bson.D{{"dropIndexes", c.Name}, {"index", keyInfo.name}}, &result) + if err != nil { + return err + } + if !result.Ok { + return errors.New(result.ErrMsg) + } + return nil +} + +// Indexes returns a list of all indexes for the collection. +// +// For example, this snippet would drop all available indexes: +// +// indexes, err := collection.Indexes() +// if err != nil { +// return err +// } +// for _, index := range indexes { +// err = collection.DropIndex(index.Key...) +// if err != nil { +// return err +// } +// } +// +// See the EnsureIndex method for more details on indexes. +func (c *Collection) Indexes() (indexes []Index, err error) { + // Clone session and set it to Monotonic mode so that the server + // used for the query may be safely obtained afterwards, if + // necessary for iteration when a cursor is received. + session := c.Database.Session + cloned := session.Clone() + cloned.SetMode(Monotonic, false) + defer cloned.Close() + + batchSize := int(cloned.queryConfig.op.limit) + + // Try with a command. + var result struct { + Indexes []bson.Raw + + Cursor struct { + FirstBatch []bson.Raw "firstBatch" + NS string + Id int64 + } + } + var iter *Iter + err = c.Database.With(cloned).Run(bson.D{{"listIndexes", c.Name}, {"cursor", bson.D{{"batchSize", batchSize}}}}, &result) + if err == nil { + firstBatch := result.Indexes + if firstBatch == nil { + firstBatch = result.Cursor.FirstBatch + } + ns := strings.SplitN(result.Cursor.NS, ".", 2) + if len(ns) < 2 { + iter = c.With(cloned).NewIter(nil, firstBatch, result.Cursor.Id, nil) + } else { + iter = cloned.DB(ns[0]).C(ns[1]).NewIter(nil, firstBatch, result.Cursor.Id, nil) + } + } else if isNoCmd(err) { + // Command not yet supported. Query the database instead. + iter = c.Database.C("system.indexes").Find(bson.M{"ns": c.FullName}).Iter() + } else { + return nil, err + } + + var spec indexSpec + for iter.Next(&spec) { + indexes = append(indexes, indexFromSpec(spec)) + } + if err = iter.Close(); err != nil { + return nil, err + } + sort.Sort(indexSlice(indexes)) + return indexes, nil +} + +func indexFromSpec(spec indexSpec) Index { + return Index{ + Name: spec.Name, + Key: simpleIndexKey(spec.Key), + Unique: spec.Unique, + DropDups: spec.DropDups, + Background: spec.Background, + Sparse: spec.Sparse, + ExpireAfter: time.Duration(spec.ExpireAfter) * time.Second, + } +} + +type indexSlice []Index + +func (idxs indexSlice) Len() int { return len(idxs) } +func (idxs indexSlice) Less(i, j int) bool { return idxs[i].Name < idxs[j].Name } +func (idxs indexSlice) Swap(i, j int) { idxs[i], idxs[j] = idxs[j], idxs[i] } + +func simpleIndexKey(realKey bson.D) (key []string) { + for i := range realKey { + field := realKey[i].Name + vi, ok := realKey[i].Value.(int) + if !ok { + vf, _ := realKey[i].Value.(float64) + vi = int(vf) + } + if vi == 1 { + key = append(key, field) + continue + } + if vi == -1 { + key = append(key, "-"+field) + continue + } + if vs, ok := realKey[i].Value.(string); ok { + key = append(key, "$"+vs+":"+field) + continue + } + panic("Got unknown index key type for field " + field) + } + return +} + +// ResetIndexCache() clears the cache of previously ensured indexes. +// Following requests to EnsureIndex will contact the server. +func (s *Session) ResetIndexCache() { + s.cluster().ResetIndexCache() +} + +// New creates a new session with the same parameters as the original +// session, including consistency, batch size, prefetching, safety mode, +// etc. The returned session will use sockets from the pool, so there's +// a chance that writes just performed in another session may not yet +// be visible. +// +// Login information from the original session will not be copied over +// into the new session unless it was provided through the initial URL +// for the Dial function. +// +// See the Copy and Clone methods. +// +func (s *Session) New() *Session { + s.m.Lock() + scopy := copySession(s, false) + s.m.Unlock() + scopy.Refresh() + return scopy +} + +// Copy works just like New, but preserves the exact authentication +// information from the original session. +func (s *Session) Copy() *Session { + s.m.Lock() + scopy := copySession(s, true) + s.m.Unlock() + scopy.Refresh() + return scopy +} + +// Clone works just like Copy, but also reuses the same socket as the original +// session, in case it had already reserved one due to its consistency +// guarantees. This behavior ensures that writes performed in the old session +// are necessarily observed when using the new session, as long as it was a +// strong or monotonic session. That said, it also means that long operations +// may cause other goroutines using the original session to wait. +func (s *Session) Clone() *Session { + s.m.Lock() + scopy := copySession(s, true) + s.m.Unlock() + return scopy +} + +// Close terminates the session. It's a runtime error to use a session +// after it has been closed. +func (s *Session) Close() { + s.m.Lock() + if s.cluster_ != nil { + debugf("Closing session %p", s) + s.unsetSocket() + s.cluster_.Release() + s.cluster_ = nil + } + s.m.Unlock() +} + +func (s *Session) cluster() *mongoCluster { + if s.cluster_ == nil { + panic("Session already closed") + } + return s.cluster_ +} + +// Refresh puts back any reserved sockets in use and restarts the consistency +// guarantees according to the current consistency setting for the session. +func (s *Session) Refresh() { + s.m.Lock() + s.slaveOk = s.consistency != Strong + s.unsetSocket() + s.m.Unlock() +} + +// SetMode changes the consistency mode for the session. +// +// In the Strong consistency mode reads and writes will always be made to +// the primary server using a unique connection so that reads and writes are +// fully consistent, ordered, and observing the most up-to-date data. +// This offers the least benefits in terms of distributing load, but the +// most guarantees. See also Monotonic and Eventual. +// +// In the Monotonic consistency mode reads may not be entirely up-to-date, +// but they will always see the history of changes moving forward, the data +// read will be consistent across sequential queries in the same session, +// and modifications made within the session will be observed in following +// queries (read-your-writes). +// +// In practice, the Monotonic mode is obtained by performing initial reads +// on a unique connection to an arbitrary secondary, if one is available, +// and once the first write happens, the session connection is switched over +// to the primary server. This manages to distribute some of the reading +// load with secondaries, while maintaining some useful guarantees. +// +// In the Eventual consistency mode reads will be made to any secondary in the +// cluster, if one is available, and sequential reads will not necessarily +// be made with the same connection. This means that data may be observed +// out of order. Writes will of course be issued to the primary, but +// independent writes in the same Eventual session may also be made with +// independent connections, so there are also no guarantees in terms of +// write ordering (no read-your-writes guarantees either). +// +// The Eventual mode is the fastest and most resource-friendly, but is +// also the one offering the least guarantees about ordering of the data +// read and written. +// +// If refresh is true, in addition to ensuring the session is in the given +// consistency mode, the consistency guarantees will also be reset (e.g. +// a Monotonic session will be allowed to read from secondaries again). +// This is equivalent to calling the Refresh function. +// +// Shifting between Monotonic and Strong modes will keep a previously +// reserved connection for the session unless refresh is true or the +// connection is unsuitable (to a secondary server in a Strong session). +func (s *Session) SetMode(consistency mode, refresh bool) { + s.m.Lock() + debugf("Session %p: setting mode %d with refresh=%v (master=%p, slave=%p)", s, consistency, refresh, s.masterSocket, s.slaveSocket) + s.consistency = consistency + if refresh { + s.slaveOk = s.consistency != Strong + s.unsetSocket() + } else if s.consistency == Strong { + s.slaveOk = false + } else if s.masterSocket == nil { + s.slaveOk = true + } + s.m.Unlock() +} + +// Mode returns the current consistency mode for the session. +func (s *Session) Mode() mode { + s.m.RLock() + mode := s.consistency + s.m.RUnlock() + return mode +} + +// SetSyncTimeout sets the amount of time an operation with this session +// will wait before returning an error in case a connection to a usable +// server can't be established. Set it to zero to wait forever. The +// default value is 7 seconds. +func (s *Session) SetSyncTimeout(d time.Duration) { + s.m.Lock() + s.syncTimeout = d + s.m.Unlock() +} + +// SetSocketTimeout sets the amount of time to wait for a non-responding +// socket to the database before it is forcefully closed. +func (s *Session) SetSocketTimeout(d time.Duration) { + s.m.Lock() + s.sockTimeout = d + if s.masterSocket != nil { + s.masterSocket.SetTimeout(d) + } + if s.slaveSocket != nil { + s.slaveSocket.SetTimeout(d) + } + s.m.Unlock() +} + +// SetCursorTimeout changes the standard timeout period that the server +// enforces on created cursors. The only supported value right now is +// 0, which disables the timeout. The standard server timeout is 10 minutes. +func (s *Session) SetCursorTimeout(d time.Duration) { + s.m.Lock() + if d == 0 { + s.queryConfig.op.flags |= flagNoCursorTimeout + } else { + panic("SetCursorTimeout: only 0 (disable timeout) supported for now") + } + s.m.Unlock() +} + +// SetPoolLimit sets the maximum number of sockets in use in a single server +// before this session will block waiting for a socket to be available. +// The default limit is 4096. +// +// This limit must be set to cover more than any expected workload of the +// application. It is a bad practice and an unsupported use case to use the +// database driver to define the concurrency limit of an application. Prevent +// such concurrency "at the door" instead, by properly restricting the amount +// of used resources and number of goroutines before they are created. +func (s *Session) SetPoolLimit(limit int) { + s.m.Lock() + s.poolLimit = limit + s.m.Unlock() +} + +// SetBatch sets the default batch size used when fetching documents from the +// database. It's possible to change this setting on a per-query basis as +// well, using the Query.Batch method. +// +// The default batch size is defined by the database itself. As of this +// writing, MongoDB will use an initial size of min(100 docs, 4MB) on the +// first batch, and 4MB on remaining ones. +func (s *Session) SetBatch(n int) { + if n == 1 { + // Server interprets 1 as -1 and closes the cursor (!?) + n = 2 + } + s.m.Lock() + s.queryConfig.op.limit = int32(n) + s.m.Unlock() +} + +// SetPrefetch sets the default point at which the next batch of results will be +// requested. When there are p*batch_size remaining documents cached in an +// Iter, the next batch will be requested in background. For instance, when +// using this: +// +// session.SetBatch(200) +// session.SetPrefetch(0.25) +// +// and there are only 50 documents cached in the Iter to be processed, the +// next batch of 200 will be requested. It's possible to change this setting on +// a per-query basis as well, using the Prefetch method of Query. +// +// The default prefetch value is 0.25. +func (s *Session) SetPrefetch(p float64) { + s.m.Lock() + s.queryConfig.prefetch = p + s.m.Unlock() +} + +// See SetSafe for details on the Safe type. +type Safe struct { + W int // Min # of servers to ack before success + WMode string // Write mode for MongoDB 2.0+ (e.g. "majority") + WTimeout int // Milliseconds to wait for W before timing out + FSync bool // Should servers sync to disk before returning success + J bool // Wait for next group commit if journaling; no effect otherwise +} + +// Safe returns the current safety mode for the session. +func (s *Session) Safe() (safe *Safe) { + s.m.Lock() + defer s.m.Unlock() + if s.safeOp != nil { + cmd := s.safeOp.query.(*getLastError) + safe = &Safe{WTimeout: cmd.WTimeout, FSync: cmd.FSync, J: cmd.J} + switch w := cmd.W.(type) { + case string: + safe.WMode = w + case int: + safe.W = w + } + } + return +} + +// SetSafe changes the session safety mode. +// +// If the safe parameter is nil, the session is put in unsafe mode, and writes +// become fire-and-forget, without error checking. The unsafe mode is faster +// since operations won't hold on waiting for a confirmation. +// +// If the safe parameter is not nil, any changing query (insert, update, ...) +// will be followed by a getLastError command with the specified parameters, +// to ensure the request was correctly processed. +// +// The safe.W parameter determines how many servers should confirm a write +// before the operation is considered successful. If set to 0 or 1, the +// command will return as soon as the primary is done with the request. +// If safe.WTimeout is greater than zero, it determines how many milliseconds +// to wait for the safe.W servers to respond before returning an error. +// +// Starting with MongoDB 2.0.0 the safe.WMode parameter can be used instead +// of W to request for richer semantics. If set to "majority" the server will +// wait for a majority of members from the replica set to respond before +// returning. Custom modes may also be defined within the server to create +// very detailed placement schemas. See the data awareness documentation in +// the links below for more details (note that MongoDB internally reuses the +// "w" field name for WMode). +// +// If safe.FSync is true and journaling is disabled, the servers will be +// forced to sync all files to disk immediately before returning. If the +// same option is true but journaling is enabled, the server will instead +// await for the next group commit before returning. +// +// Since MongoDB 2.0.0, the safe.J option can also be used instead of FSync +// to force the server to wait for a group commit in case journaling is +// enabled. The option has no effect if the server has journaling disabled. +// +// For example, the following statement will make the session check for +// errors, without imposing further constraints: +// +// session.SetSafe(&mgo.Safe{}) +// +// The following statement will force the server to wait for a majority of +// members of a replica set to return (MongoDB 2.0+ only): +// +// session.SetSafe(&mgo.Safe{WMode: "majority"}) +// +// The following statement, on the other hand, ensures that at least two +// servers have flushed the change to disk before confirming the success +// of operations: +// +// session.EnsureSafe(&mgo.Safe{W: 2, FSync: true}) +// +// The following statement, on the other hand, disables the verification +// of errors entirely: +// +// session.SetSafe(nil) +// +// See also the EnsureSafe method. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/getLastError+Command +// http://www.mongodb.org/display/DOCS/Verifying+Propagation+of+Writes+with+getLastError +// http://www.mongodb.org/display/DOCS/Data+Center+Awareness +// +func (s *Session) SetSafe(safe *Safe) { + s.m.Lock() + s.safeOp = nil + s.ensureSafe(safe) + s.m.Unlock() +} + +// EnsureSafe compares the provided safety parameters with the ones +// currently in use by the session and picks the most conservative +// choice for each setting. +// +// That is: +// +// - safe.WMode is always used if set. +// - safe.W is used if larger than the current W and WMode is empty. +// - safe.FSync is always used if true. +// - safe.J is used if FSync is false. +// - safe.WTimeout is used if set and smaller than the current WTimeout. +// +// For example, the following statement will ensure the session is +// at least checking for errors, without enforcing further constraints. +// If a more conservative SetSafe or EnsureSafe call was previously done, +// the following call will be ignored. +// +// session.EnsureSafe(&mgo.Safe{}) +// +// See also the SetSafe method for details on what each option means. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/getLastError+Command +// http://www.mongodb.org/display/DOCS/Verifying+Propagation+of+Writes+with+getLastError +// http://www.mongodb.org/display/DOCS/Data+Center+Awareness +// +func (s *Session) EnsureSafe(safe *Safe) { + s.m.Lock() + s.ensureSafe(safe) + s.m.Unlock() +} + +func (s *Session) ensureSafe(safe *Safe) { + if safe == nil { + return + } + + var w interface{} + if safe.WMode != "" { + w = safe.WMode + } else if safe.W > 0 { + w = safe.W + } + + var cmd getLastError + if s.safeOp == nil { + cmd = getLastError{1, w, safe.WTimeout, safe.FSync, safe.J} + } else { + // Copy. We don't want to mutate the existing query. + cmd = *(s.safeOp.query.(*getLastError)) + if cmd.W == nil { + cmd.W = w + } else if safe.WMode != "" { + cmd.W = safe.WMode + } else if i, ok := cmd.W.(int); ok && safe.W > i { + cmd.W = safe.W + } + if safe.WTimeout > 0 && safe.WTimeout < cmd.WTimeout { + cmd.WTimeout = safe.WTimeout + } + if safe.FSync { + cmd.FSync = true + cmd.J = false + } else if safe.J && !cmd.FSync { + cmd.J = true + } + } + s.safeOp = &queryOp{ + query: &cmd, + collection: "admin.$cmd", + limit: -1, + } +} + +// Run issues the provided command on the "admin" database and +// and unmarshals its result in the respective argument. The cmd +// argument may be either a string with the command name itself, in +// which case an empty document of the form bson.M{cmd: 1} will be used, +// or it may be a full command document. +// +// Note that MongoDB considers the first marshalled key as the command +// name, so when providing a command with options, it's important to +// use an ordering-preserving document, such as a struct value or an +// instance of bson.D. For instance: +// +// db.Run(bson.D{{"create", "mycollection"}, {"size", 1024}}) +// +// For commands on arbitrary databases, see the Run method in +// the Database type. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/Commands +// http://www.mongodb.org/display/DOCS/List+of+Database+CommandSkips +// +func (s *Session) Run(cmd interface{}, result interface{}) error { + return s.DB("admin").Run(cmd, result) +} + +// SelectServers restricts communication to servers configured with the +// given tags. For example, the following statement restricts servers +// used for reading operations to those with both tag "disk" set to +// "ssd" and tag "rack" set to 1: +// +// session.SelectSlaves(bson.D{{"disk", "ssd"}, {"rack", 1}}) +// +// Multiple sets of tags may be provided, in which case the used server +// must match all tags within any one set. +// +// If a connection was previously assigned to the session due to the +// current session mode (see Session.SetMode), the tag selection will +// only be enforced after the session is refreshed. +// +// Relevant documentation: +// +// http://docs.mongodb.org/manual/tutorial/configure-replica-set-tag-sets +// +func (s *Session) SelectServers(tags ...bson.D) { + s.m.Lock() + s.queryConfig.op.serverTags = tags + s.m.Unlock() +} + +// Ping runs a trivial ping command just to get in touch with the server. +func (s *Session) Ping() error { + return s.Run("ping", nil) +} + +// Fsync flushes in-memory writes to disk on the server the session +// is established with. If async is true, the call returns immediately, +// otherwise it returns after the flush has been made. +func (s *Session) Fsync(async bool) error { + return s.Run(bson.D{{"fsync", 1}, {"async", async}}, nil) +} + +// FsyncLock locks all writes in the specific server the session is +// established with and returns. Any writes attempted to the server +// after it is successfully locked will block until FsyncUnlock is +// called for the same server. +// +// This method works on secondaries as well, preventing the oplog from +// being flushed while the server is locked, but since only the server +// connected to is locked, for locking specific secondaries it may be +// necessary to establish a connection directly to the secondary (see +// Dial's connect=direct option). +// +// As an important caveat, note that once a write is attempted and +// blocks, follow up reads will block as well due to the way the +// lock is internally implemented in the server. More details at: +// +// https://jira.mongodb.org/browse/SERVER-4243 +// +// FsyncLock is often used for performing consistent backups of +// the database files on disk. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/fsync+Command +// http://www.mongodb.org/display/DOCS/Backups +// +func (s *Session) FsyncLock() error { + return s.Run(bson.D{{"fsync", 1}, {"lock", true}}, nil) +} + +// FsyncUnlock releases the server for writes. See FsyncLock for details. +func (s *Session) FsyncUnlock() error { + return s.DB("admin").C("$cmd.sys.unlock").Find(nil).One(nil) // WTF? +} + +// Find prepares a query using the provided document. The document may be a +// map or a struct value capable of being marshalled with bson. The map +// may be a generic one using interface{} for its key and/or values, such as +// bson.M, or it may be a properly typed map. Providing nil as the document +// is equivalent to providing an empty document such as bson.M{}. +// +// Further details of the query may be tweaked using the resulting Query value, +// and then executed to retrieve results using methods such as One, For, +// Iter, or Tail. +// +// In case the resulting document includes a field named $err or errmsg, which +// are standard ways for MongoDB to return query errors, the returned err will +// be set to a *QueryError value including the Err message and the Code. In +// those cases, the result argument is still unmarshalled into with the +// received document so that any other custom values may be obtained if +// desired. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/Querying +// http://www.mongodb.org/display/DOCS/Advanced+Queries +// +func (c *Collection) Find(query interface{}) *Query { + session := c.Database.Session + session.m.RLock() + q := &Query{session: session, query: session.queryConfig} + session.m.RUnlock() + q.op.query = query + q.op.collection = c.FullName + return q +} + +type repairCmd struct { + RepairCursor string `bson:"repairCursor"` + Cursor *repairCmdCursor ",omitempty" +} + +type repairCmdCursor struct { + BatchSize int `bson:"batchSize,omitempty"` +} + +// Repair returns an iterator that goes over all recovered documents in the +// collection, in a best-effort manner. This is most useful when there are +// damaged data files. Multiple copies of the same document may be returned +// by the iterator. +// +// Repair is supported in MongoDB 2.7.8 and later. +func (c *Collection) Repair() *Iter { + // Clone session and set it to Monotonic mode so that the server + // used for the query may be safely obtained afterwards, if + // necessary for iteration when a cursor is received. + session := c.Database.Session + cloned := session.Clone() + cloned.SetMode(Monotonic, false) + defer cloned.Close() + + batchSize := int(cloned.queryConfig.op.limit) + + var result struct { + Cursor struct { + FirstBatch []bson.Raw "firstBatch" + Id int64 + } + } + + cmd := repairCmd{ + RepairCursor: c.Name, + Cursor: &repairCmdCursor{batchSize}, + } + + clonedc := c.With(cloned) + err := clonedc.Database.Run(cmd, &result) + return clonedc.NewIter(session, result.Cursor.FirstBatch, result.Cursor.Id, err) +} + +// FindId is a convenience helper equivalent to: +// +// query := collection.Find(bson.M{"_id": id}) +// +// See the Find method for more details. +func (c *Collection) FindId(id interface{}) *Query { + return c.Find(bson.D{{"_id", id}}) +} + +type Pipe struct { + session *Session + collection *Collection + pipeline interface{} + allowDisk bool + batchSize int +} + +type pipeCmd struct { + Aggregate string + Pipeline interface{} + Cursor *pipeCmdCursor ",omitempty" + Explain bool ",omitempty" + AllowDisk bool "allowDiskUse,omitempty" +} + +type pipeCmdCursor struct { + BatchSize int `bson:"batchSize,omitempty"` +} + +// Pipe prepares a pipeline to aggregate. The pipeline document +// must be a slice built in terms of the aggregation framework language. +// +// For example: +// +// pipe := collection.Pipe([]bson.M{{"$match": bson.M{"name": "Otavio"}}}) +// iter := pipe.Iter() +// +// Relevant documentation: +// +// http://docs.mongodb.org/manual/reference/aggregation +// http://docs.mongodb.org/manual/applications/aggregation +// http://docs.mongodb.org/manual/tutorial/aggregation-examples +// +func (c *Collection) Pipe(pipeline interface{}) *Pipe { + session := c.Database.Session + session.m.RLock() + batchSize := int(session.queryConfig.op.limit) + session.m.RUnlock() + return &Pipe{ + session: session, + collection: c, + pipeline: pipeline, + batchSize: batchSize, + } +} + +// Iter executes the pipeline and returns an iterator capable of going +// over all the generated results. +func (p *Pipe) Iter() *Iter { + // Clone session and set it to Monotonic mode so that the server + // used for the query may be safely obtained afterwards, if + // necessary for iteration when a cursor is received. + cloned := p.session.Clone() + cloned.SetMode(Monotonic, false) + defer cloned.Close() + c := p.collection.With(cloned) + + var result struct { + // 2.4, no cursors. + Result []bson.Raw + + // 2.6+, with cursors. + Cursor struct { + FirstBatch []bson.Raw "firstBatch" + Id int64 + } + } + + cmd := pipeCmd{ + Aggregate: c.Name, + Pipeline: p.pipeline, + AllowDisk: p.allowDisk, + Cursor: &pipeCmdCursor{p.batchSize}, + } + err := c.Database.Run(cmd, &result) + if e, ok := err.(*QueryError); ok && e.Message == `unrecognized field "cursor` { + cmd.Cursor = nil + cmd.AllowDisk = false + err = c.Database.Run(cmd, &result) + } + firstBatch := result.Result + if firstBatch == nil { + firstBatch = result.Cursor.FirstBatch + } + return c.NewIter(p.session, firstBatch, result.Cursor.Id, err) +} + +// NewIter returns a newly created iterator with the provided parameters. +// Using this method is not recommended unless the desired functionality +// is not yet exposed via a more convenient interface (Find, Pipe, etc). +// +// The optional session parameter associates the lifetime of the returned +// iterator to an arbitrary session. If nil, the iterator will be bound to +// c's session. +// +// Documents in firstBatch will be individually provided by the returned +// iterator before documents from cursorId are made available. If cursorId +// is zero, only the documents in firstBatch are provided. +// +// If err is not nil, the iterator's Err method will report it after +// exhausting documents in firstBatch. +// +// NewIter must be called right after the cursor id is obtained, and must not +// be called on a collection in Eventual mode, because the cursor id is +// associated with the specific server that returned it. The provided session +// parameter may be in any mode or state, though. +// +func (c *Collection) NewIter(session *Session, firstBatch []bson.Raw, cursorId int64, err error) *Iter { + var server *mongoServer + csession := c.Database.Session + csession.m.RLock() + socket := csession.masterSocket + if socket == nil { + socket = csession.slaveSocket + } + if socket != nil { + server = socket.Server() + } + csession.m.RUnlock() + + if server == nil { + if csession.Mode() == Eventual { + panic("Collection.NewIter called in Eventual mode") + } + if err == nil { + err = errors.New("server not available") + } + } + + if session == nil { + session = csession + } + + iter := &Iter{ + session: session, + server: server, + timeout: -1, + err: err, + } + iter.gotReply.L = &iter.m + for _, doc := range firstBatch { + iter.docData.Push(doc.Data) + } + if cursorId != 0 { + iter.op.cursorId = cursorId + iter.op.collection = c.FullName + iter.op.replyFunc = iter.replyFunc() + } + return iter +} + +// All works like Iter.All. +func (p *Pipe) All(result interface{}) error { + return p.Iter().All(result) +} + +// One executes the pipeline and unmarshals the first item from the +// result set into the result parameter. +// It returns ErrNotFound if no items are generated by the pipeline. +func (p *Pipe) One(result interface{}) error { + iter := p.Iter() + if iter.Next(result) { + return nil + } + if err := iter.Err(); err != nil { + return err + } + return ErrNotFound +} + +// Explain returns a number of details about how the MongoDB server would +// execute the requested pipeline, such as the number of objects examined, +// the number of times the read lock was yielded to allow writes to go in, +// and so on. +// +// For example: +// +// var m bson.M +// err := collection.Pipe(pipeline).Explain(&m) +// if err == nil { +// fmt.Printf("Explain: %#v\n", m) +// } +// +func (p *Pipe) Explain(result interface{}) error { + c := p.collection + cmd := pipeCmd{ + Aggregate: c.Name, + Pipeline: p.pipeline, + AllowDisk: p.allowDisk, + Explain: true, + } + return c.Database.Run(cmd, result) +} + +// AllowDiskUse enables writing to the "/_tmp" server directory so +// that aggregation pipelines do not have to be held entirely in memory. +func (p *Pipe) AllowDiskUse() *Pipe { + p.allowDisk = true + return p +} + +// Batch sets the batch size used when fetching documents from the database. +// It's possible to change this setting on a per-session basis as well, using +// the Batch method of Session. +// +// The default batch size is defined by the database server. +func (p *Pipe) Batch(n int) *Pipe { + p.batchSize = n + return p +} + +type LastError struct { + Err string + Code, N, Waited int + FSyncFiles int `bson:"fsyncFiles"` + WTimeout bool + UpdatedExisting bool `bson:"updatedExisting"` + UpsertedId interface{} `bson:"upserted"` +} + +func (err *LastError) Error() string { + return err.Err +} + +type queryError struct { + Err string "$err" + ErrMsg string + Assertion string + Code int + AssertionCode int "assertionCode" + LastError *LastError "lastErrorObject" +} + +type QueryError struct { + Code int + Message string + Assertion bool +} + +func (err *QueryError) Error() string { + return err.Message +} + +// IsDup returns whether err informs of a duplicate key error because +// a primary key index or a secondary unique index already has an entry +// with the given value. +func IsDup(err error) bool { + // Besides being handy, helps with MongoDB bugs SERVER-7164 and SERVER-11493. + // What follows makes me sad. Hopefully conventions will be more clear over time. + switch e := err.(type) { + case *LastError: + return e.Code == 11000 || e.Code == 11001 || e.Code == 12582 || e.Code == 16460 && strings.Contains(e.Err, " E11000 ") + case *QueryError: + return e.Code == 11000 || e.Code == 11001 || e.Code == 12582 + } + return false +} + +// Insert inserts one or more documents in the respective collection. In +// case the session is in safe mode (see the SetSafe method) and an error +// happens while inserting the provided documents, the returned error will +// be of type *LastError. +func (c *Collection) Insert(docs ...interface{}) error { + _, err := c.writeQuery(&insertOp{c.FullName, docs, 0}) + return err +} + +// Update finds a single document matching the provided selector document +// and modifies it according to the update document. +// If the session is in safe mode (see SetSafe) a ErrNotFound error is +// returned if a document isn't found, or a value of type *LastError +// when some other error is detected. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/Updating +// http://www.mongodb.org/display/DOCS/Atomic+Operations +// +func (c *Collection) Update(selector interface{}, update interface{}) error { + lerr, err := c.writeQuery(&updateOp{c.FullName, selector, update, 0}) + if err == nil && lerr != nil && !lerr.UpdatedExisting { + return ErrNotFound + } + return err +} + +// UpdateId is a convenience helper equivalent to: +// +// err := collection.Update(bson.M{"_id": id}, update) +// +// See the Update method for more details. +func (c *Collection) UpdateId(id interface{}, update interface{}) error { + return c.Update(bson.D{{"_id", id}}, update) +} + +// ChangeInfo holds details about the outcome of an update operation. +type ChangeInfo struct { + Updated int // Number of existing documents updated + Removed int // Number of documents removed + UpsertedId interface{} // Upserted _id field, when not explicitly provided +} + +// UpdateAll finds all documents matching the provided selector document +// and modifies them according to the update document. +// If the session is in safe mode (see SetSafe) details of the executed +// operation are returned in info or an error of type *LastError when +// some problem is detected. It is not an error for the update to not be +// applied on any documents because the selector doesn't match. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/Updating +// http://www.mongodb.org/display/DOCS/Atomic+Operations +// +func (c *Collection) UpdateAll(selector interface{}, update interface{}) (info *ChangeInfo, err error) { + lerr, err := c.writeQuery(&updateOp{c.FullName, selector, update, 2}) + if err == nil && lerr != nil { + info = &ChangeInfo{Updated: lerr.N} + } + return info, err +} + +// Upsert finds a single document matching the provided selector document +// and modifies it according to the update document. If no document matching +// the selector is found, the update document is applied to the selector +// document and the result is inserted in the collection. +// If the session is in safe mode (see SetSafe) details of the executed +// operation are returned in info, or an error of type *LastError when +// some problem is detected. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/Updating +// http://www.mongodb.org/display/DOCS/Atomic+Operations +// +func (c *Collection) Upsert(selector interface{}, update interface{}) (info *ChangeInfo, err error) { + lerr, err := c.writeQuery(&updateOp{c.FullName, selector, update, 1}) + if err == nil && lerr != nil { + info = &ChangeInfo{} + if lerr.UpdatedExisting { + info.Updated = lerr.N + } else { + info.UpsertedId = lerr.UpsertedId + } + } + return info, err +} + +// UpsertId is a convenience helper equivalent to: +// +// info, err := collection.Upsert(bson.M{"_id": id}, update) +// +// See the Upsert method for more details. +func (c *Collection) UpsertId(id interface{}, update interface{}) (info *ChangeInfo, err error) { + return c.Upsert(bson.D{{"_id", id}}, update) +} + +// Remove finds a single document matching the provided selector document +// and removes it from the database. +// If the session is in safe mode (see SetSafe) a ErrNotFound error is +// returned if a document isn't found, or a value of type *LastError +// when some other error is detected. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/Removing +// +func (c *Collection) Remove(selector interface{}) error { + lerr, err := c.writeQuery(&deleteOp{c.FullName, selector, 1}) + if err == nil && lerr != nil && lerr.N == 0 { + return ErrNotFound + } + return err +} + +// RemoveId is a convenience helper equivalent to: +// +// err := collection.Remove(bson.M{"_id": id}) +// +// See the Remove method for more details. +func (c *Collection) RemoveId(id interface{}) error { + return c.Remove(bson.D{{"_id", id}}) +} + +// RemoveAll finds all documents matching the provided selector document +// and removes them from the database. In case the session is in safe mode +// (see the SetSafe method) and an error happens when attempting the change, +// the returned error will be of type *LastError. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/Removing +// +func (c *Collection) RemoveAll(selector interface{}) (info *ChangeInfo, err error) { + lerr, err := c.writeQuery(&deleteOp{c.FullName, selector, 0}) + if err == nil && lerr != nil { + info = &ChangeInfo{Removed: lerr.N} + } + return info, err +} + +// DropDatabase removes the entire database including all of its collections. +func (db *Database) DropDatabase() error { + return db.Run(bson.D{{"dropDatabase", 1}}, nil) +} + +// DropCollection removes the entire collection including all of its documents. +func (c *Collection) DropCollection() error { + return c.Database.Run(bson.D{{"drop", c.Name}}, nil) +} + +// The CollectionInfo type holds metadata about a collection. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/createCollection+Command +// http://www.mongodb.org/display/DOCS/Capped+Collections +// +type CollectionInfo struct { + // DisableIdIndex prevents the automatic creation of the index + // on the _id field for the collection. + DisableIdIndex bool + + // ForceIdIndex enforces the automatic creation of the index + // on the _id field for the collection. Capped collections, + // for example, do not have such an index by default. + ForceIdIndex bool + + // If Capped is true new documents will replace old ones when + // the collection is full. MaxBytes must necessarily be set + // to define the size when the collection wraps around. + // MaxDocs optionally defines the number of documents when it + // wraps, but MaxBytes still needs to be set. + Capped bool + MaxBytes int + MaxDocs int +} + +// Create explicitly creates the c collection with details of info. +// MongoDB creates collections automatically on use, so this method +// is only necessary when creating collection with non-default +// characteristics, such as capped collections. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/createCollection+Command +// http://www.mongodb.org/display/DOCS/Capped+Collections +// +func (c *Collection) Create(info *CollectionInfo) error { + cmd := make(bson.D, 0, 4) + cmd = append(cmd, bson.DocElem{"create", c.Name}) + if info.Capped { + if info.MaxBytes < 1 { + return fmt.Errorf("Collection.Create: with Capped, MaxBytes must also be set") + } + cmd = append(cmd, bson.DocElem{"capped", true}) + cmd = append(cmd, bson.DocElem{"size", info.MaxBytes}) + if info.MaxDocs > 0 { + cmd = append(cmd, bson.DocElem{"max", info.MaxDocs}) + } + } + if info.DisableIdIndex { + cmd = append(cmd, bson.DocElem{"autoIndexId", false}) + } + if info.ForceIdIndex { + cmd = append(cmd, bson.DocElem{"autoIndexId", true}) + } + return c.Database.Run(cmd, nil) +} + +// Batch sets the batch size used when fetching documents from the database. +// It's possible to change this setting on a per-session basis as well, using +// the Batch method of Session. +// +// The default batch size is defined by the database itself. As of this +// writing, MongoDB will use an initial size of min(100 docs, 4MB) on the +// first batch, and 4MB on remaining ones. +func (q *Query) Batch(n int) *Query { + if n == 1 { + // Server interprets 1 as -1 and closes the cursor (!?) + n = 2 + } + q.m.Lock() + q.op.limit = int32(n) + q.m.Unlock() + return q +} + +// Prefetch sets the point at which the next batch of results will be requested. +// When there are p*batch_size remaining documents cached in an Iter, the next +// batch will be requested in background. For instance, when using this: +// +// query.Batch(200).Prefetch(0.25) +// +// and there are only 50 documents cached in the Iter to be processed, the +// next batch of 200 will be requested. It's possible to change this setting on +// a per-session basis as well, using the SetPrefetch method of Session. +// +// The default prefetch value is 0.25. +func (q *Query) Prefetch(p float64) *Query { + q.m.Lock() + q.prefetch = p + q.m.Unlock() + return q +} + +// Skip skips over the n initial documents from the query results. Note that +// this only makes sense with capped collections where documents are naturally +// ordered by insertion time, or with sorted results. +func (q *Query) Skip(n int) *Query { + q.m.Lock() + q.op.skip = int32(n) + q.m.Unlock() + return q +} + +// Limit restricts the maximum number of documents retrieved to n, and also +// changes the batch size to the same value. Once n documents have been +// returned by Next, the following call will return ErrNotFound. +func (q *Query) Limit(n int) *Query { + q.m.Lock() + switch { + case n == 1: + q.limit = 1 + q.op.limit = -1 + case n == math.MinInt32: // -MinInt32 == -MinInt32 + q.limit = math.MaxInt32 + q.op.limit = math.MinInt32 + 1 + case n < 0: + q.limit = int32(-n) + q.op.limit = int32(n) + default: + q.limit = int32(n) + q.op.limit = int32(n) + } + q.m.Unlock() + return q +} + +// Select enables selecting which fields should be retrieved for the results +// found. For example, the following query would only retrieve the name field: +// +// err := collection.Find(nil).Select(bson.M{"name": 1}).One(&result) +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/Retrieving+a+Subset+of+Fields +// +func (q *Query) Select(selector interface{}) *Query { + q.m.Lock() + q.op.selector = selector + q.m.Unlock() + return q +} + +// Sort asks the database to order returned documents according to the +// provided field names. A field name may be prefixed by - (minus) for +// it to be sorted in reverse order. +// +// For example: +// +// query1 := collection.Find(nil).Sort("firstname", "lastname") +// query2 := collection.Find(nil).Sort("-age") +// query3 := collection.Find(nil).Sort("$natural") +// query4 := collection.Find(nil).Select(bson.M{"score": bson.M{"$meta": "textScore"}}).Sort("$textScore:score") +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/Sorting+and+Natural+Order +// +func (q *Query) Sort(fields ...string) *Query { + q.m.Lock() + var order bson.D + for _, field := range fields { + n := 1 + var kind string + if field != "" { + if field[0] == '$' { + if c := strings.Index(field, ":"); c > 1 && c < len(field)-1 { + kind = field[1:c] + field = field[c+1:] + } + } + switch field[0] { + case '+': + field = field[1:] + case '-': + n = -1 + field = field[1:] + } + } + if field == "" { + panic("Sort: empty field name") + } + if kind == "textScore" { + order = append(order, bson.DocElem{field, bson.M{"$meta": kind}}) + } else { + order = append(order, bson.DocElem{field, n}) + } + } + q.op.options.OrderBy = order + q.op.hasOptions = true + q.m.Unlock() + return q +} + +// Explain returns a number of details about how the MongoDB server would +// execute the requested query, such as the number of objects examined, +// the number of times the read lock was yielded to allow writes to go in, +// and so on. +// +// For example: +// +// m := bson.M{} +// err := collection.Find(bson.M{"filename": name}).Explain(m) +// if err == nil { +// fmt.Printf("Explain: %#v\n", m) +// } +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/Optimization +// http://www.mongodb.org/display/DOCS/Query+Optimizer +// +func (q *Query) Explain(result interface{}) error { + q.m.Lock() + clone := &Query{session: q.session, query: q.query} + q.m.Unlock() + clone.op.options.Explain = true + clone.op.hasOptions = true + if clone.op.limit > 0 { + clone.op.limit = -q.op.limit + } + iter := clone.Iter() + if iter.Next(result) { + return nil + } + return iter.Close() +} + +// Hint will include an explicit "hint" in the query to force the server +// to use a specified index, potentially improving performance in some +// situations. The provided parameters are the fields that compose the +// key of the index to be used. For details on how the indexKey may be +// built, see the EnsureIndex method. +// +// For example: +// +// query := collection.Find(bson.M{"firstname": "Joe", "lastname": "Winter"}) +// query.Hint("lastname", "firstname") +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/Optimization +// http://www.mongodb.org/display/DOCS/Query+Optimizer +// +func (q *Query) Hint(indexKey ...string) *Query { + q.m.Lock() + keyInfo, err := parseIndexKey(indexKey) + q.op.options.Hint = keyInfo.key + q.op.hasOptions = true + q.m.Unlock() + if err != nil { + panic(err) + } + return q +} + +// SetMaxScan constrains the query to stop after scanning the specified +// number of documents. +// +// This modifier is generally used to prevent potentially long running +// queries from disrupting performance by scanning through too much data. +func (q *Query) SetMaxScan(n int) *Query { + q.m.Lock() + q.op.options.MaxScan = n + q.op.hasOptions = true + q.m.Unlock() + return q +} + +// SetMaxTime constrains the query to stop after running for the specified time. +// +// When the time limit is reached MongoDB automatically cancels the query. +// This can be used to efficiently prevent and identify unexpectedly slow queries. +// +// A few important notes about the mechanism enforcing this limit: +// +// - Requests can block behind locking operations on the server, and that blocking +// time is not accounted for. In other words, the timer starts ticking only after +// the actual start of the query when it initially acquires the appropriate lock; +// +// - Operations are interrupted only at interrupt points where an operation can be +// safely aborted – the total execution time may exceed the specified value; +// +// - The limit can be applied to both CRUD operations and commands, but not all +// commands are interruptible; +// +// - While iterating over results, computing follow up batches is included in the +// total time and the iteration continues until the alloted time is over, but +// network roundtrips are not taken into account for the limit. +// +// - This limit does not override the inactive cursor timeout for idle cursors +// (default is 10 min). +// +// This mechanism was introduced in MongoDB 2.6. +// +// Relevant documentation: +// +// http://blog.mongodb.org/post/83621787773/maxtimems-and-query-optimizer-introspection-in +// +func (q *Query) SetMaxTime(d time.Duration) *Query { + q.m.Lock() + q.op.options.MaxTimeMS = int(d / time.Millisecond) + q.op.hasOptions = true + q.m.Unlock() + return q +} + +// Snapshot will force the performed query to make use of an available +// index on the _id field to prevent the same document from being returned +// more than once in a single iteration. This might happen without this +// setting in situations when the document changes in size and thus has to +// be moved while the iteration is running. +// +// Because snapshot mode traverses the _id index, it may not be used with +// sorting or explicit hints. It also cannot use any other index for the +// query. +// +// Even with snapshot mode, items inserted or deleted during the query may +// or may not be returned; that is, this mode is not a true point-in-time +// snapshot. +// +// The same effect of Snapshot may be obtained by using any unique index on +// field(s) that will not be modified (best to use Hint explicitly too). +// A non-unique index (such as creation time) may be made unique by +// appending _id to the index when creating it. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/How+to+do+Snapshotted+Queries+in+the+Mongo+Database +// +func (q *Query) Snapshot() *Query { + q.m.Lock() + q.op.options.Snapshot = true + q.op.hasOptions = true + q.m.Unlock() + return q +} + +// Comment adds a comment to the query to identify it in the database profiler output. +// +// Relevant documentation: +// +// http://docs.mongodb.org/manual/reference/operator/meta/comment +// http://docs.mongodb.org/manual/reference/command/profile +// http://docs.mongodb.org/manual/administration/analyzing-mongodb-performance/#database-profiling +// +func (q *Query) Comment(comment string) *Query { + q.m.Lock() + q.op.options.Comment = comment + q.op.hasOptions = true + q.m.Unlock() + return q +} + +// LogReplay enables an option that optimizes queries that are typically +// made on the MongoDB oplog for replaying it. This is an internal +// implementation aspect and most likely uninteresting for other uses. +// It has seen at least one use case, though, so it's exposed via the API. +func (q *Query) LogReplay() *Query { + q.m.Lock() + q.op.flags |= flagLogReplay + q.m.Unlock() + return q +} + +func checkQueryError(fullname string, d []byte) error { + l := len(d) + if l < 16 { + return nil + } + if d[5] == '$' && d[6] == 'e' && d[7] == 'r' && d[8] == 'r' && d[9] == '\x00' && d[4] == '\x02' { + goto Error + } + if len(fullname) < 5 || fullname[len(fullname)-5:] != ".$cmd" { + return nil + } + for i := 0; i+8 < l; i++ { + if d[i] == '\x02' && d[i+1] == 'e' && d[i+2] == 'r' && d[i+3] == 'r' && d[i+4] == 'm' && d[i+5] == 's' && d[i+6] == 'g' && d[i+7] == '\x00' { + goto Error + } + } + return nil + +Error: + result := &queryError{} + bson.Unmarshal(d, result) + if result.LastError != nil { + return result.LastError + } + if result.Err == "" && result.ErrMsg == "" { + return nil + } + if result.AssertionCode != 0 && result.Assertion != "" { + return &QueryError{Code: result.AssertionCode, Message: result.Assertion, Assertion: true} + } + if result.Err != "" { + return &QueryError{Code: result.Code, Message: result.Err} + } + return &QueryError{Code: result.Code, Message: result.ErrMsg} +} + +// One executes the query and unmarshals the first obtained document into the +// result argument. The result must be a struct or map value capable of being +// unmarshalled into by gobson. This function blocks until either a result +// is available or an error happens. For example: +// +// err := collection.Find(bson.M{"a", 1}).One(&result) +// +// In case the resulting document includes a field named $err or errmsg, which +// are standard ways for MongoDB to return query errors, the returned err will +// be set to a *QueryError value including the Err message and the Code. In +// those cases, the result argument is still unmarshalled into with the +// received document so that any other custom values may be obtained if +// desired. +// +func (q *Query) One(result interface{}) (err error) { + q.m.Lock() + session := q.session + op := q.op // Copy. + q.m.Unlock() + + socket, err := session.acquireSocket(true) + if err != nil { + return err + } + defer socket.Release() + + op.flags |= session.slaveOkFlag() + op.limit = -1 + + data, err := socket.SimpleQuery(&op) + if err != nil { + return err + } + if data == nil { + return ErrNotFound + } + if result != nil { + err = bson.Unmarshal(data, result) + if err == nil { + debugf("Query %p document unmarshaled: %#v", q, result) + } else { + debugf("Query %p document unmarshaling failed: %#v", q, err) + return err + } + } + return checkQueryError(op.collection, data) +} + +// run duplicates the behavior of collection.Find(query).One(&result) +// as performed by Database.Run, specializing the logic for running +// database commands on a given socket. +func (db *Database) run(socket *mongoSocket, cmd, result interface{}) (err error) { + // Database.Run: + if name, ok := cmd.(string); ok { + cmd = bson.D{{name, 1}} + } + + // Collection.Find: + session := db.Session + session.m.RLock() + op := session.queryConfig.op // Copy. + session.m.RUnlock() + op.query = cmd + op.collection = db.Name + ".$cmd" + + // Query.One: + op.flags |= session.slaveOkFlag() + op.limit = -1 + + data, err := socket.SimpleQuery(&op) + if err != nil { + return err + } + if data == nil { + return ErrNotFound + } + if result != nil { + err = bson.Unmarshal(data, result) + if err == nil { + var res bson.M + bson.Unmarshal(data, &res) + debugf("Run command unmarshaled: %#v, result: %#v", op, res) + } else { + debugf("Run command unmarshaling failed: %#v", op, err) + return err + } + } + return checkQueryError(op.collection, data) +} + +// The DBRef type implements support for the database reference MongoDB +// convention as supported by multiple drivers. This convention enables +// cross-referencing documents between collections and databases using +// a structure which includes a collection name, a document id, and +// optionally a database name. +// +// See the FindRef methods on Session and on Database. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/Database+References +// +type DBRef struct { + Collection string `bson:"$ref"` + Id interface{} `bson:"$id"` + Database string `bson:"$db,omitempty"` +} + +// NOTE: Order of fields for DBRef above does matter, per documentation. + +// FindRef returns a query that looks for the document in the provided +// reference. If the reference includes the DB field, the document will +// be retrieved from the respective database. +// +// See also the DBRef type and the FindRef method on Session. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/Database+References +// +func (db *Database) FindRef(ref *DBRef) *Query { + var c *Collection + if ref.Database == "" { + c = db.C(ref.Collection) + } else { + c = db.Session.DB(ref.Database).C(ref.Collection) + } + return c.FindId(ref.Id) +} + +// FindRef returns a query that looks for the document in the provided +// reference. For a DBRef to be resolved correctly at the session level +// it must necessarily have the optional DB field defined. +// +// See also the DBRef type and the FindRef method on Database. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/Database+References +// +func (s *Session) FindRef(ref *DBRef) *Query { + if ref.Database == "" { + panic(errors.New(fmt.Sprintf("Can't resolve database for %#v", ref))) + } + c := s.DB(ref.Database).C(ref.Collection) + return c.FindId(ref.Id) +} + +// CollectionNames returns the collection names present in the db database. +func (db *Database) CollectionNames() (names []string, err error) { + // Clone session and set it to Monotonic mode so that the server + // used for the query may be safely obtained afterwards, if + // necessary for iteration when a cursor is received. + session := db.Session + cloned := session.Clone() + cloned.SetMode(Monotonic, false) + defer cloned.Close() + + batchSize := int(cloned.queryConfig.op.limit) + + // Try with a command. + var result struct { + Collections []bson.Raw + + Cursor struct { + FirstBatch []bson.Raw "firstBatch" + NS string + Id int64 + } + } + err = db.With(cloned).Run(bson.D{{"listCollections", 1}, {"cursor", bson.D{{"batchSize", batchSize}}}}, &result) + if err == nil { + firstBatch := result.Collections + if firstBatch == nil { + firstBatch = result.Cursor.FirstBatch + } + var iter *Iter + ns := strings.SplitN(result.Cursor.NS, ".", 2) + if len(ns) < 2 { + iter = db.With(cloned).C("").NewIter(nil, firstBatch, result.Cursor.Id, nil) + } else { + iter = cloned.DB(ns[0]).C(ns[1]).NewIter(nil, firstBatch, result.Cursor.Id, nil) + } + var coll struct{ Name string } + for iter.Next(&coll) { + names = append(names, coll.Name) + } + if err := iter.Close(); err != nil { + return nil, err + } + sort.Strings(names) + return names, err + } + if err != nil && !isNoCmd(err) { + return nil, err + } + + // Command not yet supported. Query the database instead. + nameIndex := len(db.Name) + 1 + iter := db.C("system.namespaces").Find(nil).Iter() + var coll struct{ Name string } + for iter.Next(&coll) { + if strings.Index(coll.Name, "$") < 0 || strings.Index(coll.Name, ".oplog.$") >= 0 { + names = append(names, coll.Name[nameIndex:]) + } + } + if err := iter.Close(); err != nil { + return nil, err + } + sort.Strings(names) + return names, nil +} + +type dbNames struct { + Databases []struct { + Name string + Empty bool + } +} + +// DatabaseNames returns the names of non-empty databases present in the cluster. +func (s *Session) DatabaseNames() (names []string, err error) { + var result dbNames + err = s.Run("listDatabases", &result) + if err != nil { + return nil, err + } + for _, db := range result.Databases { + if !db.Empty { + names = append(names, db.Name) + } + } + sort.Strings(names) + return names, nil +} + +// Iter executes the query and returns an iterator capable of going over all +// the results. Results will be returned in batches of configurable +// size (see the Batch method) and more documents will be requested when a +// configurable number of documents is iterated over (see the Prefetch method). +func (q *Query) Iter() *Iter { + q.m.Lock() + session := q.session + op := q.op + prefetch := q.prefetch + limit := q.limit + q.m.Unlock() + + iter := &Iter{ + session: session, + prefetch: prefetch, + limit: limit, + timeout: -1, + } + iter.gotReply.L = &iter.m + iter.op.collection = op.collection + iter.op.limit = op.limit + iter.op.replyFunc = iter.replyFunc() + iter.docsToReceive++ + op.replyFunc = iter.op.replyFunc + op.flags |= session.slaveOkFlag() + + socket, err := session.acquireSocket(true) + if err != nil { + iter.err = err + } else { + iter.server = socket.Server() + err = socket.Query(&op) + if err != nil { + // Must lock as the query above may call replyFunc. + iter.m.Lock() + iter.err = err + iter.m.Unlock() + } + socket.Release() + } + return iter +} + +// Tail returns a tailable iterator. Unlike a normal iterator, a +// tailable iterator may wait for new values to be inserted in the +// collection once the end of the current result set is reached, +// A tailable iterator may only be used with capped collections. +// +// The timeout parameter indicates how long Next will block waiting +// for a result before timing out. If set to -1, Next will not +// timeout, and will continue waiting for a result for as long as +// the cursor is valid and the session is not closed. If set to 0, +// Next times out as soon as it reaches the end of the result set. +// Otherwise, Next will wait for at least the given number of +// seconds for a new document to be available before timing out. +// +// On timeouts, Next will unblock and return false, and the Timeout +// method will return true if called. In these cases, Next may still +// be called again on the same iterator to check if a new value is +// available at the current cursor position, and again it will block +// according to the specified timeoutSecs. If the cursor becomes +// invalid, though, both Next and Timeout will return false and +// the query must be restarted. +// +// The following example demonstrates timeout handling and query +// restarting: +// +// iter := collection.Find(nil).Sort("$natural").Tail(5 * time.Second) +// for { +// for iter.Next(&result) { +// fmt.Println(result.Id) +// lastId = result.Id +// } +// if iter.Err() != nil { +// return iter.Close() +// } +// if iter.Timeout() { +// continue +// } +// query := collection.Find(bson.M{"_id": bson.M{"$gt": lastId}}) +// iter = query.Sort("$natural").Tail(5 * time.Second) +// } +// iter.Close() +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/Tailable+Cursors +// http://www.mongodb.org/display/DOCS/Capped+Collections +// http://www.mongodb.org/display/DOCS/Sorting+and+Natural+Order +// +func (q *Query) Tail(timeout time.Duration) *Iter { + q.m.Lock() + session := q.session + op := q.op + prefetch := q.prefetch + q.m.Unlock() + + iter := &Iter{session: session, prefetch: prefetch} + iter.gotReply.L = &iter.m + iter.timeout = timeout + iter.op.collection = op.collection + iter.op.limit = op.limit + iter.op.replyFunc = iter.replyFunc() + iter.docsToReceive++ + op.replyFunc = iter.op.replyFunc + op.flags |= flagTailable | flagAwaitData | session.slaveOkFlag() + + socket, err := session.acquireSocket(true) + if err != nil { + iter.err = err + } else { + iter.server = socket.Server() + err = socket.Query(&op) + if err != nil { + // Must lock as the query above may call replyFunc. + iter.m.Lock() + iter.err = err + iter.m.Unlock() + } + socket.Release() + } + return iter +} + +func (s *Session) slaveOkFlag() (flag queryOpFlags) { + s.m.RLock() + if s.slaveOk { + flag = flagSlaveOk + } + s.m.RUnlock() + return +} + +// Err returns nil if no errors happened during iteration, or the actual +// error otherwise. +// +// In case a resulting document included a field named $err or errmsg, which are +// standard ways for MongoDB to report an improper query, the returned value has +// a *QueryError type, and includes the Err message and the Code. +func (iter *Iter) Err() error { + iter.m.Lock() + err := iter.err + iter.m.Unlock() + if err == ErrNotFound { + return nil + } + return err +} + +// Close kills the server cursor used by the iterator, if any, and returns +// nil if no errors happened during iteration, or the actual error otherwise. +// +// Server cursors are automatically closed at the end of an iteration, which +// means close will do nothing unless the iteration was interrupted before +// the server finished sending results to the driver. If Close is not called +// in such a situation, the cursor will remain available at the server until +// the default cursor timeout period is reached. No further problems arise. +// +// Close is idempotent. That means it can be called repeatedly and will +// return the same result every time. +// +// In case a resulting document included a field named $err or errmsg, which are +// standard ways for MongoDB to report an improper query, the returned value has +// a *QueryError type. +func (iter *Iter) Close() error { + iter.m.Lock() + cursorId := iter.op.cursorId + iter.op.cursorId = 0 + err := iter.err + iter.m.Unlock() + if cursorId == 0 { + if err == ErrNotFound { + return nil + } + return err + } + socket, err := iter.acquireSocket() + if err == nil { + // TODO Batch kills. + err = socket.Query(&killCursorsOp{[]int64{cursorId}}) + socket.Release() + } + + iter.m.Lock() + if err != nil && (iter.err == nil || iter.err == ErrNotFound) { + iter.err = err + } else if iter.err != ErrNotFound { + err = iter.err + } + iter.m.Unlock() + return err +} + +// Timeout returns true if Next returned false due to a timeout of +// a tailable cursor. In those cases, Next may be called again to continue +// the iteration at the previous cursor position. +func (iter *Iter) Timeout() bool { + iter.m.Lock() + result := iter.timedout + iter.m.Unlock() + return result +} + +// Next retrieves the next document from the result set, blocking if necessary. +// This method will also automatically retrieve another batch of documents from +// the server when the current one is exhausted, or before that in background +// if pre-fetching is enabled (see the Query.Prefetch and Session.SetPrefetch +// methods). +// +// Next returns true if a document was successfully unmarshalled onto result, +// and false at the end of the result set or if an error happened. +// When Next returns false, the Err method should be called to verify if +// there was an error during iteration. +// +// For example: +// +// iter := collection.Find(nil).Iter() +// for iter.Next(&result) { +// fmt.Printf("Result: %v\n", result.Id) +// } +// if err := iter.Close(); err != nil { +// return err +// } +// +func (iter *Iter) Next(result interface{}) bool { + iter.m.Lock() + iter.timedout = false + timeout := time.Time{} + for iter.err == nil && iter.docData.Len() == 0 && (iter.docsToReceive > 0 || iter.op.cursorId != 0) { + if iter.docsToReceive == 0 { + if iter.timeout >= 0 { + if timeout.IsZero() { + timeout = time.Now().Add(iter.timeout) + } + if time.Now().After(timeout) { + iter.timedout = true + iter.m.Unlock() + return false + } + } + iter.getMore() + if iter.err != nil { + break + } + } + iter.gotReply.Wait() + } + + // Exhaust available data before reporting any errors. + if docData, ok := iter.docData.Pop().([]byte); ok { + close := false + if iter.limit > 0 { + iter.limit-- + if iter.limit == 0 { + if iter.docData.Len() > 0 { + iter.m.Unlock() + panic(fmt.Errorf("data remains after limit exhausted: %d", iter.docData.Len())) + } + iter.err = ErrNotFound + close = true + } + } + if iter.op.cursorId != 0 && iter.err == nil { + iter.docsBeforeMore-- + if iter.docsBeforeMore == -1 { + iter.getMore() + } + } + iter.m.Unlock() + + if close { + iter.Close() + } + err := bson.Unmarshal(docData, result) + if err != nil { + debugf("Iter %p document unmarshaling failed: %#v", iter, err) + iter.m.Lock() + if iter.err == nil { + iter.err = err + } + iter.m.Unlock() + return false + } + debugf("Iter %p document unmarshaled: %#v", iter, result) + // XXX Only have to check first document for a query error? + err = checkQueryError(iter.op.collection, docData) + if err != nil { + iter.m.Lock() + if iter.err == nil { + iter.err = err + } + iter.m.Unlock() + return false + } + return true + } else if iter.err != nil { + debugf("Iter %p returning false: %s", iter, iter.err) + iter.m.Unlock() + return false + } else if iter.op.cursorId == 0 { + iter.err = ErrNotFound + debugf("Iter %p exhausted with cursor=0", iter) + iter.m.Unlock() + return false + } + + panic("unreachable") +} + +// All retrieves all documents from the result set into the provided slice +// and closes the iterator. +// +// The result argument must necessarily be the address for a slice. The slice +// may be nil or previously allocated. +// +// WARNING: Obviously, All must not be used with result sets that may be +// potentially large, since it may consume all memory until the system +// crashes. Consider building the query with a Limit clause to ensure the +// result size is bounded. +// +// For instance: +// +// var result []struct{ Value int } +// iter := collection.Find(nil).Limit(100).Iter() +// err := iter.All(&result) +// if err != nil { +// return err +// } +// +func (iter *Iter) All(result interface{}) error { + resultv := reflect.ValueOf(result) + if resultv.Kind() != reflect.Ptr || resultv.Elem().Kind() != reflect.Slice { + panic("result argument must be a slice address") + } + slicev := resultv.Elem() + slicev = slicev.Slice(0, slicev.Cap()) + elemt := slicev.Type().Elem() + i := 0 + for { + if slicev.Len() == i { + elemp := reflect.New(elemt) + if !iter.Next(elemp.Interface()) { + break + } + slicev = reflect.Append(slicev, elemp.Elem()) + slicev = slicev.Slice(0, slicev.Cap()) + } else { + if !iter.Next(slicev.Index(i).Addr().Interface()) { + break + } + } + i++ + } + resultv.Elem().Set(slicev.Slice(0, i)) + return iter.Close() +} + +// All works like Iter.All. +func (q *Query) All(result interface{}) error { + return q.Iter().All(result) +} + +// The For method is obsolete and will be removed in a future release. +// See Iter as an elegant replacement. +func (q *Query) For(result interface{}, f func() error) error { + return q.Iter().For(result, f) +} + +// The For method is obsolete and will be removed in a future release. +// See Iter as an elegant replacement. +func (iter *Iter) For(result interface{}, f func() error) (err error) { + valid := false + v := reflect.ValueOf(result) + if v.Kind() == reflect.Ptr { + v = v.Elem() + switch v.Kind() { + case reflect.Map, reflect.Ptr, reflect.Interface, reflect.Slice: + valid = v.IsNil() + } + } + if !valid { + panic("For needs a pointer to nil reference value. See the documentation.") + } + zero := reflect.Zero(v.Type()) + for { + v.Set(zero) + if !iter.Next(result) { + break + } + err = f() + if err != nil { + return err + } + } + return iter.Err() +} + +// acquireSocket acquires a socket from the same server that the iterator +// cursor was obtained from. +// +// WARNING: This method must not be called with iter.m locked. Acquiring the +// socket depends on the cluster sync loop, and the cluster sync loop might +// attempt actions which cause replyFunc to be called, inducing a deadlock. +func (iter *Iter) acquireSocket() (*mongoSocket, error) { + socket, err := iter.session.acquireSocket(true) + if err != nil { + return nil, err + } + if socket.Server() != iter.server { + // Socket server changed during iteration. This may happen + // with Eventual sessions, if a Refresh is done, or if a + // monotonic session gets a write and shifts from secondary + // to primary. Our cursor is in a specific server, though. + iter.session.m.Lock() + sockTimeout := iter.session.sockTimeout + iter.session.m.Unlock() + socket.Release() + socket, _, err = iter.server.AcquireSocket(0, sockTimeout) + if err != nil { + return nil, err + } + err := iter.session.socketLogin(socket) + if err != nil { + socket.Release() + return nil, err + } + } + return socket, nil +} + +func (iter *Iter) getMore() { + // Increment now so that unlocking the iterator won't cause a + // different goroutine to get here as well. + iter.docsToReceive++ + iter.m.Unlock() + socket, err := iter.acquireSocket() + iter.m.Lock() + if err != nil { + iter.err = err + return + } + defer socket.Release() + + debugf("Iter %p requesting more documents", iter) + if iter.limit > 0 { + // The -1 below accounts for the fact docsToReceive was incremented above. + limit := iter.limit - int32(iter.docsToReceive-1) - int32(iter.docData.Len()) + if limit < iter.op.limit { + iter.op.limit = limit + } + } + if err := socket.Query(&iter.op); err != nil { + iter.docsToReceive-- + iter.err = err + } +} + +type countCmd struct { + Count string + Query interface{} + Limit int32 ",omitempty" + Skip int32 ",omitempty" +} + +// Count returns the total number of documents in the result set. +func (q *Query) Count() (n int, err error) { + q.m.Lock() + session := q.session + op := q.op + limit := q.limit + q.m.Unlock() + + c := strings.Index(op.collection, ".") + if c < 0 { + return 0, errors.New("Bad collection name: " + op.collection) + } + + dbname := op.collection[:c] + cname := op.collection[c+1:] + query := op.query + if query == nil { + query = bson.D{} + } + result := struct{ N int }{} + err = session.DB(dbname).Run(countCmd{cname, query, limit, op.skip}, &result) + return result.N, err +} + +// Count returns the total number of documents in the collection. +func (c *Collection) Count() (n int, err error) { + return c.Find(nil).Count() +} + +type distinctCmd struct { + Collection string "distinct" + Key string + Query interface{} ",omitempty" +} + +// Distinct unmarshals into result the list of distinct values for the given key. +// +// For example: +// +// var result []int +// err := collection.Find(bson.M{"gender": "F"}).Distinct("age", &result) +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/Aggregation +// +func (q *Query) Distinct(key string, result interface{}) error { + q.m.Lock() + session := q.session + op := q.op // Copy. + q.m.Unlock() + + c := strings.Index(op.collection, ".") + if c < 0 { + return errors.New("Bad collection name: " + op.collection) + } + + dbname := op.collection[:c] + cname := op.collection[c+1:] + + var doc struct{ Values bson.Raw } + err := session.DB(dbname).Run(distinctCmd{cname, key, op.query}, &doc) + if err != nil { + return err + } + return doc.Values.Unmarshal(result) +} + +type mapReduceCmd struct { + Collection string "mapreduce" + Map string ",omitempty" + Reduce string ",omitempty" + Finalize string ",omitempty" + Limit int32 ",omitempty" + Out interface{} + Query interface{} ",omitempty" + Sort interface{} ",omitempty" + Scope interface{} ",omitempty" + Verbose bool ",omitempty" +} + +type mapReduceResult struct { + Results bson.Raw + Result bson.Raw + TimeMillis int64 "timeMillis" + Counts struct{ Input, Emit, Output int } + Ok bool + Err string + Timing *MapReduceTime +} + +type MapReduce struct { + Map string // Map Javascript function code (required) + Reduce string // Reduce Javascript function code (required) + Finalize string // Finalize Javascript function code (optional) + Out interface{} // Output collection name or document. If nil, results are inlined into the result parameter. + Scope interface{} // Optional global scope for Javascript functions + Verbose bool +} + +type MapReduceInfo struct { + InputCount int // Number of documents mapped + EmitCount int // Number of times reduce called emit + OutputCount int // Number of documents in resulting collection + Database string // Output database, if results are not inlined + Collection string // Output collection, if results are not inlined + Time int64 // Time to run the job, in nanoseconds + VerboseTime *MapReduceTime // Only defined if Verbose was true +} + +type MapReduceTime struct { + Total int64 // Total time, in nanoseconds + Map int64 "mapTime" // Time within map function, in nanoseconds + EmitLoop int64 "emitLoop" // Time within the emit/map loop, in nanoseconds +} + +// MapReduce executes a map/reduce job for documents covered by the query. +// That kind of job is suitable for very flexible bulk aggregation of data +// performed at the server side via Javascript functions. +// +// Results from the job may be returned as a result of the query itself +// through the result parameter in case they'll certainly fit in memory +// and in a single document. If there's the possibility that the amount +// of data might be too large, results must be stored back in an alternative +// collection or even a separate database, by setting the Out field of the +// provided MapReduce job. In that case, provide nil as the result parameter. +// +// These are some of the ways to set Out: +// +// nil +// Inline results into the result parameter. +// +// bson.M{"replace": "mycollection"} +// The output will be inserted into a collection which replaces any +// existing collection with the same name. +// +// bson.M{"merge": "mycollection"} +// This option will merge new data into the old output collection. In +// other words, if the same key exists in both the result set and the +// old collection, the new key will overwrite the old one. +// +// bson.M{"reduce": "mycollection"} +// If documents exist for a given key in the result set and in the old +// collection, then a reduce operation (using the specified reduce +// function) will be performed on the two values and the result will be +// written to the output collection. If a finalize function was +// provided, this will be run after the reduce as well. +// +// bson.M{...., "db": "mydb"} +// Any of the above options can have the "db" key included for doing +// the respective action in a separate database. +// +// The following is a trivial example which will count the number of +// occurrences of a field named n on each document in a collection, and +// will return results inline: +// +// job := &mgo.MapReduce{ +// Map: "function() { emit(this.n, 1) }", +// Reduce: "function(key, values) { return Array.sum(values) }", +// } +// var result []struct { Id int "_id"; Value int } +// _, err := collection.Find(nil).MapReduce(job, &result) +// if err != nil { +// return err +// } +// for _, item := range result { +// fmt.Println(item.Value) +// } +// +// This function is compatible with MongoDB 1.7.4+. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/MapReduce +// +func (q *Query) MapReduce(job *MapReduce, result interface{}) (info *MapReduceInfo, err error) { + q.m.Lock() + session := q.session + op := q.op // Copy. + limit := q.limit + q.m.Unlock() + + c := strings.Index(op.collection, ".") + if c < 0 { + return nil, errors.New("Bad collection name: " + op.collection) + } + + dbname := op.collection[:c] + cname := op.collection[c+1:] + + cmd := mapReduceCmd{ + Collection: cname, + Map: job.Map, + Reduce: job.Reduce, + Finalize: job.Finalize, + Out: fixMROut(job.Out), + Scope: job.Scope, + Verbose: job.Verbose, + Query: op.query, + Sort: op.options.OrderBy, + Limit: limit, + } + + if cmd.Out == nil { + cmd.Out = bson.D{{"inline", 1}} + } + + var doc mapReduceResult + err = session.DB(dbname).Run(&cmd, &doc) + if err != nil { + return nil, err + } + if doc.Err != "" { + return nil, errors.New(doc.Err) + } + + info = &MapReduceInfo{ + InputCount: doc.Counts.Input, + EmitCount: doc.Counts.Emit, + OutputCount: doc.Counts.Output, + Time: doc.TimeMillis * 1e6, + } + + if doc.Result.Kind == 0x02 { + err = doc.Result.Unmarshal(&info.Collection) + info.Database = dbname + } else if doc.Result.Kind == 0x03 { + var v struct{ Collection, Db string } + err = doc.Result.Unmarshal(&v) + info.Collection = v.Collection + info.Database = v.Db + } + + if doc.Timing != nil { + info.VerboseTime = doc.Timing + info.VerboseTime.Total *= 1e6 + info.VerboseTime.Map *= 1e6 + info.VerboseTime.EmitLoop *= 1e6 + } + + if err != nil { + return nil, err + } + if result != nil { + return info, doc.Results.Unmarshal(result) + } + return info, nil +} + +// The "out" option in the MapReduce command must be ordered. This was +// found after the implementation was accepting maps for a long time, +// so rather than breaking the API, we'll fix the order if necessary. +// Details about the order requirement may be seen in MongoDB's code: +// +// http://goo.gl/L8jwJX +// +func fixMROut(out interface{}) interface{} { + outv := reflect.ValueOf(out) + if outv.Kind() != reflect.Map || outv.Type().Key() != reflect.TypeOf("") { + return out + } + outs := make(bson.D, outv.Len()) + + outTypeIndex := -1 + for i, k := range outv.MapKeys() { + ks := k.String() + outs[i].Name = ks + outs[i].Value = outv.MapIndex(k).Interface() + switch ks { + case "normal", "replace", "merge", "reduce", "inline": + outTypeIndex = i + } + } + if outTypeIndex > 0 { + outs[0], outs[outTypeIndex] = outs[outTypeIndex], outs[0] + } + return outs +} + +// Change holds fields for running a findAndModify MongoDB command via +// the Query.Apply method. +type Change struct { + Update interface{} // The update document + Upsert bool // Whether to insert in case the document isn't found + Remove bool // Whether to remove the document found rather than updating + ReturnNew bool // Should the modified document be returned rather than the old one +} + +type findModifyCmd struct { + Collection string "findAndModify" + Query, Update, Sort, Fields interface{} ",omitempty" + Upsert, Remove, New bool ",omitempty" +} + +type valueResult struct { + Value bson.Raw + LastError LastError "lastErrorObject" +} + +// Apply runs the findAndModify MongoDB command, which allows updating, upserting +// or removing a document matching a query and atomically returning either the old +// version (the default) or the new version of the document (when ReturnNew is true). +// If no objects are found Apply returns ErrNotFound. +// +// The Sort and Select query methods affect the result of Apply. In case +// multiple documents match the query, Sort enables selecting which document to +// act upon by ordering it first. Select enables retrieving only a selection +// of fields of the new or old document. +// +// This simple example increments a counter and prints its new value: +// +// change := mgo.Change{ +// Update: bson.M{"$inc": bson.M{"n": 1}}, +// ReturnNew: true, +// } +// info, err = col.Find(M{"_id": id}).Apply(change, &doc) +// fmt.Println(doc.N) +// +// This method depends on MongoDB >= 2.0 to work properly. +// +// Relevant documentation: +// +// http://www.mongodb.org/display/DOCS/findAndModify+Command +// http://www.mongodb.org/display/DOCS/Updating +// http://www.mongodb.org/display/DOCS/Atomic+Operations +// +func (q *Query) Apply(change Change, result interface{}) (info *ChangeInfo, err error) { + q.m.Lock() + session := q.session + op := q.op // Copy. + q.m.Unlock() + + c := strings.Index(op.collection, ".") + if c < 0 { + return nil, errors.New("bad collection name: " + op.collection) + } + + dbname := op.collection[:c] + cname := op.collection[c+1:] + + cmd := findModifyCmd{ + Collection: cname, + Update: change.Update, + Upsert: change.Upsert, + Remove: change.Remove, + New: change.ReturnNew, + Query: op.query, + Sort: op.options.OrderBy, + Fields: op.selector, + } + + session = session.Clone() + defer session.Close() + session.SetMode(Strong, false) + + var doc valueResult + err = session.DB(dbname).Run(&cmd, &doc) + if err != nil { + if qerr, ok := err.(*QueryError); ok && qerr.Message == "No matching object found" { + return nil, ErrNotFound + } + return nil, err + } + if doc.LastError.N == 0 { + return nil, ErrNotFound + } + if doc.Value.Kind != 0x0A && result != nil { + err = doc.Value.Unmarshal(result) + if err != nil { + return nil, err + } + } + info = &ChangeInfo{} + lerr := &doc.LastError + if lerr.UpdatedExisting { + info.Updated = lerr.N + } else if change.Remove { + info.Removed = lerr.N + } else if change.Upsert { + info.UpsertedId = lerr.UpsertedId + } + return info, nil +} + +// The BuildInfo type encapsulates details about the running MongoDB server. +// +// Note that the VersionArray field was introduced in MongoDB 2.0+, but it is +// internally assembled from the Version information for previous versions. +// In both cases, VersionArray is guaranteed to have at least 4 entries. +type BuildInfo struct { + Version string + VersionArray []int `bson:"versionArray"` // On MongoDB 2.0+; assembled from Version otherwise + GitVersion string `bson:"gitVersion"` + OpenSSLVersion string `bson:"OpenSSLVersion"` + SysInfo string `bson:"sysInfo"` + Bits int + Debug bool + MaxObjectSize int `bson:"maxBsonObjectSize"` +} + +// VersionAtLeast returns whether the BuildInfo version is greater than or +// equal to the provided version number. If more than one number is +// provided, numbers will be considered as major, minor, and so on. +func (bi *BuildInfo) VersionAtLeast(version ...int) bool { + for i := range version { + if i == len(bi.VersionArray) { + return false + } + if bi.VersionArray[i] < version[i] { + return false + } + } + return true +} + +// BuildInfo retrieves the version and other details about the +// running MongoDB server. +func (s *Session) BuildInfo() (info BuildInfo, err error) { + err = s.Run(bson.D{{"buildInfo", "1"}}, &info) + if len(info.VersionArray) == 0 { + for _, a := range strings.Split(info.Version, ".") { + i, err := strconv.Atoi(a) + if err != nil { + break + } + info.VersionArray = append(info.VersionArray, i) + } + } + for len(info.VersionArray) < 4 { + info.VersionArray = append(info.VersionArray, 0) + } + if i := strings.IndexByte(info.GitVersion, ' '); i >= 0 { + // Strip off the " modules: enterprise" suffix. This is a _git version_. + // That information may be moved to another field if people need it. + info.GitVersion = info.GitVersion[:i] + } + return +} + +// --------------------------------------------------------------------------- +// Internal session handling helpers. + +func (s *Session) acquireSocket(slaveOk bool) (*mongoSocket, error) { + + // Read-only lock to check for previously reserved socket. + s.m.RLock() + if s.masterSocket != nil { + socket := s.masterSocket + socket.Acquire() + s.m.RUnlock() + return socket, nil + } + if s.slaveSocket != nil && s.slaveOk && slaveOk { + socket := s.slaveSocket + socket.Acquire() + s.m.RUnlock() + return socket, nil + } + s.m.RUnlock() + + // No go. We may have to request a new socket and change the session, + // so try again but with an exclusive lock now. + s.m.Lock() + defer s.m.Unlock() + + if s.masterSocket != nil { + s.masterSocket.Acquire() + return s.masterSocket, nil + } + if s.slaveSocket != nil && s.slaveOk && slaveOk { + s.slaveSocket.Acquire() + return s.slaveSocket, nil + } + + // Still not good. We need a new socket. + sock, err := s.cluster().AcquireSocket(slaveOk && s.slaveOk, s.syncTimeout, s.sockTimeout, s.queryConfig.op.serverTags, s.poolLimit) + if err != nil { + return nil, err + } + + // Authenticate the new socket. + if err = s.socketLogin(sock); err != nil { + sock.Release() + return nil, err + } + + // Keep track of the new socket, if necessary. + // Note that, as a special case, if the Eventual session was + // not refreshed (s.slaveSocket != nil), it means the developer + // asked to preserve an existing reserved socket, so we'll + // keep a master one around too before a Refresh happens. + if s.consistency != Eventual || s.slaveSocket != nil { + s.setSocket(sock) + } + + // Switch over a Monotonic session to the master. + if !slaveOk && s.consistency == Monotonic { + s.slaveOk = false + } + + return sock, nil +} + +// setSocket binds socket to this section. +func (s *Session) setSocket(socket *mongoSocket) { + info := socket.Acquire() + if info.Master { + if s.masterSocket != nil { + panic("setSocket(master) with existing master socket reserved") + } + s.masterSocket = socket + } else { + if s.slaveSocket != nil { + panic("setSocket(slave) with existing slave socket reserved") + } + s.slaveSocket = socket + } +} + +// unsetSocket releases any slave and/or master sockets reserved. +func (s *Session) unsetSocket() { + if s.masterSocket != nil { + s.masterSocket.Release() + } + if s.slaveSocket != nil { + s.slaveSocket.Release() + } + s.masterSocket = nil + s.slaveSocket = nil +} + +func (iter *Iter) replyFunc() replyFunc { + return func(err error, op *replyOp, docNum int, docData []byte) { + iter.m.Lock() + iter.docsToReceive-- + if err != nil { + iter.err = err + debugf("Iter %p received an error: %s", iter, err.Error()) + } else if docNum == -1 { + debugf("Iter %p received no documents (cursor=%d).", iter, op.cursorId) + if op != nil && op.cursorId != 0 { + // It's a tailable cursor. + iter.op.cursorId = op.cursorId + } else if op != nil && op.cursorId == 0 && op.flags&1 == 1 { + // Cursor likely timed out. + iter.err = ErrCursor + } else { + iter.err = ErrNotFound + } + } else { + rdocs := int(op.replyDocs) + if docNum == 0 { + iter.docsToReceive += rdocs - 1 + docsToProcess := iter.docData.Len() + rdocs + if iter.limit == 0 || int32(docsToProcess) < iter.limit { + iter.docsBeforeMore = docsToProcess - int(iter.prefetch*float64(rdocs)) + } else { + iter.docsBeforeMore = -1 + } + iter.op.cursorId = op.cursorId + } + // XXX Handle errors and flags. + debugf("Iter %p received reply document %d/%d (cursor=%d)", iter, docNum+1, rdocs, op.cursorId) + iter.docData.Push(docData) + } + iter.gotReply.Broadcast() + iter.m.Unlock() + } +} + +type writeCmdResult struct { + Ok bool + N int + NModified int `bson:"nModified"` + Upserted []struct { + Index int + Id interface{} `_id` + } + Errors []struct { + Ok bool + Index int + Code int + N int + ErrMsg string + } `bson:"writeErrors"` + ConcernError struct { + Code int + ErrMsg string + } `bson:"writeConcernError"` +} + +// writeQuery runs the given modifying operation, potentially followed up +// by a getLastError command in case the session is in safe mode. The +// LastError result is made available in lerr, and if lerr.Err is set it +// will also be returned as err. +func (c *Collection) writeQuery(op interface{}) (lerr *LastError, err error) { + s := c.Database.Session + dbname := c.Database.Name + socket, err := s.acquireSocket(dbname == "local") + if err != nil { + return nil, err + } + defer socket.Release() + + s.m.RLock() + safeOp := s.safeOp + s.m.RUnlock() + + // TODO Enable this path for wire version 2 as well. + if socket.ServerInfo().MaxWireVersion >= 3 { + // Servers with a more recent write protocol benefit from write commands. + if op, ok := op.(*insertOp); ok && len(op.documents) > 1000 { + var firstErr error + // Maximum batch size is 1000. Must split out in separate operations for compatibility. + all := op.documents + for i := 0; i < len(all); i += 1000 { + l := i + 1000 + if l > len(all) { + l = len(all) + } + op.documents = all[i:l] + _, err := c.writeCommand(socket, safeOp, op) + if err != nil { + if op.flags&1 != 0 { + if firstErr == nil { + firstErr = err + } + } else { + return nil, err + } + } + } + return nil, firstErr + } + return c.writeCommand(socket, safeOp, op) + } + + if safeOp == nil { + return nil, socket.Query(op) + } + + var mutex sync.Mutex + var replyData []byte + var replyErr error + mutex.Lock() + query := *safeOp // Copy the data. + query.collection = dbname + ".$cmd" + query.replyFunc = func(err error, reply *replyOp, docNum int, docData []byte) { + replyData = docData + replyErr = err + mutex.Unlock() + } + err = socket.Query(op, &query) + if err != nil { + return nil, err + } + mutex.Lock() // Wait. + if replyErr != nil { + return nil, replyErr // XXX TESTME + } + if hasErrMsg(replyData) { + // Looks like getLastError itself failed. + err = checkQueryError(query.collection, replyData) + if err != nil { + return nil, err + } + } + result := &LastError{} + bson.Unmarshal(replyData, &result) + debugf("Result from writing query: %#v", result) + if result.Err != "" { + return result, result + } + return result, nil +} + +func (c *Collection) writeCommand(socket *mongoSocket, safeOp *queryOp, op interface{}) (lerr *LastError, err error) { + var writeConcern interface{} + if safeOp == nil { + writeConcern = bson.D{{"w", 0}} + } else { + writeConcern = safeOp.query.(*getLastError) + } + + var cmd bson.D + switch op := op.(type) { + case *insertOp: + // http://docs.mongodb.org/manual/reference/command/insert + cmd = bson.D{ + {"insert", c.Name}, + {"documents", op.documents}, + {"writeConcern", writeConcern}, + {"ordered", op.flags&1 == 0}, + } + case *updateOp: + // http://docs.mongodb.org/manual/reference/command/update + selector := op.selector + if selector == nil { + selector = bson.D{} + } + cmd = bson.D{ + {"update", c.Name}, + {"updates", []bson.D{{{"q", selector}, {"u", op.update}, {"upsert", op.flags&1 != 0}, {"multi", op.flags&2 != 0}}}}, + {"writeConcern", writeConcern}, + //{"ordered", }, + } + case *deleteOp: + // http://docs.mongodb.org/manual/reference/command/delete + selector := op.selector + if selector == nil { + selector = bson.D{} + } + cmd = bson.D{ + {"delete", c.Name}, + {"deletes", []bson.D{{{"q", selector}, {"limit", op.flags & 1}}}}, + {"writeConcern", writeConcern}, + //{"ordered", }, + } + } + + var result writeCmdResult + err = c.Database.run(socket, cmd, &result) + debugf("Write command result: %#v (err=%v)", result, err) + lerr = &LastError{ + UpdatedExisting: result.N > 0 && len(result.Upserted) == 0, + N: result.N, + } + if len(result.Upserted) > 0 { + lerr.UpsertedId = result.Upserted[0].Id + } + if len(result.Errors) > 0 { + e := result.Errors[0] + if !e.Ok { + lerr.Code = e.Code + lerr.Err = e.ErrMsg + err = lerr + } + } else if result.ConcernError.Code != 0 { + e := result.ConcernError + lerr.Code = e.Code + lerr.Err = e.ErrMsg + err = lerr + } + + if err == nil && safeOp == nil { + return nil, nil + } + return lerr, err +} + +func hasErrMsg(d []byte) bool { + l := len(d) + for i := 0; i+8 < l; i++ { + if d[i] == '\x02' && d[i+1] == 'e' && d[i+2] == 'r' && d[i+3] == 'r' && d[i+4] == 'm' && d[i+5] == 's' && d[i+6] == 'g' && d[i+7] == '\x00' { + return true + } + } + return false +} diff --git a/vendor/gopkg.in/mgo.v2/session_test.go b/vendor/gopkg.in/mgo.v2/session_test.go new file mode 100644 index 000000000..3ec5550a4 --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/session_test.go @@ -0,0 +1,3704 @@ +// mgo - MongoDB driver for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package mgo_test + +import ( + "flag" + "fmt" + "math" + "reflect" + "runtime" + "sort" + "strconv" + "strings" + "time" + + . "gopkg.in/check.v1" + "gopkg.in/mgo.v2" + "gopkg.in/mgo.v2/bson" +) + +func (s *S) TestRunString(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + result := struct{ Ok int }{} + err = session.Run("ping", &result) + c.Assert(err, IsNil) + c.Assert(result.Ok, Equals, 1) +} + +func (s *S) TestRunValue(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + result := struct{ Ok int }{} + err = session.Run(M{"ping": 1}, &result) + c.Assert(err, IsNil) + c.Assert(result.Ok, Equals, 1) +} + +func (s *S) TestPing(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + // Just ensure the nonce has been received. + result := struct{}{} + err = session.Run("ping", &result) + + mgo.ResetStats() + + err = session.Ping() + c.Assert(err, IsNil) + + // Pretty boring. + stats := mgo.GetStats() + c.Assert(stats.SentOps, Equals, 1) + c.Assert(stats.ReceivedOps, Equals, 1) +} + +func (s *S) TestURLSingle(c *C) { + session, err := mgo.Dial("mongodb://localhost:40001/") + c.Assert(err, IsNil) + defer session.Close() + + result := struct{ Ok int }{} + err = session.Run("ping", &result) + c.Assert(err, IsNil) + c.Assert(result.Ok, Equals, 1) +} + +func (s *S) TestURLMany(c *C) { + session, err := mgo.Dial("mongodb://localhost:40011,localhost:40012/") + c.Assert(err, IsNil) + defer session.Close() + + result := struct{ Ok int }{} + err = session.Run("ping", &result) + c.Assert(err, IsNil) + c.Assert(result.Ok, Equals, 1) +} + +func (s *S) TestURLParsing(c *C) { + urls := []string{ + "localhost:40001?foo=1&bar=2", + "localhost:40001?foo=1;bar=2", + } + for _, url := range urls { + session, err := mgo.Dial(url) + if session != nil { + session.Close() + } + c.Assert(err, ErrorMatches, "unsupported connection URL option: (foo=1|bar=2)") + } +} + +func (s *S) TestInsertFindOne(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"a": 1, "b": 2}) + c.Assert(err, IsNil) + err = coll.Insert(M{"a": 1, "b": 3}) + c.Assert(err, IsNil) + + result := struct{ A, B int }{} + + err = coll.Find(M{"a": 1}).Sort("b").One(&result) + c.Assert(err, IsNil) + c.Assert(result.A, Equals, 1) + c.Assert(result.B, Equals, 2) +} + +func (s *S) TestInsertFindOneNil(c *C) { + session, err := mgo.Dial("localhost:40002") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + err = coll.Find(nil).One(nil) + c.Assert(err, ErrorMatches, "unauthorized.*|not authorized.*") +} + +func (s *S) TestInsertFindOneMap(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"a": 1, "b": 2}) + c.Assert(err, IsNil) + result := make(M) + err = coll.Find(M{"a": 1}).One(result) + c.Assert(err, IsNil) + c.Assert(result["a"], Equals, 1) + c.Assert(result["b"], Equals, 2) +} + +func (s *S) TestInsertFindAll(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"a": 1, "b": 2}) + c.Assert(err, IsNil) + err = coll.Insert(M{"a": 3, "b": 4}) + c.Assert(err, IsNil) + + type R struct{ A, B int } + var result []R + + assertResult := func() { + c.Assert(len(result), Equals, 2) + c.Assert(result[0].A, Equals, 1) + c.Assert(result[0].B, Equals, 2) + c.Assert(result[1].A, Equals, 3) + c.Assert(result[1].B, Equals, 4) + } + + // nil slice + err = coll.Find(nil).Sort("a").All(&result) + c.Assert(err, IsNil) + assertResult() + + // Previously allocated slice + allocd := make([]R, 5) + result = allocd + err = coll.Find(nil).Sort("a").All(&result) + c.Assert(err, IsNil) + assertResult() + + // Ensure result is backed by the originally allocated array + c.Assert(&result[0], Equals, &allocd[0]) + + // Non-pointer slice error + f := func() { coll.Find(nil).All(result) } + c.Assert(f, Panics, "result argument must be a slice address") + + // Non-slice error + f = func() { coll.Find(nil).All(new(int)) } + c.Assert(f, Panics, "result argument must be a slice address") +} + +func (s *S) TestFindRef(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + db1 := session.DB("db1") + db1col1 := db1.C("col1") + + db2 := session.DB("db2") + db2col1 := db2.C("col1") + + err = db1col1.Insert(M{"_id": 1, "n": 1}) + c.Assert(err, IsNil) + err = db1col1.Insert(M{"_id": 2, "n": 2}) + c.Assert(err, IsNil) + err = db2col1.Insert(M{"_id": 2, "n": 3}) + c.Assert(err, IsNil) + + result := struct{ N int }{} + + ref1 := &mgo.DBRef{Collection: "col1", Id: 1} + ref2 := &mgo.DBRef{Collection: "col1", Id: 2, Database: "db2"} + + err = db1.FindRef(ref1).One(&result) + c.Assert(err, IsNil) + c.Assert(result.N, Equals, 1) + + err = db1.FindRef(ref2).One(&result) + c.Assert(err, IsNil) + c.Assert(result.N, Equals, 3) + + err = db2.FindRef(ref1).One(&result) + c.Assert(err, Equals, mgo.ErrNotFound) + + err = db2.FindRef(ref2).One(&result) + c.Assert(err, IsNil) + c.Assert(result.N, Equals, 3) + + err = session.FindRef(ref2).One(&result) + c.Assert(err, IsNil) + c.Assert(result.N, Equals, 3) + + f := func() { session.FindRef(ref1).One(&result) } + c.Assert(f, PanicMatches, "Can't resolve database for &mgo.DBRef{Collection:\"col1\", Id:1, Database:\"\"}") +} + +func (s *S) TestDatabaseAndCollectionNames(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + db1 := session.DB("db1") + db1col1 := db1.C("col1") + db1col2 := db1.C("col2") + + db2 := session.DB("db2") + db2col1 := db2.C("col3") + + err = db1col1.Insert(M{"_id": 1}) + c.Assert(err, IsNil) + err = db1col2.Insert(M{"_id": 1}) + c.Assert(err, IsNil) + err = db2col1.Insert(M{"_id": 1}) + c.Assert(err, IsNil) + + names, err := session.DatabaseNames() + c.Assert(err, IsNil) + if !reflect.DeepEqual(names, []string{"db1", "db2"}) { + // 2.4+ has "local" as well. + c.Assert(names, DeepEquals, []string{"db1", "db2", "local"}) + } + + // Try to exercise cursor logic. 2.8.0-rc3 still ignores this. + session.SetBatch(2) + + names, err = db1.CollectionNames() + c.Assert(err, IsNil) + c.Assert(names, DeepEquals, []string{"col1", "col2", "system.indexes"}) + + names, err = db2.CollectionNames() + c.Assert(err, IsNil) + c.Assert(names, DeepEquals, []string{"col3", "system.indexes"}) +} + +func (s *S) TestSelect(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + coll.Insert(M{"a": 1, "b": 2}) + + result := struct{ A, B int }{} + + err = coll.Find(M{"a": 1}).Select(M{"b": 1}).One(&result) + c.Assert(err, IsNil) + c.Assert(result.A, Equals, 0) + c.Assert(result.B, Equals, 2) +} + +func (s *S) TestInlineMap(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + var v, result1 struct { + A int + M map[string]int ",inline" + } + + v.A = 1 + v.M = map[string]int{"b": 2} + err = coll.Insert(v) + c.Assert(err, IsNil) + + noId := M{"_id": 0} + + err = coll.Find(nil).Select(noId).One(&result1) + c.Assert(err, IsNil) + c.Assert(result1.A, Equals, 1) + c.Assert(result1.M, DeepEquals, map[string]int{"b": 2}) + + var result2 M + err = coll.Find(nil).Select(noId).One(&result2) + c.Assert(err, IsNil) + c.Assert(result2, DeepEquals, M{"a": 1, "b": 2}) + +} + +func (s *S) TestUpdate(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + ns := []int{40, 41, 42, 43, 44, 45, 46} + for _, n := range ns { + err := coll.Insert(M{"k": n, "n": n}) + c.Assert(err, IsNil) + } + + // No changes is a no-op and shouldn't return an error. + err = coll.Update(M{"k": 42}, M{"$set": M{"n": 42}}) + c.Assert(err, IsNil) + + err = coll.Update(M{"k": 42}, M{"$inc": M{"n": 1}}) + c.Assert(err, IsNil) + + result := make(M) + err = coll.Find(M{"k": 42}).One(result) + c.Assert(err, IsNil) + c.Assert(result["n"], Equals, 43) + + err = coll.Update(M{"k": 47}, M{"k": 47, "n": 47}) + c.Assert(err, Equals, mgo.ErrNotFound) + + err = coll.Find(M{"k": 47}).One(result) + c.Assert(err, Equals, mgo.ErrNotFound) +} + +func (s *S) TestUpdateId(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + ns := []int{40, 41, 42, 43, 44, 45, 46} + for _, n := range ns { + err := coll.Insert(M{"_id": n, "n": n}) + c.Assert(err, IsNil) + } + + err = coll.UpdateId(42, M{"$inc": M{"n": 1}}) + c.Assert(err, IsNil) + + result := make(M) + err = coll.FindId(42).One(result) + c.Assert(err, IsNil) + c.Assert(result["n"], Equals, 43) + + err = coll.UpdateId(47, M{"k": 47, "n": 47}) + c.Assert(err, Equals, mgo.ErrNotFound) + + err = coll.FindId(47).One(result) + c.Assert(err, Equals, mgo.ErrNotFound) +} + +func (s *S) TestUpdateNil(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + err = coll.Insert(M{"k": 42, "n": 42}) + c.Assert(err, IsNil) + err = coll.Update(nil, M{"$inc": M{"n": 1}}) + c.Assert(err, IsNil) + + result := make(M) + err = coll.Find(M{"k": 42}).One(result) + c.Assert(err, IsNil) + c.Assert(result["n"], Equals, 43) + + err = coll.Insert(M{"k": 45, "n": 45}) + c.Assert(err, IsNil) + _, err = coll.UpdateAll(nil, M{"$inc": M{"n": 1}}) + c.Assert(err, IsNil) + + err = coll.Find(M{"k": 42}).One(result) + c.Assert(err, IsNil) + c.Assert(result["n"], Equals, 44) + err = coll.Find(M{"k": 45}).One(result) + c.Assert(err, IsNil) + c.Assert(result["n"], Equals, 46) +} + +func (s *S) TestUpsert(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + ns := []int{40, 41, 42, 43, 44, 45, 46} + for _, n := range ns { + err := coll.Insert(M{"k": n, "n": n}) + c.Assert(err, IsNil) + } + + info, err := coll.Upsert(M{"k": 42}, M{"k": 42, "n": 24}) + c.Assert(err, IsNil) + c.Assert(info.Updated, Equals, 1) + c.Assert(info.UpsertedId, IsNil) + + result := M{} + err = coll.Find(M{"k": 42}).One(result) + c.Assert(err, IsNil) + c.Assert(result["n"], Equals, 24) + + // Insert with internally created id. + info, err = coll.Upsert(M{"k": 47}, M{"k": 47, "n": 47}) + c.Assert(err, IsNil) + c.Assert(info.Updated, Equals, 0) + c.Assert(info.UpsertedId, NotNil) + + err = coll.Find(M{"k": 47}).One(result) + c.Assert(err, IsNil) + c.Assert(result["n"], Equals, 47) + + result = M{} + err = coll.Find(M{"_id": info.UpsertedId}).One(result) + c.Assert(err, IsNil) + c.Assert(result["n"], Equals, 47) + + // Insert with provided id. + info, err = coll.Upsert(M{"k": 48}, M{"k": 48, "n": 48, "_id": 48}) + c.Assert(err, IsNil) + c.Assert(info.Updated, Equals, 0) + if s.versionAtLeast(2, 6) { + c.Assert(info.UpsertedId, Equals, 48) + } else { + c.Assert(info.UpsertedId, IsNil) // Unfortunate, but that's what Mongo gave us. + } + + err = coll.Find(M{"k": 48}).One(result) + c.Assert(err, IsNil) + c.Assert(result["n"], Equals, 48) +} + +func (s *S) TestUpsertId(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + ns := []int{40, 41, 42, 43, 44, 45, 46} + for _, n := range ns { + err := coll.Insert(M{"_id": n, "n": n}) + c.Assert(err, IsNil) + } + + info, err := coll.UpsertId(42, M{"n": 24}) + c.Assert(err, IsNil) + c.Assert(info.Updated, Equals, 1) + c.Assert(info.UpsertedId, IsNil) + + result := M{} + err = coll.FindId(42).One(result) + c.Assert(err, IsNil) + c.Assert(result["n"], Equals, 24) + + info, err = coll.UpsertId(47, M{"_id": 47, "n": 47}) + c.Assert(err, IsNil) + c.Assert(info.Updated, Equals, 0) + if s.versionAtLeast(2, 6) { + c.Assert(info.UpsertedId, Equals, 47) + } else { + c.Assert(info.UpsertedId, IsNil) + } + + err = coll.FindId(47).One(result) + c.Assert(err, IsNil) + c.Assert(result["n"], Equals, 47) +} + +func (s *S) TestUpdateAll(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + ns := []int{40, 41, 42, 43, 44, 45, 46} + for _, n := range ns { + err := coll.Insert(M{"k": n, "n": n}) + c.Assert(err, IsNil) + } + + // Don't actually modify the documents. Should still report 4 matching updates. + info, err := coll.UpdateAll(M{"k": M{"$gt": 42}}, M{"$unset": M{"missing": 1}}) + c.Assert(err, IsNil) + c.Assert(info.Updated, Equals, 4) + + info, err = coll.UpdateAll(M{"k": M{"$gt": 42}}, M{"$inc": M{"n": 1}}) + c.Assert(err, IsNil) + c.Assert(info.Updated, Equals, 4) + + result := make(M) + err = coll.Find(M{"k": 42}).One(result) + c.Assert(err, IsNil) + c.Assert(result["n"], Equals, 42) + + err = coll.Find(M{"k": 43}).One(result) + c.Assert(err, IsNil) + c.Assert(result["n"], Equals, 44) + + err = coll.Find(M{"k": 44}).One(result) + c.Assert(err, IsNil) + c.Assert(result["n"], Equals, 45) + + if !s.versionAtLeast(2, 6) { + // 2.6 made this invalid. + info, err = coll.UpdateAll(M{"k": 47}, M{"k": 47, "n": 47}) + c.Assert(err, Equals, nil) + c.Assert(info.Updated, Equals, 0) + } +} + +func (s *S) TestRemove(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + ns := []int{40, 41, 42, 43, 44, 45, 46} + for _, n := range ns { + err := coll.Insert(M{"n": n}) + c.Assert(err, IsNil) + } + + err = coll.Remove(M{"n": M{"$gt": 42}}) + c.Assert(err, IsNil) + + result := &struct{ N int }{} + err = coll.Find(M{"n": 42}).One(result) + c.Assert(err, IsNil) + c.Assert(result.N, Equals, 42) + + err = coll.Find(M{"n": 43}).One(result) + c.Assert(err, Equals, mgo.ErrNotFound) + + err = coll.Find(M{"n": 44}).One(result) + c.Assert(err, IsNil) + c.Assert(result.N, Equals, 44) +} + +func (s *S) TestRemoveId(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + err = coll.Insert(M{"_id": 40}, M{"_id": 41}, M{"_id": 42}) + c.Assert(err, IsNil) + + err = coll.RemoveId(41) + c.Assert(err, IsNil) + + c.Assert(coll.FindId(40).One(nil), IsNil) + c.Assert(coll.FindId(41).One(nil), Equals, mgo.ErrNotFound) + c.Assert(coll.FindId(42).One(nil), IsNil) +} + +func (s *S) TestRemoveUnsafe(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + session.SetSafe(nil) + + coll := session.DB("mydb").C("mycoll") + + err = coll.Insert(M{"_id": 40}, M{"_id": 41}, M{"_id": 42}) + c.Assert(err, IsNil) + + err = coll.RemoveId(41) + c.Assert(err, IsNil) + + c.Assert(coll.FindId(40).One(nil), IsNil) + c.Assert(coll.FindId(41).One(nil), Equals, mgo.ErrNotFound) + c.Assert(coll.FindId(42).One(nil), IsNil) +} + +func (s *S) TestRemoveAll(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + ns := []int{40, 41, 42, 43, 44, 45, 46} + for _, n := range ns { + err := coll.Insert(M{"n": n}) + c.Assert(err, IsNil) + } + + info, err := coll.RemoveAll(M{"n": M{"$gt": 42}}) + c.Assert(err, IsNil) + c.Assert(info.Updated, Equals, 0) + c.Assert(info.Removed, Equals, 4) + c.Assert(info.UpsertedId, IsNil) + + result := &struct{ N int }{} + err = coll.Find(M{"n": 42}).One(result) + c.Assert(err, IsNil) + c.Assert(result.N, Equals, 42) + + err = coll.Find(M{"n": 43}).One(result) + c.Assert(err, Equals, mgo.ErrNotFound) + + err = coll.Find(M{"n": 44}).One(result) + c.Assert(err, Equals, mgo.ErrNotFound) + + info, err = coll.RemoveAll(nil) + c.Assert(err, IsNil) + c.Assert(info.Updated, Equals, 0) + c.Assert(info.Removed, Equals, 3) + c.Assert(info.UpsertedId, IsNil) + + n, err := coll.Find(nil).Count() + c.Assert(err, IsNil) + c.Assert(n, Equals, 0) +} + +func (s *S) TestDropDatabase(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + db1 := session.DB("db1") + db1.C("col").Insert(M{"_id": 1}) + + db2 := session.DB("db2") + db2.C("col").Insert(M{"_id": 1}) + + err = db1.DropDatabase() + c.Assert(err, IsNil) + + names, err := session.DatabaseNames() + c.Assert(err, IsNil) + if !reflect.DeepEqual(names, []string{"db2"}) { + // 2.4+ has "local" as well. + c.Assert(names, DeepEquals, []string{"db2", "local"}) + } + + err = db2.DropDatabase() + c.Assert(err, IsNil) + + names, err = session.DatabaseNames() + c.Assert(err, IsNil) + if !reflect.DeepEqual(names, []string(nil)) { + // 2.4+ has "local" as well. + c.Assert(names, DeepEquals, []string{"local"}) + } +} + +func (s *S) TestDropCollection(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + db := session.DB("db1") + db.C("col1").Insert(M{"_id": 1}) + db.C("col2").Insert(M{"_id": 1}) + + err = db.C("col1").DropCollection() + c.Assert(err, IsNil) + + names, err := db.CollectionNames() + c.Assert(err, IsNil) + c.Assert(names, DeepEquals, []string{"col2", "system.indexes"}) + + err = db.C("col2").DropCollection() + c.Assert(err, IsNil) + + names, err = db.CollectionNames() + c.Assert(err, IsNil) + c.Assert(names, DeepEquals, []string{"system.indexes"}) +} + +func (s *S) TestCreateCollectionCapped(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + info := &mgo.CollectionInfo{ + Capped: true, + MaxBytes: 1024, + MaxDocs: 3, + } + err = coll.Create(info) + c.Assert(err, IsNil) + + ns := []int{1, 2, 3, 4, 5} + for _, n := range ns { + err := coll.Insert(M{"n": n}) + c.Assert(err, IsNil) + } + + n, err := coll.Find(nil).Count() + c.Assert(err, IsNil) + c.Assert(n, Equals, 3) +} + +func (s *S) TestCreateCollectionNoIndex(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + info := &mgo.CollectionInfo{ + DisableIdIndex: true, + } + err = coll.Create(info) + c.Assert(err, IsNil) + + err = coll.Insert(M{"n": 1}) + c.Assert(err, IsNil) + + indexes, err := coll.Indexes() + c.Assert(indexes, HasLen, 0) +} + +func (s *S) TestCreateCollectionForceIndex(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + info := &mgo.CollectionInfo{ + ForceIdIndex: true, + Capped: true, + MaxBytes: 1024, + } + err = coll.Create(info) + c.Assert(err, IsNil) + + err = coll.Insert(M{"n": 1}) + c.Assert(err, IsNil) + + indexes, err := coll.Indexes() + c.Assert(indexes, HasLen, 1) +} + +func (s *S) TestIsDupValues(c *C) { + c.Assert(mgo.IsDup(nil), Equals, false) + c.Assert(mgo.IsDup(&mgo.LastError{Code: 1}), Equals, false) + c.Assert(mgo.IsDup(&mgo.QueryError{Code: 1}), Equals, false) + c.Assert(mgo.IsDup(&mgo.LastError{Code: 11000}), Equals, true) + c.Assert(mgo.IsDup(&mgo.QueryError{Code: 11000}), Equals, true) + c.Assert(mgo.IsDup(&mgo.LastError{Code: 11001}), Equals, true) + c.Assert(mgo.IsDup(&mgo.QueryError{Code: 11001}), Equals, true) + c.Assert(mgo.IsDup(&mgo.LastError{Code: 12582}), Equals, true) + c.Assert(mgo.IsDup(&mgo.QueryError{Code: 12582}), Equals, true) + lerr := &mgo.LastError{Code: 16460, Err: "error inserting 1 documents to shard ... caused by :: E11000 duplicate key error index: ..."} + c.Assert(mgo.IsDup(lerr), Equals, true) +} + +func (s *S) TestIsDupPrimary(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + err = coll.Insert(M{"_id": 1}) + c.Assert(err, IsNil) + err = coll.Insert(M{"_id": 1}) + c.Assert(err, ErrorMatches, ".*duplicate key error.*") + c.Assert(mgo.IsDup(err), Equals, true) +} + +func (s *S) TestIsDupUnique(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + index := mgo.Index{ + Key: []string{"a", "b"}, + Unique: true, + } + + coll := session.DB("mydb").C("mycoll") + + err = coll.EnsureIndex(index) + c.Assert(err, IsNil) + + err = coll.Insert(M{"a": 1, "b": 1}) + c.Assert(err, IsNil) + err = coll.Insert(M{"a": 1, "b": 1}) + c.Assert(err, ErrorMatches, ".*duplicate key error.*") + c.Assert(mgo.IsDup(err), Equals, true) +} + +func (s *S) TestIsDupCapped(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + info := &mgo.CollectionInfo{ + ForceIdIndex: true, + Capped: true, + MaxBytes: 1024, + } + err = coll.Create(info) + c.Assert(err, IsNil) + + err = coll.Insert(M{"_id": 1}) + c.Assert(err, IsNil) + err = coll.Insert(M{"_id": 1}) + // The error was different for capped collections before 2.6. + c.Assert(err, ErrorMatches, ".*duplicate key.*") + // The issue is reduced by using IsDup. + c.Assert(mgo.IsDup(err), Equals, true) +} + +func (s *S) TestIsDupFindAndModify(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + err = coll.EnsureIndex(mgo.Index{Key: []string{"n"}, Unique: true}) + c.Assert(err, IsNil) + + err = coll.Insert(M{"n": 1}) + c.Assert(err, IsNil) + err = coll.Insert(M{"n": 2}) + c.Assert(err, IsNil) + _, err = coll.Find(M{"n": 1}).Apply(mgo.Change{Update: M{"$inc": M{"n": 1}}}, bson.M{}) + c.Assert(err, ErrorMatches, ".*duplicate key error.*") + c.Assert(mgo.IsDup(err), Equals, true) +} + +func (s *S) TestFindAndModify(c *C) { + session, err := mgo.Dial("localhost:40011") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + err = coll.Insert(M{"n": 42}) + + session.SetMode(mgo.Monotonic, true) + + result := M{} + info, err := coll.Find(M{"n": 42}).Apply(mgo.Change{Update: M{"$inc": M{"n": 1}}}, result) + c.Assert(err, IsNil) + c.Assert(result["n"], Equals, 42) + c.Assert(info.Updated, Equals, 1) + c.Assert(info.Removed, Equals, 0) + c.Assert(info.UpsertedId, IsNil) + + // A nil result parameter should be acceptable. + info, err = coll.Find(M{"n": 43}).Apply(mgo.Change{Update: M{"$unset": M{"missing": 1}}}, nil) + c.Assert(err, IsNil) + c.Assert(info.Updated, Equals, 1) + c.Assert(info.Removed, Equals, 0) + c.Assert(info.UpsertedId, IsNil) + + result = M{} + info, err = coll.Find(M{"n": 43}).Apply(mgo.Change{Update: M{"$inc": M{"n": 1}}, ReturnNew: true}, result) + c.Assert(err, IsNil) + c.Assert(result["n"], Equals, 44) + c.Assert(info.Updated, Equals, 1) + c.Assert(info.Removed, Equals, 0) + c.Assert(info.UpsertedId, IsNil) + + result = M{} + info, err = coll.Find(M{"n": 50}).Apply(mgo.Change{Upsert: true, Update: M{"n": 51, "o": 52}}, result) + c.Assert(err, IsNil) + c.Assert(result["n"], IsNil) + c.Assert(info.Updated, Equals, 0) + c.Assert(info.Removed, Equals, 0) + c.Assert(info.UpsertedId, NotNil) + + result = M{} + info, err = coll.Find(nil).Sort("-n").Apply(mgo.Change{Update: M{"$inc": M{"n": 1}}, ReturnNew: true}, result) + c.Assert(err, IsNil) + c.Assert(result["n"], Equals, 52) + c.Assert(info.Updated, Equals, 1) + c.Assert(info.Removed, Equals, 0) + c.Assert(info.UpsertedId, IsNil) + + result = M{} + info, err = coll.Find(M{"n": 52}).Select(M{"o": 1}).Apply(mgo.Change{Remove: true}, result) + c.Assert(err, IsNil) + c.Assert(result["n"], IsNil) + c.Assert(result["o"], Equals, 52) + c.Assert(info.Updated, Equals, 0) + c.Assert(info.Removed, Equals, 1) + c.Assert(info.UpsertedId, IsNil) + + result = M{} + info, err = coll.Find(M{"n": 60}).Apply(mgo.Change{Remove: true}, result) + c.Assert(err, Equals, mgo.ErrNotFound) + c.Assert(len(result), Equals, 0) + c.Assert(info, IsNil) +} + +func (s *S) TestFindAndModifyBug997828(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + err = coll.Insert(M{"n": "not-a-number"}) + + result := make(M) + _, err = coll.Find(M{"n": "not-a-number"}).Apply(mgo.Change{Update: M{"$inc": M{"n": 1}}}, result) + c.Assert(err, ErrorMatches, `(exception: )?Cannot apply \$inc .*`) + if s.versionAtLeast(2, 1) { + qerr, _ := err.(*mgo.QueryError) + c.Assert(qerr, NotNil, Commentf("err: %#v", err)) + if s.versionAtLeast(2, 6) { + // Oh, the dance of error codes. :-( + c.Assert(qerr.Code, Equals, 16837) + } else { + c.Assert(qerr.Code, Equals, 10140) + } + } else { + lerr, _ := err.(*mgo.LastError) + c.Assert(lerr, NotNil, Commentf("err: %#v", err)) + c.Assert(lerr.Code, Equals, 10140) + } +} + +func (s *S) TestCountCollection(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + ns := []int{40, 41, 42} + for _, n := range ns { + err := coll.Insert(M{"n": n}) + c.Assert(err, IsNil) + } + + n, err := coll.Count() + c.Assert(err, IsNil) + c.Assert(n, Equals, 3) +} + +func (s *S) TestCountQuery(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + ns := []int{40, 41, 42} + for _, n := range ns { + err := coll.Insert(M{"n": n}) + c.Assert(err, IsNil) + } + + n, err := coll.Find(M{"n": M{"$gt": 40}}).Count() + c.Assert(err, IsNil) + c.Assert(n, Equals, 2) +} + +func (s *S) TestCountQuerySorted(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + ns := []int{40, 41, 42} + for _, n := range ns { + err := coll.Insert(M{"n": n}) + c.Assert(err, IsNil) + } + + n, err := coll.Find(M{"n": M{"$gt": 40}}).Sort("n").Count() + c.Assert(err, IsNil) + c.Assert(n, Equals, 2) +} + +func (s *S) TestCountSkipLimit(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + ns := []int{40, 41, 42, 43, 44} + for _, n := range ns { + err := coll.Insert(M{"n": n}) + c.Assert(err, IsNil) + } + + n, err := coll.Find(nil).Skip(1).Limit(3).Count() + c.Assert(err, IsNil) + c.Assert(n, Equals, 3) + + n, err = coll.Find(nil).Skip(1).Limit(5).Count() + c.Assert(err, IsNil) + c.Assert(n, Equals, 4) +} + +func (s *S) TestQueryExplain(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + ns := []int{40, 41, 42} + for _, n := range ns { + err := coll.Insert(M{"n": n}) + c.Assert(err, IsNil) + } + + m := M{} + query := coll.Find(nil).Limit(2) + err = query.Explain(m) + c.Assert(err, IsNil) + if m["queryPlanner"] != nil { + c.Assert(m["executionStats"].(M)["totalDocsExamined"], Equals, 2) + } else { + c.Assert(m["cursor"], Equals, "BasicCursor") + c.Assert(m["nscanned"], Equals, 2) + c.Assert(m["n"], Equals, 2) + } + + n := 0 + var result M + iter := query.Iter() + for iter.Next(&result) { + n++ + } + c.Assert(iter.Close(), IsNil) + c.Assert(n, Equals, 2) +} + +func (s *S) TestQuerySetMaxScan(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + coll := session.DB("mydb").C("mycoll") + + ns := []int{40, 41, 42} + for _, n := range ns { + err := coll.Insert(M{"n": n}) + c.Assert(err, IsNil) + } + + query := coll.Find(nil).SetMaxScan(2) + var result []M + err = query.All(&result) + c.Assert(err, IsNil) + c.Assert(result, HasLen, 2) +} + +func (s *S) TestQuerySetMaxTime(c *C) { + if !s.versionAtLeast(2, 6) { + c.Skip("SetMaxTime only supported in 2.6+") + } + + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + coll := session.DB("mydb").C("mycoll") + + for i := 0; i < 1000; i++ { + err := coll.Insert(M{"n": i}) + c.Assert(err, IsNil) + } + + query := coll.Find(nil) + query.SetMaxTime(1 * time.Millisecond) + query.Batch(2) + var result []M + err = query.All(&result) + c.Assert(err, ErrorMatches, "operation exceeded time limit") +} + +func (s *S) TestQueryHint(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + coll.EnsureIndexKey("a") + + m := M{} + err = coll.Find(nil).Hint("a").Explain(m) + c.Assert(err, IsNil) + + if m["queryPlanner"] != nil { + m = m["queryPlanner"].(M) + m = m["winningPlan"].(M) + m = m["inputStage"].(M) + c.Assert(m["indexName"], Equals, "a_1") + } else { + c.Assert(m["indexBounds"], NotNil) + c.Assert(m["indexBounds"].(M)["a"], NotNil) + } +} + +func (s *S) TestQueryComment(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + db := session.DB("mydb") + coll := db.C("mycoll") + + err = db.Run(bson.M{"profile": 2}, nil) + c.Assert(err, IsNil) + + ns := []int{40, 41, 42} + for _, n := range ns { + err := coll.Insert(M{"n": n}) + c.Assert(err, IsNil) + } + + query := coll.Find(bson.M{"n": 41}) + query.Comment("some comment") + err = query.One(nil) + c.Assert(err, IsNil) + + query = coll.Find(bson.M{"n": 41}) + query.Comment("another comment") + err = query.One(nil) + c.Assert(err, IsNil) + + n, err := session.DB("mydb").C("system.profile").Find(bson.M{"query.$query.n": 41, "query.$comment": "some comment"}).Count() + c.Assert(err, IsNil) + c.Assert(n, Equals, 1) +} + +func (s *S) TestFindOneNotFound(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + result := struct{ A, B int }{} + err = coll.Find(M{"a": 1}).One(&result) + c.Assert(err, Equals, mgo.ErrNotFound) + c.Assert(err, ErrorMatches, "not found") + c.Assert(err == mgo.ErrNotFound, Equals, true) +} + +func (s *S) TestFindNil(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"n": 1}) + c.Assert(err, IsNil) + + result := struct{ N int }{} + + err = coll.Find(nil).One(&result) + c.Assert(err, IsNil) + c.Assert(result.N, Equals, 1) +} + +func (s *S) TestFindId(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"_id": 41, "n": 41}) + c.Assert(err, IsNil) + err = coll.Insert(M{"_id": 42, "n": 42}) + c.Assert(err, IsNil) + + result := struct{ N int }{} + + err = coll.FindId(42).One(&result) + c.Assert(err, IsNil) + c.Assert(result.N, Equals, 42) +} + +func (s *S) TestFindIterAll(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + ns := []int{40, 41, 42, 43, 44, 45, 46} + for _, n := range ns { + coll.Insert(M{"n": n}) + } + + session.Refresh() // Release socket. + + mgo.ResetStats() + + iter := coll.Find(M{"n": M{"$gte": 42}}).Sort("$natural").Prefetch(0).Batch(2).Iter() + result := struct{ N int }{} + for i := 2; i < 7; i++ { + ok := iter.Next(&result) + c.Assert(ok, Equals, true) + c.Assert(result.N, Equals, ns[i]) + if i == 1 { + stats := mgo.GetStats() + c.Assert(stats.ReceivedDocs, Equals, 2) + } + } + + ok := iter.Next(&result) + c.Assert(ok, Equals, false) + c.Assert(iter.Close(), IsNil) + + session.Refresh() // Release socket. + + stats := mgo.GetStats() + c.Assert(stats.SentOps, Equals, 3) // 1*QUERY_OP + 2*GET_MORE_OP + c.Assert(stats.ReceivedOps, Equals, 3) // and their REPLY_OPs. + c.Assert(stats.ReceivedDocs, Equals, 5) + c.Assert(stats.SocketsInUse, Equals, 0) +} + +func (s *S) TestFindIterTwiceWithSameQuery(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + for i := 40; i != 47; i++ { + coll.Insert(M{"n": i}) + } + + query := coll.Find(M{}).Sort("n") + + result1 := query.Skip(1).Iter() + result2 := query.Skip(2).Iter() + + result := struct{ N int }{} + ok := result2.Next(&result) + c.Assert(ok, Equals, true) + c.Assert(result.N, Equals, 42) + ok = result1.Next(&result) + c.Assert(ok, Equals, true) + c.Assert(result.N, Equals, 41) +} + +func (s *S) TestFindIterWithoutResults(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + coll.Insert(M{"n": 42}) + + iter := coll.Find(M{"n": 0}).Iter() + + result := struct{ N int }{} + ok := iter.Next(&result) + c.Assert(ok, Equals, false) + c.Assert(iter.Close(), IsNil) + c.Assert(result.N, Equals, 0) +} + +func (s *S) TestFindIterLimit(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + ns := []int{40, 41, 42, 43, 44, 45, 46} + for _, n := range ns { + coll.Insert(M{"n": n}) + } + + session.Refresh() // Release socket. + + mgo.ResetStats() + + query := coll.Find(M{"n": M{"$gte": 42}}).Sort("$natural").Limit(3) + iter := query.Iter() + + result := struct{ N int }{} + for i := 2; i < 5; i++ { + ok := iter.Next(&result) + c.Assert(ok, Equals, true) + c.Assert(result.N, Equals, ns[i]) + } + + ok := iter.Next(&result) + c.Assert(ok, Equals, false) + c.Assert(iter.Close(), IsNil) + + session.Refresh() // Release socket. + + stats := mgo.GetStats() + c.Assert(stats.SentOps, Equals, 2) // 1*QUERY_OP + 1*KILL_CURSORS_OP + c.Assert(stats.ReceivedOps, Equals, 1) // and its REPLY_OP + c.Assert(stats.ReceivedDocs, Equals, 3) + c.Assert(stats.SocketsInUse, Equals, 0) +} + +var cursorTimeout = flag.Bool("cursor-timeout", false, "Enable cursor timeout test") + +func (s *S) TestFindIterCursorTimeout(c *C) { + if !*cursorTimeout { + c.Skip("-cursor-timeout") + } + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + type Doc struct { + Id int "_id" + } + + coll := session.DB("test").C("test") + coll.Remove(nil) + for i := 0; i < 100; i++ { + err = coll.Insert(Doc{i}) + c.Assert(err, IsNil) + } + + session.SetBatch(1) + iter := coll.Find(nil).Iter() + var doc Doc + if !iter.Next(&doc) { + c.Fatalf("iterator failed to return any documents") + } + + for i := 10; i > 0; i-- { + c.Logf("Sleeping... %d minutes to go...", i) + time.Sleep(1*time.Minute + 2*time.Second) + } + + // Drain any existing documents that were fetched. + if !iter.Next(&doc) { + c.Fatalf("iterator with timed out cursor failed to return previously cached document") + } + if iter.Next(&doc) { + c.Fatalf("timed out cursor returned document") + } + + c.Assert(iter.Err(), Equals, mgo.ErrCursor) +} + +func (s *S) TestTooManyItemsLimitBug(c *C) { + if *fast { + c.Skip("-fast") + } + + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + defer runtime.GOMAXPROCS(runtime.GOMAXPROCS(runtime.NumCPU())) + + mgo.SetDebug(false) + coll := session.DB("mydb").C("mycoll") + words := strings.Split("foo bar baz", " ") + for i := 0; i < 5; i++ { + words = append(words, words...) + } + doc := bson.D{{"words", words}} + inserts := 10000 + limit := 5000 + iters := 0 + c.Assert(inserts > limit, Equals, true) + for i := 0; i < inserts; i++ { + err := coll.Insert(&doc) + c.Assert(err, IsNil) + } + iter := coll.Find(nil).Limit(limit).Iter() + for iter.Next(&doc) { + if iters%100 == 0 { + c.Logf("Seen %d docments", iters) + } + iters++ + } + c.Assert(iter.Close(), IsNil) + c.Assert(iters, Equals, limit) +} + +func serverCursorsOpen(session *mgo.Session) int { + var result struct { + Cursors struct { + TotalOpen int `bson:"totalOpen"` + TimedOut int `bson:"timedOut"` + } + } + err := session.Run("serverStatus", &result) + if err != nil { + panic(err) + } + return result.Cursors.TotalOpen +} + +func (s *S) TestFindIterLimitWithMore(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + // Insane amounts of logging otherwise due to the + // amount of data being shuffled. + mgo.SetDebug(false) + defer mgo.SetDebug(true) + + // Should amount to more than 4MB bson payload, + // the default limit per result chunk. + const total = 4096 + var d struct{ A [1024]byte } + docs := make([]interface{}, total) + for i := 0; i < total; i++ { + docs[i] = &d + } + err = coll.Insert(docs...) + c.Assert(err, IsNil) + + n, err := coll.Count() + c.Assert(err, IsNil) + c.Assert(n, Equals, total) + + // First, try restricting to a single chunk with a negative limit. + nresults := 0 + iter := coll.Find(nil).Limit(-total).Iter() + var discard struct{} + for iter.Next(&discard) { + nresults++ + } + if nresults < total/2 || nresults >= total { + c.Fatalf("Bad result size with negative limit: %d", nresults) + } + + cursorsOpen := serverCursorsOpen(session) + + // Try again, with a positive limit. Should reach the end now, + // using multiple chunks. + nresults = 0 + iter = coll.Find(nil).Limit(total).Iter() + for iter.Next(&discard) { + nresults++ + } + c.Assert(nresults, Equals, total) + + // Ensure the cursor used is properly killed. + c.Assert(serverCursorsOpen(session), Equals, cursorsOpen) + + // Edge case, -MinInt == -MinInt. + nresults = 0 + iter = coll.Find(nil).Limit(math.MinInt32).Iter() + for iter.Next(&discard) { + nresults++ + } + if nresults < total/2 || nresults >= total { + c.Fatalf("Bad result size with MinInt32 limit: %d", nresults) + } +} + +func (s *S) TestFindIterLimitWithBatch(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + ns := []int{40, 41, 42, 43, 44, 45, 46} + for _, n := range ns { + coll.Insert(M{"n": n}) + } + + // Ping the database to ensure the nonce has been received already. + c.Assert(session.Ping(), IsNil) + + session.Refresh() // Release socket. + + mgo.ResetStats() + + query := coll.Find(M{"n": M{"$gte": 42}}).Sort("$natural").Limit(3).Batch(2) + iter := query.Iter() + result := struct{ N int }{} + for i := 2; i < 5; i++ { + ok := iter.Next(&result) + c.Assert(ok, Equals, true) + c.Assert(result.N, Equals, ns[i]) + if i == 3 { + stats := mgo.GetStats() + c.Assert(stats.ReceivedDocs, Equals, 2) + } + } + + ok := iter.Next(&result) + c.Assert(ok, Equals, false) + c.Assert(iter.Close(), IsNil) + + session.Refresh() // Release socket. + + stats := mgo.GetStats() + c.Assert(stats.SentOps, Equals, 3) // 1*QUERY_OP + 1*GET_MORE_OP + 1*KILL_CURSORS_OP + c.Assert(stats.ReceivedOps, Equals, 2) // and its REPLY_OPs + c.Assert(stats.ReceivedDocs, Equals, 3) + c.Assert(stats.SocketsInUse, Equals, 0) +} + +func (s *S) TestFindIterSortWithBatch(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + ns := []int{40, 41, 42, 43, 44, 45, 46} + for _, n := range ns { + coll.Insert(M{"n": n}) + } + + // Without this, the logic above breaks because Mongo refuses to + // return a cursor with an in-memory sort. + coll.EnsureIndexKey("n") + + // Ping the database to ensure the nonce has been received already. + c.Assert(session.Ping(), IsNil) + + session.Refresh() // Release socket. + + mgo.ResetStats() + + query := coll.Find(M{"n": M{"$lte": 44}}).Sort("-n").Batch(2) + iter := query.Iter() + ns = []int{46, 45, 44, 43, 42, 41, 40} + result := struct{ N int }{} + for i := 2; i < len(ns); i++ { + c.Logf("i=%d", i) + ok := iter.Next(&result) + c.Assert(ok, Equals, true) + c.Assert(result.N, Equals, ns[i]) + if i == 3 { + stats := mgo.GetStats() + c.Assert(stats.ReceivedDocs, Equals, 2) + } + } + + ok := iter.Next(&result) + c.Assert(ok, Equals, false) + c.Assert(iter.Close(), IsNil) + + session.Refresh() // Release socket. + + stats := mgo.GetStats() + c.Assert(stats.SentOps, Equals, 3) // 1*QUERY_OP + 2*GET_MORE_OP + c.Assert(stats.ReceivedOps, Equals, 3) // and its REPLY_OPs + c.Assert(stats.ReceivedDocs, Equals, 5) + c.Assert(stats.SocketsInUse, Equals, 0) +} + +// Test tailable cursors in a situation where Next has to sleep to +// respect the timeout requested on Tail. +func (s *S) TestFindTailTimeoutWithSleep(c *C) { + if *fast { + c.Skip("-fast") + } + + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + cresult := struct{ ErrMsg string }{} + + db := session.DB("mydb") + err = db.Run(bson.D{{"create", "mycoll"}, {"capped", true}, {"size", 1024}}, &cresult) + c.Assert(err, IsNil) + c.Assert(cresult.ErrMsg, Equals, "") + coll := db.C("mycoll") + + ns := []int{40, 41, 42, 43, 44, 45, 46} + for _, n := range ns { + coll.Insert(M{"n": n}) + } + + session.Refresh() // Release socket. + + mgo.ResetStats() + + timeout := 3 * time.Second + + query := coll.Find(M{"n": M{"$gte": 42}}).Sort("$natural").Prefetch(0).Batch(2) + iter := query.Tail(timeout) + + n := len(ns) + result := struct{ N int }{} + for i := 2; i != n; i++ { + ok := iter.Next(&result) + c.Assert(ok, Equals, true) + c.Assert(iter.Err(), IsNil) + c.Assert(iter.Timeout(), Equals, false) + c.Assert(result.N, Equals, ns[i]) + if i == 3 { // The batch boundary. + stats := mgo.GetStats() + c.Assert(stats.ReceivedDocs, Equals, 2) + } + } + + mgo.ResetStats() + + // The following call to Next will block. + go func() { + // The internal AwaitData timing of MongoDB is around 2 seconds, + // so this should force mgo to sleep at least once by itself to + // respect the requested timeout. + time.Sleep(timeout + 5e8*time.Nanosecond) + session := session.New() + defer session.Close() + coll := session.DB("mydb").C("mycoll") + coll.Insert(M{"n": 47}) + }() + + c.Log("Will wait for Next with N=47...") + ok := iter.Next(&result) + c.Assert(ok, Equals, true) + c.Assert(iter.Err(), IsNil) + c.Assert(iter.Timeout(), Equals, false) + c.Assert(result.N, Equals, 47) + c.Log("Got Next with N=47!") + + // The following may break because it depends a bit on the internal + // timing used by MongoDB's AwaitData logic. If it does, the problem + // will be observed as more GET_MORE_OPs than predicted: + // 1*QUERY for nonce + 1*GET_MORE_OP on Next + 1*GET_MORE_OP on Next after sleep + + // 1*INSERT_OP + 1*QUERY_OP for getLastError on insert of 47 + stats := mgo.GetStats() + if s.versionAtLeast(3, 0) { // TODO Will be 2.6 when write commands are enabled for it. + c.Assert(stats.SentOps, Equals, 4) + } else { + c.Assert(stats.SentOps, Equals, 5) + } + c.Assert(stats.ReceivedOps, Equals, 4) // REPLY_OPs for 1*QUERY_OP for nonce + 2*GET_MORE_OPs + 1*QUERY_OP + c.Assert(stats.ReceivedDocs, Equals, 3) // nonce + N=47 result + getLastError response + + c.Log("Will wait for a result which will never come...") + + started := time.Now() + ok = iter.Next(&result) + c.Assert(ok, Equals, false) + c.Assert(iter.Err(), IsNil) + c.Assert(iter.Timeout(), Equals, true) + c.Assert(started.Before(time.Now().Add(-timeout)), Equals, true) + + c.Log("Will now reuse the timed out tail cursor...") + + coll.Insert(M{"n": 48}) + ok = iter.Next(&result) + c.Assert(ok, Equals, true) + c.Assert(iter.Close(), IsNil) + c.Assert(iter.Timeout(), Equals, false) + c.Assert(result.N, Equals, 48) +} + +// Test tailable cursors in a situation where Next never gets to sleep once +// to respect the timeout requested on Tail. +func (s *S) TestFindTailTimeoutNoSleep(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + cresult := struct{ ErrMsg string }{} + + db := session.DB("mydb") + err = db.Run(bson.D{{"create", "mycoll"}, {"capped", true}, {"size", 1024}}, &cresult) + c.Assert(err, IsNil) + c.Assert(cresult.ErrMsg, Equals, "") + coll := db.C("mycoll") + + ns := []int{40, 41, 42, 43, 44, 45, 46} + for _, n := range ns { + coll.Insert(M{"n": n}) + } + + session.Refresh() // Release socket. + + mgo.ResetStats() + + timeout := 1 * time.Second + + query := coll.Find(M{"n": M{"$gte": 42}}).Sort("$natural").Prefetch(0).Batch(2) + iter := query.Tail(timeout) + + n := len(ns) + result := struct{ N int }{} + for i := 2; i != n; i++ { + ok := iter.Next(&result) + c.Assert(ok, Equals, true) + c.Assert(iter.Err(), IsNil) + c.Assert(iter.Timeout(), Equals, false) + c.Assert(result.N, Equals, ns[i]) + if i == 3 { // The batch boundary. + stats := mgo.GetStats() + c.Assert(stats.ReceivedDocs, Equals, 2) + } + } + + mgo.ResetStats() + + // The following call to Next will block. + go func() { + // The internal AwaitData timing of MongoDB is around 2 seconds, + // so this item should arrive within the AwaitData threshold. + time.Sleep(5e8) + session := session.New() + defer session.Close() + coll := session.DB("mydb").C("mycoll") + coll.Insert(M{"n": 47}) + }() + + c.Log("Will wait for Next with N=47...") + ok := iter.Next(&result) + c.Assert(ok, Equals, true) + c.Assert(iter.Err(), IsNil) + c.Assert(iter.Timeout(), Equals, false) + c.Assert(result.N, Equals, 47) + c.Log("Got Next with N=47!") + + // The following may break because it depends a bit on the internal + // timing used by MongoDB's AwaitData logic. If it does, the problem + // will be observed as more GET_MORE_OPs than predicted: + // 1*QUERY_OP for nonce + 1*GET_MORE_OP on Next + + // 1*INSERT_OP + 1*QUERY_OP for getLastError on insert of 47 + stats := mgo.GetStats() + if s.versionAtLeast(3, 0) { // TODO Will be 2.6 when write commands are enabled for it. + c.Assert(stats.SentOps, Equals, 3) + } else { + c.Assert(stats.SentOps, Equals, 4) + } + c.Assert(stats.ReceivedOps, Equals, 3) // REPLY_OPs for 1*QUERY_OP for nonce + 1*GET_MORE_OPs and 1*QUERY_OP + c.Assert(stats.ReceivedDocs, Equals, 3) // nonce + N=47 result + getLastError response + + c.Log("Will wait for a result which will never come...") + + started := time.Now() + ok = iter.Next(&result) + c.Assert(ok, Equals, false) + c.Assert(iter.Err(), IsNil) + c.Assert(iter.Timeout(), Equals, true) + c.Assert(started.Before(time.Now().Add(-timeout)), Equals, true) + + c.Log("Will now reuse the timed out tail cursor...") + + coll.Insert(M{"n": 48}) + ok = iter.Next(&result) + c.Assert(ok, Equals, true) + c.Assert(iter.Close(), IsNil) + c.Assert(iter.Timeout(), Equals, false) + c.Assert(result.N, Equals, 48) +} + +// Test tailable cursors in a situation where Next never gets to sleep once +// to respect the timeout requested on Tail. +func (s *S) TestFindTailNoTimeout(c *C) { + if *fast { + c.Skip("-fast") + } + + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + cresult := struct{ ErrMsg string }{} + + db := session.DB("mydb") + err = db.Run(bson.D{{"create", "mycoll"}, {"capped", true}, {"size", 1024}}, &cresult) + c.Assert(err, IsNil) + c.Assert(cresult.ErrMsg, Equals, "") + coll := db.C("mycoll") + + ns := []int{40, 41, 42, 43, 44, 45, 46} + for _, n := range ns { + coll.Insert(M{"n": n}) + } + + session.Refresh() // Release socket. + + mgo.ResetStats() + + query := coll.Find(M{"n": M{"$gte": 42}}).Sort("$natural").Prefetch(0).Batch(2) + iter := query.Tail(-1) + c.Assert(err, IsNil) + + n := len(ns) + result := struct{ N int }{} + for i := 2; i != n; i++ { + ok := iter.Next(&result) + c.Assert(ok, Equals, true) + c.Assert(result.N, Equals, ns[i]) + if i == 3 { // The batch boundary. + stats := mgo.GetStats() + c.Assert(stats.ReceivedDocs, Equals, 2) + } + } + + mgo.ResetStats() + + // The following call to Next will block. + go func() { + time.Sleep(5e8) + session := session.New() + defer session.Close() + coll := session.DB("mydb").C("mycoll") + coll.Insert(M{"n": 47}) + }() + + c.Log("Will wait for Next with N=47...") + ok := iter.Next(&result) + c.Assert(ok, Equals, true) + c.Assert(iter.Err(), IsNil) + c.Assert(iter.Timeout(), Equals, false) + c.Assert(result.N, Equals, 47) + c.Log("Got Next with N=47!") + + // The following may break because it depends a bit on the internal + // timing used by MongoDB's AwaitData logic. If it does, the problem + // will be observed as more GET_MORE_OPs than predicted: + // 1*QUERY_OP for nonce + 1*GET_MORE_OP on Next + + // 1*INSERT_OP + 1*QUERY_OP for getLastError on insert of 47 + stats := mgo.GetStats() + if s.versionAtLeast(3, 0) { // TODO Will be 2.6 when write commands are enabled for it. + c.Assert(stats.SentOps, Equals, 3) + } else { + c.Assert(stats.SentOps, Equals, 4) + } + c.Assert(stats.ReceivedOps, Equals, 3) // REPLY_OPs for 1*QUERY_OP for nonce + 1*GET_MORE_OPs and 1*QUERY_OP + c.Assert(stats.ReceivedDocs, Equals, 3) // nonce + N=47 result + getLastError response + + c.Log("Will wait for a result which will never come...") + + gotNext := make(chan bool) + go func() { + ok := iter.Next(&result) + gotNext <- ok + }() + + select { + case ok := <-gotNext: + c.Fatalf("Next returned: %v", ok) + case <-time.After(3e9): + // Good. Should still be sleeping at that point. + } + + // Closing the session should cause Next to return. + session.Close() + + select { + case ok := <-gotNext: + c.Assert(ok, Equals, false) + c.Assert(iter.Err(), ErrorMatches, "Closed explicitly") + c.Assert(iter.Timeout(), Equals, false) + case <-time.After(1e9): + c.Fatal("Closing the session did not unblock Next") + } +} + +func (s *S) TestIterNextResetsResult(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + ns := []int{1, 2, 3} + for _, n := range ns { + coll.Insert(M{"n" + strconv.Itoa(n): n}) + } + + query := coll.Find(nil).Sort("$natural") + + i := 0 + var sresult *struct{ N1, N2, N3 int } + iter := query.Iter() + for iter.Next(&sresult) { + switch i { + case 0: + c.Assert(sresult.N1, Equals, 1) + c.Assert(sresult.N2+sresult.N3, Equals, 0) + case 1: + c.Assert(sresult.N2, Equals, 2) + c.Assert(sresult.N1+sresult.N3, Equals, 0) + case 2: + c.Assert(sresult.N3, Equals, 3) + c.Assert(sresult.N1+sresult.N2, Equals, 0) + } + i++ + } + c.Assert(iter.Close(), IsNil) + + i = 0 + var mresult M + iter = query.Iter() + for iter.Next(&mresult) { + delete(mresult, "_id") + switch i { + case 0: + c.Assert(mresult, DeepEquals, M{"n1": 1}) + case 1: + c.Assert(mresult, DeepEquals, M{"n2": 2}) + case 2: + c.Assert(mresult, DeepEquals, M{"n3": 3}) + } + i++ + } + c.Assert(iter.Close(), IsNil) + + i = 0 + var iresult interface{} + iter = query.Iter() + for iter.Next(&iresult) { + mresult, ok := iresult.(bson.M) + c.Assert(ok, Equals, true, Commentf("%#v", iresult)) + delete(mresult, "_id") + switch i { + case 0: + c.Assert(mresult, DeepEquals, bson.M{"n1": 1}) + case 1: + c.Assert(mresult, DeepEquals, bson.M{"n2": 2}) + case 2: + c.Assert(mresult, DeepEquals, bson.M{"n3": 3}) + } + i++ + } + c.Assert(iter.Close(), IsNil) +} + +func (s *S) TestFindForOnIter(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + ns := []int{40, 41, 42, 43, 44, 45, 46} + for _, n := range ns { + coll.Insert(M{"n": n}) + } + + session.Refresh() // Release socket. + + mgo.ResetStats() + + query := coll.Find(M{"n": M{"$gte": 42}}).Sort("$natural").Prefetch(0).Batch(2) + iter := query.Iter() + + i := 2 + var result *struct{ N int } + err = iter.For(&result, func() error { + c.Assert(i < 7, Equals, true) + c.Assert(result.N, Equals, ns[i]) + if i == 1 { + stats := mgo.GetStats() + c.Assert(stats.ReceivedDocs, Equals, 2) + } + i++ + return nil + }) + c.Assert(err, IsNil) + + session.Refresh() // Release socket. + + stats := mgo.GetStats() + c.Assert(stats.SentOps, Equals, 3) // 1*QUERY_OP + 2*GET_MORE_OP + c.Assert(stats.ReceivedOps, Equals, 3) // and their REPLY_OPs. + c.Assert(stats.ReceivedDocs, Equals, 5) + c.Assert(stats.SocketsInUse, Equals, 0) +} + +func (s *S) TestFindFor(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + ns := []int{40, 41, 42, 43, 44, 45, 46} + for _, n := range ns { + coll.Insert(M{"n": n}) + } + + session.Refresh() // Release socket. + + mgo.ResetStats() + + query := coll.Find(M{"n": M{"$gte": 42}}).Sort("$natural").Prefetch(0).Batch(2) + + i := 2 + var result *struct{ N int } + err = query.For(&result, func() error { + c.Assert(i < 7, Equals, true) + c.Assert(result.N, Equals, ns[i]) + if i == 1 { + stats := mgo.GetStats() + c.Assert(stats.ReceivedDocs, Equals, 2) + } + i++ + return nil + }) + c.Assert(err, IsNil) + + session.Refresh() // Release socket. + + stats := mgo.GetStats() + c.Assert(stats.SentOps, Equals, 3) // 1*QUERY_OP + 2*GET_MORE_OP + c.Assert(stats.ReceivedOps, Equals, 3) // and their REPLY_OPs. + c.Assert(stats.ReceivedDocs, Equals, 5) + c.Assert(stats.SocketsInUse, Equals, 0) +} + +func (s *S) TestFindForStopOnError(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + ns := []int{40, 41, 42, 43, 44, 45, 46} + for _, n := range ns { + coll.Insert(M{"n": n}) + } + + query := coll.Find(M{"n": M{"$gte": 42}}) + i := 2 + var result *struct{ N int } + err = query.For(&result, func() error { + c.Assert(i < 4, Equals, true) + c.Assert(result.N, Equals, ns[i]) + if i == 3 { + return fmt.Errorf("stop!") + } + i++ + return nil + }) + c.Assert(err, ErrorMatches, "stop!") +} + +func (s *S) TestFindForResetsResult(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + ns := []int{1, 2, 3} + for _, n := range ns { + coll.Insert(M{"n" + strconv.Itoa(n): n}) + } + + query := coll.Find(nil).Sort("$natural") + + i := 0 + var sresult *struct{ N1, N2, N3 int } + err = query.For(&sresult, func() error { + switch i { + case 0: + c.Assert(sresult.N1, Equals, 1) + c.Assert(sresult.N2+sresult.N3, Equals, 0) + case 1: + c.Assert(sresult.N2, Equals, 2) + c.Assert(sresult.N1+sresult.N3, Equals, 0) + case 2: + c.Assert(sresult.N3, Equals, 3) + c.Assert(sresult.N1+sresult.N2, Equals, 0) + } + i++ + return nil + }) + c.Assert(err, IsNil) + + i = 0 + var mresult M + err = query.For(&mresult, func() error { + delete(mresult, "_id") + switch i { + case 0: + c.Assert(mresult, DeepEquals, M{"n1": 1}) + case 1: + c.Assert(mresult, DeepEquals, M{"n2": 2}) + case 2: + c.Assert(mresult, DeepEquals, M{"n3": 3}) + } + i++ + return nil + }) + c.Assert(err, IsNil) + + i = 0 + var iresult interface{} + err = query.For(&iresult, func() error { + mresult, ok := iresult.(bson.M) + c.Assert(ok, Equals, true, Commentf("%#v", iresult)) + delete(mresult, "_id") + switch i { + case 0: + c.Assert(mresult, DeepEquals, bson.M{"n1": 1}) + case 1: + c.Assert(mresult, DeepEquals, bson.M{"n2": 2}) + case 2: + c.Assert(mresult, DeepEquals, bson.M{"n3": 3}) + } + i++ + return nil + }) + c.Assert(err, IsNil) +} + +func (s *S) TestFindIterSnapshot(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + // Insane amounts of logging otherwise due to the + // amount of data being shuffled. + mgo.SetDebug(false) + defer mgo.SetDebug(true) + + coll := session.DB("mydb").C("mycoll") + + var a [1024000]byte + + for n := 0; n < 10; n++ { + err := coll.Insert(M{"_id": n, "n": n, "a1": &a}) + c.Assert(err, IsNil) + } + + query := coll.Find(M{"n": M{"$gt": -1}}).Batch(2).Prefetch(0) + query.Snapshot() + iter := query.Iter() + + seen := map[int]bool{} + result := struct { + Id int "_id" + }{} + for iter.Next(&result) { + if len(seen) == 2 { + // Grow all entries so that they have to move. + // Backwards so that the order is inverted. + for n := 10; n >= 0; n-- { + _, err := coll.Upsert(M{"_id": n}, M{"$set": M{"a2": &a}}) + c.Assert(err, IsNil) + } + } + if seen[result.Id] { + c.Fatalf("seen duplicated key: %d", result.Id) + } + seen[result.Id] = true + } + c.Assert(iter.Close(), IsNil) +} + +func (s *S) TestSort(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + coll.Insert(M{"a": 1, "b": 1}) + coll.Insert(M{"a": 2, "b": 2}) + coll.Insert(M{"a": 2, "b": 1}) + coll.Insert(M{"a": 0, "b": 1}) + coll.Insert(M{"a": 2, "b": 0}) + coll.Insert(M{"a": 0, "b": 2}) + coll.Insert(M{"a": 1, "b": 2}) + coll.Insert(M{"a": 0, "b": 0}) + coll.Insert(M{"a": 1, "b": 0}) + + query := coll.Find(M{}) + query.Sort("-a") // Should be ignored. + query.Sort("-b", "a") + iter := query.Iter() + + l := make([]int, 18) + r := struct{ A, B int }{} + for i := 0; i != len(l); i += 2 { + ok := iter.Next(&r) + c.Assert(ok, Equals, true) + c.Assert(err, IsNil) + l[i] = r.A + l[i+1] = r.B + } + + c.Assert(l, DeepEquals, []int{0, 2, 1, 2, 2, 2, 0, 1, 1, 1, 2, 1, 0, 0, 1, 0, 2, 0}) +} + +func (s *S) TestSortWithBadArgs(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + f1 := func() { coll.Find(nil).Sort("") } + f2 := func() { coll.Find(nil).Sort("+") } + f3 := func() { coll.Find(nil).Sort("foo", "-") } + + for _, f := range []func(){f1, f2, f3} { + c.Assert(f, PanicMatches, "Sort: empty field name") + } +} + +func (s *S) TestSortScoreText(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + err = coll.EnsureIndex(mgo.Index{ + Key: []string{"$text:a", "$text:b"}, + }) + c.Assert(err, IsNil) + + err = coll.Insert(M{ + "a": "none", + "b": "twice: foo foo", + }) + c.Assert(err, IsNil) + err = coll.Insert(M{ + "a": "just once: foo", + "b": "none", + }) + c.Assert(err, IsNil) + err = coll.Insert(M{ + "a": "many: foo foo foo", + "b": "none", + }) + c.Assert(err, IsNil) + err = coll.Insert(M{ + "a": "none", + "b": "none", + "c": "ignore: foo", + }) + c.Assert(err, IsNil) + + query := coll.Find(M{"$text": M{"$search": "foo"}}) + query.Select(M{"score": M{"$meta": "textScore"}}) + query.Sort("$textScore:score") + iter := query.Iter() + + var r struct{ A, B string } + var results []string + for iter.Next(&r) { + results = append(results, r.A, r.B) + } + + c.Assert(results, DeepEquals, []string{ + "many: foo foo foo", "none", + "none", "twice: foo foo", + "just once: foo", "none", + }) +} + +func (s *S) TestPrefetching(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + const total = 600 + mgo.SetDebug(false) + docs := make([]interface{}, total) + for i := 0; i != total; i++ { + docs[i] = bson.D{{"n", i}} + } + err = coll.Insert(docs...) + c.Assert(err, IsNil) + + for testi := 0; testi < 5; testi++ { + mgo.ResetStats() + + var iter *mgo.Iter + var beforeMore int + + switch testi { + case 0: // The default session value. + session.SetBatch(100) + iter = coll.Find(M{}).Iter() + beforeMore = 75 + + case 2: // Changing the session value. + session.SetBatch(100) + session.SetPrefetch(0.27) + iter = coll.Find(M{}).Iter() + beforeMore = 73 + + case 1: // Changing via query methods. + iter = coll.Find(M{}).Prefetch(0.27).Batch(100).Iter() + beforeMore = 73 + + case 3: // With prefetch on first document. + iter = coll.Find(M{}).Prefetch(1.0).Batch(100).Iter() + beforeMore = 0 + + case 4: // Without prefetch. + iter = coll.Find(M{}).Prefetch(0).Batch(100).Iter() + beforeMore = 100 + } + + pings := 0 + for batchi := 0; batchi < len(docs)/100-1; batchi++ { + c.Logf("Iterating over %d documents on batch %d", beforeMore, batchi) + var result struct{ N int } + for i := 0; i < beforeMore; i++ { + ok := iter.Next(&result) + c.Assert(ok, Equals, true, Commentf("iter.Err: %v", iter.Err())) + } + beforeMore = 99 + c.Logf("Done iterating.") + + session.Run("ping", nil) // Roundtrip to settle down. + pings++ + + stats := mgo.GetStats() + c.Assert(stats.ReceivedDocs, Equals, (batchi+1)*100+pings) + + c.Logf("Iterating over one more document on batch %d", batchi) + ok := iter.Next(&result) + c.Assert(ok, Equals, true, Commentf("iter.Err: %v", iter.Err())) + c.Logf("Done iterating.") + + session.Run("ping", nil) // Roundtrip to settle down. + pings++ + + stats = mgo.GetStats() + c.Assert(stats.ReceivedDocs, Equals, (batchi+2)*100+pings) + } + } +} + +func (s *S) TestSafeSetting(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + // Check the default + safe := session.Safe() + c.Assert(safe.W, Equals, 0) + c.Assert(safe.WMode, Equals, "") + c.Assert(safe.WTimeout, Equals, 0) + c.Assert(safe.FSync, Equals, false) + c.Assert(safe.J, Equals, false) + + // Tweak it + session.SetSafe(&mgo.Safe{W: 1, WTimeout: 2, FSync: true}) + safe = session.Safe() + c.Assert(safe.W, Equals, 1) + c.Assert(safe.WMode, Equals, "") + c.Assert(safe.WTimeout, Equals, 2) + c.Assert(safe.FSync, Equals, true) + c.Assert(safe.J, Equals, false) + + // Reset it again. + session.SetSafe(&mgo.Safe{}) + safe = session.Safe() + c.Assert(safe.W, Equals, 0) + c.Assert(safe.WMode, Equals, "") + c.Assert(safe.WTimeout, Equals, 0) + c.Assert(safe.FSync, Equals, false) + c.Assert(safe.J, Equals, false) + + // Ensure safety to something more conservative. + session.SetSafe(&mgo.Safe{W: 5, WTimeout: 6, J: true}) + safe = session.Safe() + c.Assert(safe.W, Equals, 5) + c.Assert(safe.WMode, Equals, "") + c.Assert(safe.WTimeout, Equals, 6) + c.Assert(safe.FSync, Equals, false) + c.Assert(safe.J, Equals, true) + + // Ensure safety to something less conservative won't change it. + session.EnsureSafe(&mgo.Safe{W: 4, WTimeout: 7}) + safe = session.Safe() + c.Assert(safe.W, Equals, 5) + c.Assert(safe.WMode, Equals, "") + c.Assert(safe.WTimeout, Equals, 6) + c.Assert(safe.FSync, Equals, false) + c.Assert(safe.J, Equals, true) + + // But to something more conservative will. + session.EnsureSafe(&mgo.Safe{W: 6, WTimeout: 4, FSync: true}) + safe = session.Safe() + c.Assert(safe.W, Equals, 6) + c.Assert(safe.WMode, Equals, "") + c.Assert(safe.WTimeout, Equals, 4) + c.Assert(safe.FSync, Equals, true) + c.Assert(safe.J, Equals, false) + + // Even more conservative. + session.EnsureSafe(&mgo.Safe{WMode: "majority", WTimeout: 2}) + safe = session.Safe() + c.Assert(safe.W, Equals, 0) + c.Assert(safe.WMode, Equals, "majority") + c.Assert(safe.WTimeout, Equals, 2) + c.Assert(safe.FSync, Equals, true) + c.Assert(safe.J, Equals, false) + + // WMode always overrides, whatever it is, but J doesn't. + session.EnsureSafe(&mgo.Safe{WMode: "something", J: true}) + safe = session.Safe() + c.Assert(safe.W, Equals, 0) + c.Assert(safe.WMode, Equals, "something") + c.Assert(safe.WTimeout, Equals, 2) + c.Assert(safe.FSync, Equals, true) + c.Assert(safe.J, Equals, false) + + // EnsureSafe with nil does nothing. + session.EnsureSafe(nil) + safe = session.Safe() + c.Assert(safe.W, Equals, 0) + c.Assert(safe.WMode, Equals, "something") + c.Assert(safe.WTimeout, Equals, 2) + c.Assert(safe.FSync, Equals, true) + c.Assert(safe.J, Equals, false) + + // Changing the safety of a cloned session doesn't touch the original. + clone := session.Clone() + defer clone.Close() + clone.EnsureSafe(&mgo.Safe{WMode: "foo"}) + safe = session.Safe() + c.Assert(safe.WMode, Equals, "something") +} + +func (s *S) TestSafeInsert(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + // Insert an element with a predefined key. + err = coll.Insert(M{"_id": 1}) + c.Assert(err, IsNil) + + mgo.ResetStats() + + // Session should be safe by default, so inserting it again must fail. + err = coll.Insert(M{"_id": 1}) + c.Assert(err, ErrorMatches, ".*E11000 duplicate.*") + c.Assert(err.(*mgo.LastError).Code, Equals, 11000) + + // It must have sent two operations (INSERT_OP + getLastError QUERY_OP) + stats := mgo.GetStats() + + // TODO Will be 2.6 when write commands are enabled for it. + if s.versionAtLeast(3, 0) { + c.Assert(stats.SentOps, Equals, 1) + } else { + c.Assert(stats.SentOps, Equals, 2) + } + + mgo.ResetStats() + + // If we disable safety, though, it won't complain. + session.SetSafe(nil) + err = coll.Insert(M{"_id": 1}) + c.Assert(err, IsNil) + + // Must have sent a single operation this time (just the INSERT_OP) + stats = mgo.GetStats() + c.Assert(stats.SentOps, Equals, 1) +} + +func (s *S) TestSafeParameters(c *C) { + session, err := mgo.Dial("localhost:40011") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + // Tweak the safety parameters to something unachievable. + session.SetSafe(&mgo.Safe{W: 4, WTimeout: 100}) + err = coll.Insert(M{"_id": 1}) + c.Assert(err, ErrorMatches, "timeout|timed out waiting for slaves|Not enough data-bearing nodes|waiting for replication timed out") // :-( + if !s.versionAtLeast(2, 6) { + // 2.6 turned it into a query error. + c.Assert(err.(*mgo.LastError).WTimeout, Equals, true) + } +} + +func (s *S) TestQueryErrorOne(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + result := struct { + Err string "$err" + }{} + + err = coll.Find(M{"a": 1}).Select(M{"a": M{"b": 1}}).One(&result) + c.Assert(err, ErrorMatches, ".*Unsupported projection option:.*") + c.Assert(err.(*mgo.QueryError).Message, Matches, ".*Unsupported projection option:.*") + if s.versionAtLeast(2, 6) { + // Oh, the dance of error codes. :-( + c.Assert(err.(*mgo.QueryError).Code, Equals, 17287) + } else { + c.Assert(err.(*mgo.QueryError).Code, Equals, 13097) + } + + // The result should be properly unmarshalled with QueryError + c.Assert(result.Err, Matches, ".*Unsupported projection option:.*") +} + +func (s *S) TestQueryErrorNext(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + result := struct { + Err string "$err" + }{} + + iter := coll.Find(M{"a": 1}).Select(M{"a": M{"b": 1}}).Iter() + + ok := iter.Next(&result) + c.Assert(ok, Equals, false) + + err = iter.Close() + c.Assert(err, ErrorMatches, ".*Unsupported projection option:.*") + c.Assert(err.(*mgo.QueryError).Message, Matches, ".*Unsupported projection option:.*") + if s.versionAtLeast(2, 6) { + // Oh, the dance of error codes. :-( + c.Assert(err.(*mgo.QueryError).Code, Equals, 17287) + } else { + c.Assert(err.(*mgo.QueryError).Code, Equals, 13097) + } + c.Assert(iter.Err(), Equals, err) + + // The result should be properly unmarshalled with QueryError + c.Assert(result.Err, Matches, ".*Unsupported projection option:.*") +} + +var indexTests = []struct { + index mgo.Index + expected M +}{{ + mgo.Index{ + Key: []string{"a"}, + Background: true, + }, + M{ + "name": "a_1", + "key": M{"a": 1}, + "ns": "mydb.mycoll", + "background": true, + }, +}, { + mgo.Index{ + Key: []string{"a", "-b"}, + Unique: true, + DropDups: true, + }, + M{ + "name": "a_1_b_-1", + "key": M{"a": 1, "b": -1}, + "ns": "mydb.mycoll", + "unique": true, + "dropDups": true, + }, +}, { + mgo.Index{ + Key: []string{"@loc_old"}, // Obsolete + Min: -500, + Max: 500, + Bits: 32, + }, + M{ + "name": "loc_old_2d", + "key": M{"loc_old": "2d"}, + "ns": "mydb.mycoll", + "min": -500, + "max": 500, + "bits": 32, + }, +}, { + mgo.Index{ + Key: []string{"$2d:loc"}, + Min: -500, + Max: 500, + Bits: 32, + }, + M{ + "name": "loc_2d", + "key": M{"loc": "2d"}, + "ns": "mydb.mycoll", + "min": -500, + "max": 500, + "bits": 32, + }, +}, { + mgo.Index{ + Key: []string{"$geoHaystack:loc", "type"}, + BucketSize: 1, + }, + M{ + "name": "loc_geoHaystack_type_1", + "key": M{"loc": "geoHaystack", "type": 1}, + "ns": "mydb.mycoll", + "bucketSize": 1.0, + }, +}, { + mgo.Index{ + Key: []string{"$text:a", "$text:b"}, + Weights: map[string]int{"b": 42}, + }, + M{ + "name": "a_text_b_text", + "key": M{"_fts": "text", "_ftsx": 1}, + "ns": "mydb.mycoll", + "weights": M{"a": 1, "b": 42}, + "default_language": "english", + "language_override": "language", + "textIndexVersion": 2, + }, +}, { + mgo.Index{ + Key: []string{"$text:a"}, + DefaultLanguage: "portuguese", + LanguageOverride: "idioma", + }, + M{ + "name": "a_text", + "key": M{"_fts": "text", "_ftsx": 1}, + "ns": "mydb.mycoll", + "weights": M{"a": 1}, + "default_language": "portuguese", + "language_override": "idioma", + "textIndexVersion": 2, + }, +}, { + mgo.Index{ + Key: []string{"$text:$**"}, + }, + M{ + "name": "$**_text", + "key": M{"_fts": "text", "_ftsx": 1}, + "ns": "mydb.mycoll", + "weights": M{"$**": 1}, + "default_language": "english", + "language_override": "language", + "textIndexVersion": 2, + }, +}} + +func (s *S) TestEnsureIndex(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + idxs := session.DB("mydb").C("system.indexes") + + for _, test := range indexTests { + if !s.versionAtLeast(2, 4) && test.expected["weights"] != nil { + // No text indexes until 2.4. + continue + } + + err = coll.EnsureIndex(test.index) + c.Assert(err, IsNil) + + obtained := M{} + err = idxs.Find(M{"name": test.expected["name"]}).One(obtained) + c.Assert(err, IsNil) + + delete(obtained, "v") + + if s.versionAtLeast(2, 7) { + // Was deprecated in 2.6, and not being reported by 2.7+. + delete(test.expected, "dropDups") + } + + c.Assert(obtained, DeepEquals, test.expected) + + err = coll.DropIndex(test.index.Key...) + c.Assert(err, IsNil) + } +} + +func (s *S) TestEnsureIndexWithBadInfo(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + err = coll.EnsureIndex(mgo.Index{}) + c.Assert(err, ErrorMatches, "invalid index key:.*") + + err = coll.EnsureIndex(mgo.Index{Key: []string{""}}) + c.Assert(err, ErrorMatches, "invalid index key:.*") +} + +func (s *S) TestEnsureIndexWithUnsafeSession(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + session.SetSafe(nil) + + coll := session.DB("mydb").C("mycoll") + + err = coll.Insert(M{"a": 1}) + c.Assert(err, IsNil) + + err = coll.Insert(M{"a": 1}) + c.Assert(err, IsNil) + + // Should fail since there are duplicated entries. + index := mgo.Index{ + Key: []string{"a"}, + Unique: true, + } + + err = coll.EnsureIndex(index) + c.Assert(err, ErrorMatches, ".*duplicate key error.*") +} + +func (s *S) TestEnsureIndexKey(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + err = coll.EnsureIndexKey("a") + c.Assert(err, IsNil) + + err = coll.EnsureIndexKey("a", "-b") + c.Assert(err, IsNil) + + sysidx := session.DB("mydb").C("system.indexes") + + result1 := M{} + err = sysidx.Find(M{"name": "a_1"}).One(result1) + c.Assert(err, IsNil) + + result2 := M{} + err = sysidx.Find(M{"name": "a_1_b_-1"}).One(result2) + c.Assert(err, IsNil) + + delete(result1, "v") + expected1 := M{ + "name": "a_1", + "key": M{"a": 1}, + "ns": "mydb.mycoll", + } + c.Assert(result1, DeepEquals, expected1) + + delete(result2, "v") + expected2 := M{ + "name": "a_1_b_-1", + "key": M{"a": 1, "b": -1}, + "ns": "mydb.mycoll", + } + c.Assert(result2, DeepEquals, expected2) +} + +func (s *S) TestEnsureIndexDropIndex(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + err = coll.EnsureIndexKey("a") + c.Assert(err, IsNil) + + err = coll.EnsureIndexKey("-b") + c.Assert(err, IsNil) + + err = coll.DropIndex("-b") + c.Assert(err, IsNil) + + sysidx := session.DB("mydb").C("system.indexes") + dummy := &struct{}{} + + err = sysidx.Find(M{"name": "a_1"}).One(dummy) + c.Assert(err, IsNil) + + err = sysidx.Find(M{"name": "b_1"}).One(dummy) + c.Assert(err, Equals, mgo.ErrNotFound) + + err = coll.DropIndex("a") + c.Assert(err, IsNil) + + err = sysidx.Find(M{"name": "a_1"}).One(dummy) + c.Assert(err, Equals, mgo.ErrNotFound) + + err = coll.DropIndex("a") + c.Assert(err, ErrorMatches, "index not found.*") +} + +func (s *S) TestEnsureIndexCaching(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + err = coll.EnsureIndexKey("a") + c.Assert(err, IsNil) + + mgo.ResetStats() + + // Second EnsureIndex should be cached and do nothing. + err = coll.EnsureIndexKey("a") + c.Assert(err, IsNil) + + stats := mgo.GetStats() + c.Assert(stats.SentOps, Equals, 0) + + // Resetting the cache should make it contact the server again. + session.ResetIndexCache() + + err = coll.EnsureIndexKey("a") + c.Assert(err, IsNil) + + stats = mgo.GetStats() + c.Assert(stats.SentOps > 0, Equals, true) + + // Dropping the index should also drop the cached index key. + err = coll.DropIndex("a") + c.Assert(err, IsNil) + + mgo.ResetStats() + + err = coll.EnsureIndexKey("a") + c.Assert(err, IsNil) + + stats = mgo.GetStats() + c.Assert(stats.SentOps > 0, Equals, true) +} + +func (s *S) TestEnsureIndexGetIndexes(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + err = coll.EnsureIndexKey("-b") + c.Assert(err, IsNil) + + err = coll.EnsureIndexKey("a") + c.Assert(err, IsNil) + + // Obsolete. + err = coll.EnsureIndexKey("@c") + c.Assert(err, IsNil) + + err = coll.EnsureIndexKey("$2d:d") + c.Assert(err, IsNil) + + // Try to exercise cursor logic. 2.8.0-rc3 still ignores this. + session.SetBatch(2) + + indexes, err := coll.Indexes() + c.Assert(err, IsNil) + + c.Assert(indexes[0].Name, Equals, "_id_") + c.Assert(indexes[1].Name, Equals, "a_1") + c.Assert(indexes[1].Key, DeepEquals, []string{"a"}) + c.Assert(indexes[2].Name, Equals, "b_-1") + c.Assert(indexes[2].Key, DeepEquals, []string{"-b"}) + c.Assert(indexes[3].Name, Equals, "c_2d") + c.Assert(indexes[3].Key, DeepEquals, []string{"$2d:c"}) + c.Assert(indexes[4].Name, Equals, "d_2d") + c.Assert(indexes[4].Key, DeepEquals, []string{"$2d:d"}) +} + +func (s *S) TestEnsureIndexEvalGetIndexes(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + err = session.Run(bson.D{{"eval", "db.getSiblingDB('mydb').mycoll.ensureIndex({b: -1})"}}, nil) + c.Assert(err, IsNil) + err = session.Run(bson.D{{"eval", "db.getSiblingDB('mydb').mycoll.ensureIndex({a: 1})"}}, nil) + c.Assert(err, IsNil) + err = session.Run(bson.D{{"eval", "db.getSiblingDB('mydb').mycoll.ensureIndex({c: -1, e: 1})"}}, nil) + c.Assert(err, IsNil) + err = session.Run(bson.D{{"eval", "db.getSiblingDB('mydb').mycoll.ensureIndex({d: '2d'})"}}, nil) + c.Assert(err, IsNil) + + indexes, err := coll.Indexes() + c.Assert(err, IsNil) + + c.Assert(indexes[0].Name, Equals, "_id_") + c.Assert(indexes[1].Name, Equals, "a_1") + c.Assert(indexes[1].Key, DeepEquals, []string{"a"}) + c.Assert(indexes[2].Name, Equals, "b_-1") + c.Assert(indexes[2].Key, DeepEquals, []string{"-b"}) + c.Assert(indexes[3].Name, Equals, "c_-1_e_1") + c.Assert(indexes[3].Key, DeepEquals, []string{"-c", "e"}) + if s.versionAtLeast(2, 2) { + c.Assert(indexes[4].Name, Equals, "d_2d") + c.Assert(indexes[4].Key, DeepEquals, []string{"$2d:d"}) + } else { + c.Assert(indexes[4].Name, Equals, "d_") + c.Assert(indexes[4].Key, DeepEquals, []string{"$2d:d"}) + } +} + +var testTTL = flag.Bool("test-ttl", false, "test TTL collections (may take 1 minute)") + +func (s *S) TestEnsureIndexExpireAfter(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + session.SetSafe(nil) + + coll := session.DB("mydb").C("mycoll") + + err = coll.Insert(M{"n": 1, "t": time.Now().Add(-120 * time.Second)}) + c.Assert(err, IsNil) + err = coll.Insert(M{"n": 2, "t": time.Now()}) + c.Assert(err, IsNil) + + // Should fail since there are duplicated entries. + index := mgo.Index{ + Key: []string{"t"}, + ExpireAfter: 1 * time.Minute, + } + + err = coll.EnsureIndex(index) + c.Assert(err, IsNil) + + indexes, err := coll.Indexes() + c.Assert(err, IsNil) + c.Assert(indexes[1].Name, Equals, "t_1") + c.Assert(indexes[1].ExpireAfter, Equals, 1*time.Minute) + + if *testTTL { + worked := false + stop := time.Now().Add(70 * time.Second) + for time.Now().Before(stop) { + n, err := coll.Count() + c.Assert(err, IsNil) + if n == 1 { + worked = true + break + } + c.Assert(n, Equals, 2) + c.Logf("Still has 2 entries...") + time.Sleep(1 * time.Second) + } + if !worked { + c.Fatalf("TTL index didn't work") + } + } +} + +func (s *S) TestDistinct(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + for _, i := range []int{1, 4, 6, 2, 2, 3, 4} { + coll.Insert(M{"n": i}) + } + + var result []int + err = coll.Find(M{"n": M{"$gt": 2}}).Sort("n").Distinct("n", &result) + + sort.IntSlice(result).Sort() + c.Assert(result, DeepEquals, []int{3, 4, 6}) +} + +func (s *S) TestMapReduce(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + for _, i := range []int{1, 4, 6, 2, 2, 3, 4} { + coll.Insert(M{"n": i}) + } + + job := &mgo.MapReduce{ + Map: "function() { emit(this.n, 1); }", + Reduce: "function(key, values) { return Array.sum(values); }", + } + var result []struct { + Id int "_id" + Value int + } + + info, err := coll.Find(M{"n": M{"$gt": 2}}).MapReduce(job, &result) + c.Assert(err, IsNil) + c.Assert(info.InputCount, Equals, 4) + c.Assert(info.EmitCount, Equals, 4) + c.Assert(info.OutputCount, Equals, 3) + c.Assert(info.VerboseTime, IsNil) + + expected := map[int]int{3: 1, 4: 2, 6: 1} + for _, item := range result { + c.Logf("Item: %#v", &item) + c.Assert(item.Value, Equals, expected[item.Id]) + expected[item.Id] = -1 + } +} + +func (s *S) TestMapReduceFinalize(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + for _, i := range []int{1, 4, 6, 2, 2, 3, 4} { + coll.Insert(M{"n": i}) + } + + job := &mgo.MapReduce{ + Map: "function() { emit(this.n, 1) }", + Reduce: "function(key, values) { return Array.sum(values) }", + Finalize: "function(key, count) { return {count: count} }", + } + var result []struct { + Id int "_id" + Value struct{ Count int } + } + _, err = coll.Find(nil).MapReduce(job, &result) + c.Assert(err, IsNil) + + expected := map[int]int{1: 1, 2: 2, 3: 1, 4: 2, 6: 1} + for _, item := range result { + c.Logf("Item: %#v", &item) + c.Assert(item.Value.Count, Equals, expected[item.Id]) + expected[item.Id] = -1 + } +} + +func (s *S) TestMapReduceToCollection(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + for _, i := range []int{1, 4, 6, 2, 2, 3, 4} { + coll.Insert(M{"n": i}) + } + + job := &mgo.MapReduce{ + Map: "function() { emit(this.n, 1); }", + Reduce: "function(key, values) { return Array.sum(values); }", + Out: "mr", + } + + info, err := coll.Find(nil).MapReduce(job, nil) + c.Assert(err, IsNil) + c.Assert(info.InputCount, Equals, 7) + c.Assert(info.EmitCount, Equals, 7) + c.Assert(info.OutputCount, Equals, 5) + c.Assert(info.Collection, Equals, "mr") + c.Assert(info.Database, Equals, "mydb") + + expected := map[int]int{1: 1, 2: 2, 3: 1, 4: 2, 6: 1} + var item *struct { + Id int "_id" + Value int + } + mr := session.DB("mydb").C("mr") + iter := mr.Find(nil).Iter() + for iter.Next(&item) { + c.Logf("Item: %#v", &item) + c.Assert(item.Value, Equals, expected[item.Id]) + expected[item.Id] = -1 + } + c.Assert(iter.Close(), IsNil) +} + +func (s *S) TestMapReduceToOtherDb(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + for _, i := range []int{1, 4, 6, 2, 2, 3, 4} { + coll.Insert(M{"n": i}) + } + + job := &mgo.MapReduce{ + Map: "function() { emit(this.n, 1); }", + Reduce: "function(key, values) { return Array.sum(values); }", + Out: bson.D{{"replace", "mr"}, {"db", "otherdb"}}, + } + + info, err := coll.Find(nil).MapReduce(job, nil) + c.Assert(err, IsNil) + c.Assert(info.InputCount, Equals, 7) + c.Assert(info.EmitCount, Equals, 7) + c.Assert(info.OutputCount, Equals, 5) + c.Assert(info.Collection, Equals, "mr") + c.Assert(info.Database, Equals, "otherdb") + + expected := map[int]int{1: 1, 2: 2, 3: 1, 4: 2, 6: 1} + var item *struct { + Id int "_id" + Value int + } + mr := session.DB("otherdb").C("mr") + iter := mr.Find(nil).Iter() + for iter.Next(&item) { + c.Logf("Item: %#v", &item) + c.Assert(item.Value, Equals, expected[item.Id]) + expected[item.Id] = -1 + } + c.Assert(iter.Close(), IsNil) +} + +func (s *S) TestMapReduceOutOfOrder(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + for _, i := range []int{1, 4, 6, 2, 2, 3, 4} { + coll.Insert(M{"n": i}) + } + + job := &mgo.MapReduce{ + Map: "function() { emit(this.n, 1); }", + Reduce: "function(key, values) { return Array.sum(values); }", + Out: bson.M{"a": "a", "z": "z", "replace": "mr", "db": "otherdb", "b": "b", "y": "y"}, + } + + info, err := coll.Find(nil).MapReduce(job, nil) + c.Assert(err, IsNil) + c.Assert(info.Collection, Equals, "mr") + c.Assert(info.Database, Equals, "otherdb") +} + +func (s *S) TestMapReduceScope(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + coll.Insert(M{"n": 1}) + + job := &mgo.MapReduce{ + Map: "function() { emit(this.n, x); }", + Reduce: "function(key, values) { return Array.sum(values); }", + Scope: M{"x": 42}, + } + + var result []bson.M + _, err = coll.Find(nil).MapReduce(job, &result) + c.Assert(len(result), Equals, 1) + c.Assert(result[0]["value"], Equals, 42.0) +} + +func (s *S) TestMapReduceVerbose(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + for i := 0; i < 100; i++ { + err = coll.Insert(M{"n": i}) + c.Assert(err, IsNil) + } + + job := &mgo.MapReduce{ + Map: "function() { emit(this.n, 1); }", + Reduce: "function(key, values) { return Array.sum(values); }", + Verbose: true, + } + + info, err := coll.Find(nil).MapReduce(job, nil) + c.Assert(err, IsNil) + c.Assert(info.VerboseTime, NotNil) +} + +func (s *S) TestMapReduceLimit(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + for _, i := range []int{1, 4, 6, 2, 2, 3, 4} { + coll.Insert(M{"n": i}) + } + + job := &mgo.MapReduce{ + Map: "function() { emit(this.n, 1); }", + Reduce: "function(key, values) { return Array.sum(values); }", + } + + var result []bson.M + _, err = coll.Find(nil).Limit(3).MapReduce(job, &result) + c.Assert(err, IsNil) + c.Assert(len(result), Equals, 3) +} + +func (s *S) TestBuildInfo(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + info, err := session.BuildInfo() + c.Assert(err, IsNil) + + var v []int + for i, a := range strings.Split(info.Version, ".") { + for _, token := range []string{"-rc", "-pre"} { + if i == 2 && strings.Contains(a, token) { + a = a[:strings.Index(a, token)] + info.VersionArray[len(info.VersionArray)-1] = 0 + } + } + n, err := strconv.Atoi(a) + c.Assert(err, IsNil) + v = append(v, n) + } + for len(v) < 4 { + v = append(v, 0) + } + + c.Assert(info.VersionArray, DeepEquals, v) + c.Assert(info.GitVersion, Matches, "[a-z0-9]+") + c.Assert(info.SysInfo, Matches, ".*[0-9:]+.*") + if info.Bits != 32 && info.Bits != 64 { + c.Fatalf("info.Bits is %d", info.Bits) + } + if info.MaxObjectSize < 8192 { + c.Fatalf("info.MaxObjectSize seems too small: %d", info.MaxObjectSize) + } +} + +func (s *S) TestZeroTimeRoundtrip(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + var d struct{ T time.Time } + conn := session.DB("mydb").C("mycoll") + err = conn.Insert(d) + c.Assert(err, IsNil) + + var result bson.M + err = conn.Find(nil).One(&result) + c.Assert(err, IsNil) + t, isTime := result["t"].(time.Time) + c.Assert(isTime, Equals, true) + c.Assert(t, Equals, time.Time{}) +} + +func (s *S) TestFsyncLock(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + clone := session.Clone() + defer clone.Close() + + err = session.FsyncLock() + c.Assert(err, IsNil) + + done := make(chan time.Time) + go func() { + time.Sleep(3e9) + now := time.Now() + err := session.FsyncUnlock() + c.Check(err, IsNil) + done <- now + }() + + err = clone.DB("mydb").C("mycoll").Insert(bson.M{"n": 1}) + unlocked := time.Now() + unlocking := <-done + c.Assert(err, IsNil) + + c.Assert(unlocked.After(unlocking), Equals, true) + c.Assert(unlocked.Sub(unlocking) < 1e9, Equals, true) +} + +func (s *S) TestFsync(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + // Not much to do here. Just a smoke check. + err = session.Fsync(false) + c.Assert(err, IsNil) + err = session.Fsync(true) + c.Assert(err, IsNil) +} + +func (s *S) TestRepairCursor(c *C) { + if !s.versionAtLeast(2, 7) { + c.Skip("RepairCursor only works on 2.7+") + } + + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + session.SetBatch(2) + + coll := session.DB("mydb").C("mycoll3") + err = coll.DropCollection() + + ns := []int{0, 10, 20, 30, 40, 50} + for _, n := range ns { + coll.Insert(M{"n": n}) + } + + repairIter := coll.Repair() + + c.Assert(repairIter.Err(), IsNil) + + result := struct{ N int }{} + resultCounts := map[int]int{} + for repairIter.Next(&result) { + resultCounts[result.N]++ + } + + c.Assert(repairIter.Next(&result), Equals, false) + c.Assert(repairIter.Err(), IsNil) + c.Assert(repairIter.Close(), IsNil) + + // Verify that the results of the repair cursor are valid. + // The repair cursor can return multiple copies + // of the same document, so to check correctness we only + // need to verify that at least 1 of each document was returned. + + for _, key := range ns { + c.Assert(resultCounts[key] > 0, Equals, true) + } +} + +func (s *S) TestPipeIter(c *C) { + if !s.versionAtLeast(2, 1) { + c.Skip("Pipe only works on 2.1+") + } + + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + ns := []int{40, 41, 42, 43, 44, 45, 46} + for _, n := range ns { + coll.Insert(M{"n": n}) + } + + pipe := coll.Pipe([]M{{"$match": M{"n": M{"$gte": 42}}}}) + + // Ensure cursor logic is working by forcing a small batch. + pipe.Batch(2) + + // Smoke test for AllowDiskUse. + pipe.AllowDiskUse() + + iter := pipe.Iter() + result := struct{ N int }{} + for i := 2; i < 7; i++ { + ok := iter.Next(&result) + c.Assert(ok, Equals, true) + c.Assert(result.N, Equals, ns[i]) + } + + c.Assert(iter.Next(&result), Equals, false) + c.Assert(iter.Close(), IsNil) +} + +func (s *S) TestPipeAll(c *C) { + if !s.versionAtLeast(2, 1) { + c.Skip("Pipe only works on 2.1+") + } + + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + ns := []int{40, 41, 42, 43, 44, 45, 46} + for _, n := range ns { + err := coll.Insert(M{"n": n}) + c.Assert(err, IsNil) + } + + var result []struct{ N int } + err = coll.Pipe([]M{{"$match": M{"n": M{"$gte": 42}}}}).All(&result) + c.Assert(err, IsNil) + for i := 2; i < 7; i++ { + c.Assert(result[i-2].N, Equals, ns[i]) + } +} + +func (s *S) TestPipeOne(c *C) { + if !s.versionAtLeast(2, 1) { + c.Skip("Pipe only works on 2.1+") + } + + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + coll.Insert(M{"a": 1, "b": 2}) + + result := struct{ A, B int }{} + + pipe := coll.Pipe([]M{{"$project": M{"a": 1, "b": M{"$add": []interface{}{"$b", 1}}}}}) + err = pipe.One(&result) + c.Assert(err, IsNil) + c.Assert(result.A, Equals, 1) + c.Assert(result.B, Equals, 3) + + pipe = coll.Pipe([]M{{"$match": M{"a": 2}}}) + err = pipe.One(&result) + c.Assert(err, Equals, mgo.ErrNotFound) +} + +func (s *S) TestPipeExplain(c *C) { + if !s.versionAtLeast(2, 1) { + c.Skip("Pipe only works on 2.1+") + } + + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + coll.Insert(M{"a": 1, "b": 2}) + + pipe := coll.Pipe([]M{{"$project": M{"a": 1, "b": M{"$add": []interface{}{"$b", 1}}}}}) + + // The explain command result changes across versions. + var result struct{ Ok int } + err = pipe.Explain(&result) + c.Assert(err, IsNil) + c.Assert(result.Ok, Equals, 1) +} + +func (s *S) TestBatch1Bug(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + for i := 0; i < 3; i++ { + err := coll.Insert(M{"n": i}) + c.Assert(err, IsNil) + } + + var ns []struct{ N int } + err = coll.Find(nil).Batch(1).All(&ns) + c.Assert(err, IsNil) + c.Assert(len(ns), Equals, 3) + + session.SetBatch(1) + err = coll.Find(nil).All(&ns) + c.Assert(err, IsNil) + c.Assert(len(ns), Equals, 3) +} + +func (s *S) TestInterfaceIterBug(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + + for i := 0; i < 3; i++ { + err := coll.Insert(M{"n": i}) + c.Assert(err, IsNil) + } + + var result interface{} + + i := 0 + iter := coll.Find(nil).Sort("n").Iter() + for iter.Next(&result) { + c.Assert(result.(bson.M)["n"], Equals, i) + i++ + } + c.Assert(iter.Close(), IsNil) +} + +func (s *S) TestFindIterCloseKillsCursor(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + cursors := serverCursorsOpen(session) + + coll := session.DB("mydb").C("mycoll") + ns := []int{40, 41, 42, 43, 44, 45, 46} + for _, n := range ns { + err = coll.Insert(M{"n": n}) + c.Assert(err, IsNil) + } + + iter := coll.Find(nil).Batch(2).Iter() + c.Assert(iter.Next(bson.M{}), Equals, true) + + c.Assert(iter.Close(), IsNil) + c.Assert(serverCursorsOpen(session), Equals, cursors) +} + +func (s *S) TestLogReplay(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + for i := 0; i < 5; i++ { + err = coll.Insert(M{"ts": time.Now()}) + c.Assert(err, IsNil) + } + + iter := coll.Find(nil).LogReplay().Iter() + if s.versionAtLeast(2, 6) { + // This used to fail in 2.4. Now it's just a smoke test. + c.Assert(iter.Err(), IsNil) + } else { + c.Assert(iter.Next(bson.M{}), Equals, false) + c.Assert(iter.Err(), ErrorMatches, "no ts field in query") + } +} + +func (s *S) TestSetCursorTimeout(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + err = coll.Insert(M{"n": 42}) + + // This is just a smoke test. Won't wait 10 minutes for an actual timeout. + + session.SetCursorTimeout(0) + + var result struct{ N int } + iter := coll.Find(nil).Iter() + c.Assert(iter.Next(&result), Equals, true) + c.Assert(result.N, Equals, 42) + c.Assert(iter.Next(&result), Equals, false) +} + +func (s *S) TestNewIterNoServer(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + data, err := bson.Marshal(bson.M{"a": 1}) + + coll := session.DB("mydb").C("mycoll") + iter := coll.NewIter(nil, []bson.Raw{{3, data}}, 42, nil) + + var result struct{ A int } + ok := iter.Next(&result) + c.Assert(ok, Equals, true) + c.Assert(result.A, Equals, 1) + + ok = iter.Next(&result) + c.Assert(ok, Equals, false) + + c.Assert(iter.Err(), ErrorMatches, "server not available") +} + +func (s *S) TestNewIterNoServerPresetErr(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + data, err := bson.Marshal(bson.M{"a": 1}) + + coll := session.DB("mydb").C("mycoll") + iter := coll.NewIter(nil, []bson.Raw{{3, data}}, 42, fmt.Errorf("my error")) + + var result struct{ A int } + ok := iter.Next(&result) + c.Assert(ok, Equals, true) + c.Assert(result.A, Equals, 1) + + ok = iter.Next(&result) + c.Assert(ok, Equals, false) + + c.Assert(iter.Err(), ErrorMatches, "my error") +} + +// -------------------------------------------------------------------------- +// Some benchmarks that require a running database. + +func (s *S) BenchmarkFindIterRaw(c *C) { + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + defer session.Close() + + coll := session.DB("mydb").C("mycoll") + doc := bson.D{ + {"f2", "a short string"}, + {"f3", bson.D{{"1", "one"}, {"2", 2.0}}}, + {"f4", []string{"a", "b", "c", "d", "e", "f", "g"}}, + } + + for i := 0; i < c.N+1; i++ { + err := coll.Insert(doc) + c.Assert(err, IsNil) + } + + session.SetBatch(c.N) + + var raw bson.Raw + iter := coll.Find(nil).Iter() + iter.Next(&raw) + c.ResetTimer() + i := 0 + for iter.Next(&raw) { + i++ + } + c.StopTimer() + c.Assert(iter.Err(), IsNil) + c.Assert(i, Equals, c.N) +} diff --git a/vendor/gopkg.in/mgo.v2/socket.go b/vendor/gopkg.in/mgo.v2/socket.go new file mode 100644 index 000000000..f6882d501 --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/socket.go @@ -0,0 +1,677 @@ +// mgo - MongoDB driver for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package mgo + +import ( + "errors" + "net" + "sync" + "time" + + "gopkg.in/mgo.v2/bson" +) + +type replyFunc func(err error, reply *replyOp, docNum int, docData []byte) + +type mongoSocket struct { + sync.Mutex + server *mongoServer // nil when cached + conn net.Conn + timeout time.Duration + addr string // For debugging only. + nextRequestId uint32 + replyFuncs map[uint32]replyFunc + references int + creds []Credential + logout []Credential + cachedNonce string + gotNonce sync.Cond + dead error + serverInfo *mongoServerInfo +} + +type queryOpFlags uint32 + +const ( + _ queryOpFlags = 1 << iota + flagTailable + flagSlaveOk + flagLogReplay + flagNoCursorTimeout + flagAwaitData +) + +type queryOp struct { + collection string + query interface{} + skip int32 + limit int32 + selector interface{} + flags queryOpFlags + replyFunc replyFunc + + options queryWrapper + hasOptions bool + serverTags []bson.D +} + +type queryWrapper struct { + Query interface{} "$query" + OrderBy interface{} "$orderby,omitempty" + Hint interface{} "$hint,omitempty" + Explain bool "$explain,omitempty" + Snapshot bool "$snapshot,omitempty" + ReadPreference bson.D "$readPreference,omitempty" + MaxScan int "$maxScan,omitempty" + MaxTimeMS int "$maxTimeMS,omitempty" + Comment string "$comment,omitempty" +} + +func (op *queryOp) finalQuery(socket *mongoSocket) interface{} { + if op.flags&flagSlaveOk != 0 && len(op.serverTags) > 0 && socket.ServerInfo().Mongos { + op.hasOptions = true + op.options.ReadPreference = bson.D{{"mode", "secondaryPreferred"}, {"tags", op.serverTags}} + } + if op.hasOptions { + if op.query == nil { + var empty bson.D + op.options.Query = empty + } else { + op.options.Query = op.query + } + debugf("final query is %#v\n", &op.options) + return &op.options + } + return op.query +} + +type getMoreOp struct { + collection string + limit int32 + cursorId int64 + replyFunc replyFunc +} + +type replyOp struct { + flags uint32 + cursorId int64 + firstDoc int32 + replyDocs int32 +} + +type insertOp struct { + collection string // "database.collection" + documents []interface{} // One or more documents to insert + flags uint32 +} + +type updateOp struct { + collection string // "database.collection" + selector interface{} + update interface{} + flags uint32 +} + +type deleteOp struct { + collection string // "database.collection" + selector interface{} + flags uint32 +} + +type killCursorsOp struct { + cursorIds []int64 +} + +type requestInfo struct { + bufferPos int + replyFunc replyFunc +} + +func newSocket(server *mongoServer, conn net.Conn, timeout time.Duration) *mongoSocket { + socket := &mongoSocket{ + conn: conn, + addr: server.Addr, + server: server, + replyFuncs: make(map[uint32]replyFunc), + } + socket.gotNonce.L = &socket.Mutex + if err := socket.InitialAcquire(server.Info(), timeout); err != nil { + panic("newSocket: InitialAcquire returned error: " + err.Error()) + } + stats.socketsAlive(+1) + debugf("Socket %p to %s: initialized", socket, socket.addr) + socket.resetNonce() + go socket.readLoop() + return socket +} + +// Server returns the server that the socket is associated with. +// It returns nil while the socket is cached in its respective server. +func (socket *mongoSocket) Server() *mongoServer { + socket.Lock() + server := socket.server + socket.Unlock() + return server +} + +// ServerInfo returns details for the server at the time the socket +// was initially acquired. +func (socket *mongoSocket) ServerInfo() *mongoServerInfo { + socket.Lock() + serverInfo := socket.serverInfo + socket.Unlock() + return serverInfo +} + +// InitialAcquire obtains the first reference to the socket, either +// right after the connection is made or once a recycled socket is +// being put back in use. +func (socket *mongoSocket) InitialAcquire(serverInfo *mongoServerInfo, timeout time.Duration) error { + socket.Lock() + if socket.references > 0 { + panic("Socket acquired out of cache with references") + } + if socket.dead != nil { + dead := socket.dead + socket.Unlock() + return dead + } + socket.references++ + socket.serverInfo = serverInfo + socket.timeout = timeout + stats.socketsInUse(+1) + stats.socketRefs(+1) + socket.Unlock() + return nil +} + +// Acquire obtains an additional reference to the socket. +// The socket will only be recycled when it's released as many +// times as it's been acquired. +func (socket *mongoSocket) Acquire() (info *mongoServerInfo) { + socket.Lock() + if socket.references == 0 { + panic("Socket got non-initial acquire with references == 0") + } + // We'll track references to dead sockets as well. + // Caller is still supposed to release the socket. + socket.references++ + stats.socketRefs(+1) + serverInfo := socket.serverInfo + socket.Unlock() + return serverInfo +} + +// Release decrements a socket reference. The socket will be +// recycled once its released as many times as it's been acquired. +func (socket *mongoSocket) Release() { + socket.Lock() + if socket.references == 0 { + panic("socket.Release() with references == 0") + } + socket.references-- + stats.socketRefs(-1) + if socket.references == 0 { + stats.socketsInUse(-1) + server := socket.server + socket.Unlock() + socket.LogoutAll() + // If the socket is dead server is nil. + if server != nil { + server.RecycleSocket(socket) + } + } else { + socket.Unlock() + } +} + +// SetTimeout changes the timeout used on socket operations. +func (socket *mongoSocket) SetTimeout(d time.Duration) { + socket.Lock() + socket.timeout = d + socket.Unlock() +} + +type deadlineType int + +const ( + readDeadline deadlineType = 1 + writeDeadline deadlineType = 2 +) + +func (socket *mongoSocket) updateDeadline(which deadlineType) { + var when time.Time + if socket.timeout > 0 { + when = time.Now().Add(socket.timeout) + } + whichstr := "" + switch which { + case readDeadline | writeDeadline: + whichstr = "read/write" + socket.conn.SetDeadline(when) + case readDeadline: + whichstr = "read" + socket.conn.SetReadDeadline(when) + case writeDeadline: + whichstr = "write" + socket.conn.SetWriteDeadline(when) + default: + panic("invalid parameter to updateDeadline") + } + debugf("Socket %p to %s: updated %s deadline to %s ahead (%s)", socket, socket.addr, whichstr, socket.timeout, when) +} + +// Close terminates the socket use. +func (socket *mongoSocket) Close() { + socket.kill(errors.New("Closed explicitly"), false) +} + +func (socket *mongoSocket) kill(err error, abend bool) { + socket.Lock() + if socket.dead != nil { + debugf("Socket %p to %s: killed again: %s (previously: %s)", socket, socket.addr, err.Error(), socket.dead.Error()) + socket.Unlock() + return + } + logf("Socket %p to %s: closing: %s (abend=%v)", socket, socket.addr, err.Error(), abend) + socket.dead = err + socket.conn.Close() + stats.socketsAlive(-1) + replyFuncs := socket.replyFuncs + socket.replyFuncs = make(map[uint32]replyFunc) + server := socket.server + socket.server = nil + socket.gotNonce.Broadcast() + socket.Unlock() + for _, replyFunc := range replyFuncs { + logf("Socket %p to %s: notifying replyFunc of closed socket: %s", socket, socket.addr, err.Error()) + replyFunc(err, nil, -1, nil) + } + if abend { + server.AbendSocket(socket) + } +} + +func (socket *mongoSocket) SimpleQuery(op *queryOp) (data []byte, err error) { + var wait, change sync.Mutex + var replyDone bool + var replyData []byte + var replyErr error + wait.Lock() + op.replyFunc = func(err error, reply *replyOp, docNum int, docData []byte) { + change.Lock() + if !replyDone { + replyDone = true + replyErr = err + if err == nil { + replyData = docData + } + } + change.Unlock() + wait.Unlock() + } + err = socket.Query(op) + if err != nil { + return nil, err + } + wait.Lock() + change.Lock() + data = replyData + err = replyErr + change.Unlock() + return data, err +} + +func (socket *mongoSocket) Query(ops ...interface{}) (err error) { + + if lops := socket.flushLogout(); len(lops) > 0 { + ops = append(lops, ops...) + } + + buf := make([]byte, 0, 256) + + // Serialize operations synchronously to avoid interrupting + // other goroutines while we can't really be sending data. + // Also, record id positions so that we can compute request + // ids at once later with the lock already held. + requests := make([]requestInfo, len(ops)) + requestCount := 0 + + for _, op := range ops { + debugf("Socket %p to %s: serializing op: %#v", socket, socket.addr, op) + start := len(buf) + var replyFunc replyFunc + switch op := op.(type) { + + case *updateOp: + buf = addHeader(buf, 2001) + buf = addInt32(buf, 0) // Reserved + buf = addCString(buf, op.collection) + buf = addInt32(buf, int32(op.flags)) + debugf("Socket %p to %s: serializing selector document: %#v", socket, socket.addr, op.selector) + buf, err = addBSON(buf, op.selector) + if err != nil { + return err + } + debugf("Socket %p to %s: serializing update document: %#v", socket, socket.addr, op.update) + buf, err = addBSON(buf, op.update) + if err != nil { + return err + } + + case *insertOp: + buf = addHeader(buf, 2002) + buf = addInt32(buf, int32(op.flags)) + buf = addCString(buf, op.collection) + for _, doc := range op.documents { + debugf("Socket %p to %s: serializing document for insertion: %#v", socket, socket.addr, doc) + buf, err = addBSON(buf, doc) + if err != nil { + return err + } + } + + case *queryOp: + buf = addHeader(buf, 2004) + buf = addInt32(buf, int32(op.flags)) + buf = addCString(buf, op.collection) + buf = addInt32(buf, op.skip) + buf = addInt32(buf, op.limit) + buf, err = addBSON(buf, op.finalQuery(socket)) + if err != nil { + return err + } + if op.selector != nil { + buf, err = addBSON(buf, op.selector) + if err != nil { + return err + } + } + replyFunc = op.replyFunc + + case *getMoreOp: + buf = addHeader(buf, 2005) + buf = addInt32(buf, 0) // Reserved + buf = addCString(buf, op.collection) + buf = addInt32(buf, op.limit) + buf = addInt64(buf, op.cursorId) + replyFunc = op.replyFunc + + case *deleteOp: + buf = addHeader(buf, 2006) + buf = addInt32(buf, 0) // Reserved + buf = addCString(buf, op.collection) + buf = addInt32(buf, int32(op.flags)) + debugf("Socket %p to %s: serializing selector document: %#v", socket, socket.addr, op.selector) + buf, err = addBSON(buf, op.selector) + if err != nil { + return err + } + + case *killCursorsOp: + buf = addHeader(buf, 2007) + buf = addInt32(buf, 0) // Reserved + buf = addInt32(buf, int32(len(op.cursorIds))) + for _, cursorId := range op.cursorIds { + buf = addInt64(buf, cursorId) + } + + default: + panic("internal error: unknown operation type") + } + + setInt32(buf, start, int32(len(buf)-start)) + + if replyFunc != nil { + request := &requests[requestCount] + request.replyFunc = replyFunc + request.bufferPos = start + requestCount++ + } + } + + // Buffer is ready for the pipe. Lock, allocate ids, and enqueue. + + socket.Lock() + if socket.dead != nil { + dead := socket.dead + socket.Unlock() + debugf("Socket %p to %s: failing query, already closed: %s", socket, socket.addr, socket.dead.Error()) + // XXX This seems necessary in case the session is closed concurrently + // with a query being performed, but it's not yet tested: + for i := 0; i != requestCount; i++ { + request := &requests[i] + if request.replyFunc != nil { + request.replyFunc(dead, nil, -1, nil) + } + } + return dead + } + + wasWaiting := len(socket.replyFuncs) > 0 + + // Reserve id 0 for requests which should have no responses. + requestId := socket.nextRequestId + 1 + if requestId == 0 { + requestId++ + } + socket.nextRequestId = requestId + uint32(requestCount) + for i := 0; i != requestCount; i++ { + request := &requests[i] + setInt32(buf, request.bufferPos+4, int32(requestId)) + socket.replyFuncs[requestId] = request.replyFunc + requestId++ + } + + debugf("Socket %p to %s: sending %d op(s) (%d bytes)", socket, socket.addr, len(ops), len(buf)) + stats.sentOps(len(ops)) + + socket.updateDeadline(writeDeadline) + _, err = socket.conn.Write(buf) + if !wasWaiting && requestCount > 0 { + socket.updateDeadline(readDeadline) + } + socket.Unlock() + return err +} + +func fill(r net.Conn, b []byte) error { + l := len(b) + n, err := r.Read(b) + for n != l && err == nil { + var ni int + ni, err = r.Read(b[n:]) + n += ni + } + return err +} + +// Estimated minimum cost per socket: 1 goroutine + memory for the largest +// document ever seen. +func (socket *mongoSocket) readLoop() { + p := make([]byte, 36) // 16 from header + 20 from OP_REPLY fixed fields + s := make([]byte, 4) + conn := socket.conn // No locking, conn never changes. + for { + // XXX Handle timeouts, , etc + err := fill(conn, p) + if err != nil { + socket.kill(err, true) + return + } + + totalLen := getInt32(p, 0) + responseTo := getInt32(p, 8) + opCode := getInt32(p, 12) + + // Don't use socket.server.Addr here. socket is not + // locked and socket.server may go away. + debugf("Socket %p to %s: got reply (%d bytes)", socket, socket.addr, totalLen) + + _ = totalLen + + if opCode != 1 { + socket.kill(errors.New("opcode != 1, corrupted data?"), true) + return + } + + reply := replyOp{ + flags: uint32(getInt32(p, 16)), + cursorId: getInt64(p, 20), + firstDoc: getInt32(p, 28), + replyDocs: getInt32(p, 32), + } + + stats.receivedOps(+1) + stats.receivedDocs(int(reply.replyDocs)) + + socket.Lock() + replyFunc, ok := socket.replyFuncs[uint32(responseTo)] + if ok { + delete(socket.replyFuncs, uint32(responseTo)) + } + socket.Unlock() + + if replyFunc != nil && reply.replyDocs == 0 { + replyFunc(nil, &reply, -1, nil) + } else { + for i := 0; i != int(reply.replyDocs); i++ { + err := fill(conn, s) + if err != nil { + if replyFunc != nil { + replyFunc(err, nil, -1, nil) + } + socket.kill(err, true) + return + } + + b := make([]byte, int(getInt32(s, 0))) + + // copy(b, s) in an efficient way. + b[0] = s[0] + b[1] = s[1] + b[2] = s[2] + b[3] = s[3] + + err = fill(conn, b[4:]) + if err != nil { + if replyFunc != nil { + replyFunc(err, nil, -1, nil) + } + socket.kill(err, true) + return + } + + if globalDebug && globalLogger != nil { + m := bson.M{} + if err := bson.Unmarshal(b, m); err == nil { + debugf("Socket %p to %s: received document: %#v", socket, socket.addr, m) + } + } + + if replyFunc != nil { + replyFunc(nil, &reply, i, b) + } + + // XXX Do bound checking against totalLen. + } + } + + socket.Lock() + if len(socket.replyFuncs) == 0 { + // Nothing else to read for now. Disable deadline. + socket.conn.SetReadDeadline(time.Time{}) + } else { + socket.updateDeadline(readDeadline) + } + socket.Unlock() + + // XXX Do bound checking against totalLen. + } +} + +var emptyHeader = []byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0} + +func addHeader(b []byte, opcode int) []byte { + i := len(b) + b = append(b, emptyHeader...) + // Enough for current opcodes. + b[i+12] = byte(opcode) + b[i+13] = byte(opcode >> 8) + return b +} + +func addInt32(b []byte, i int32) []byte { + return append(b, byte(i), byte(i>>8), byte(i>>16), byte(i>>24)) +} + +func addInt64(b []byte, i int64) []byte { + return append(b, byte(i), byte(i>>8), byte(i>>16), byte(i>>24), + byte(i>>32), byte(i>>40), byte(i>>48), byte(i>>56)) +} + +func addCString(b []byte, s string) []byte { + b = append(b, []byte(s)...) + b = append(b, 0) + return b +} + +func addBSON(b []byte, doc interface{}) ([]byte, error) { + if doc == nil { + return append(b, 5, 0, 0, 0, 0), nil + } + data, err := bson.Marshal(doc) + if err != nil { + return b, err + } + return append(b, data...), nil +} + +func setInt32(b []byte, pos int, i int32) { + b[pos] = byte(i) + b[pos+1] = byte(i >> 8) + b[pos+2] = byte(i >> 16) + b[pos+3] = byte(i >> 24) +} + +func getInt32(b []byte, pos int) int32 { + return (int32(b[pos+0])) | + (int32(b[pos+1]) << 8) | + (int32(b[pos+2]) << 16) | + (int32(b[pos+3]) << 24) +} + +func getInt64(b []byte, pos int) int64 { + return (int64(b[pos+0])) | + (int64(b[pos+1]) << 8) | + (int64(b[pos+2]) << 16) | + (int64(b[pos+3]) << 24) | + (int64(b[pos+4]) << 32) | + (int64(b[pos+5]) << 40) | + (int64(b[pos+6]) << 48) | + (int64(b[pos+7]) << 56) +} diff --git a/vendor/gopkg.in/mgo.v2/stats.go b/vendor/gopkg.in/mgo.v2/stats.go new file mode 100644 index 000000000..59723e60c --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/stats.go @@ -0,0 +1,147 @@ +// mgo - MongoDB driver for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package mgo + +import ( + "sync" +) + +var stats *Stats +var statsMutex sync.Mutex + +func SetStats(enabled bool) { + statsMutex.Lock() + if enabled { + if stats == nil { + stats = &Stats{} + } + } else { + stats = nil + } + statsMutex.Unlock() +} + +func GetStats() (snapshot Stats) { + statsMutex.Lock() + snapshot = *stats + statsMutex.Unlock() + return +} + +func ResetStats() { + statsMutex.Lock() + debug("Resetting stats") + old := stats + stats = &Stats{} + // These are absolute values: + stats.Clusters = old.Clusters + stats.SocketsInUse = old.SocketsInUse + stats.SocketsAlive = old.SocketsAlive + stats.SocketRefs = old.SocketRefs + statsMutex.Unlock() + return +} + +type Stats struct { + Clusters int + MasterConns int + SlaveConns int + SentOps int + ReceivedOps int + ReceivedDocs int + SocketsAlive int + SocketsInUse int + SocketRefs int +} + +func (stats *Stats) cluster(delta int) { + if stats != nil { + statsMutex.Lock() + stats.Clusters += delta + statsMutex.Unlock() + } +} + +func (stats *Stats) conn(delta int, master bool) { + if stats != nil { + statsMutex.Lock() + if master { + stats.MasterConns += delta + } else { + stats.SlaveConns += delta + } + statsMutex.Unlock() + } +} + +func (stats *Stats) sentOps(delta int) { + if stats != nil { + statsMutex.Lock() + stats.SentOps += delta + statsMutex.Unlock() + } +} + +func (stats *Stats) receivedOps(delta int) { + if stats != nil { + statsMutex.Lock() + stats.ReceivedOps += delta + statsMutex.Unlock() + } +} + +func (stats *Stats) receivedDocs(delta int) { + if stats != nil { + statsMutex.Lock() + stats.ReceivedDocs += delta + statsMutex.Unlock() + } +} + +func (stats *Stats) socketsInUse(delta int) { + if stats != nil { + statsMutex.Lock() + stats.SocketsInUse += delta + statsMutex.Unlock() + } +} + +func (stats *Stats) socketsAlive(delta int) { + if stats != nil { + statsMutex.Lock() + stats.SocketsAlive += delta + statsMutex.Unlock() + } +} + +func (stats *Stats) socketRefs(delta int) { + if stats != nil { + statsMutex.Lock() + stats.SocketRefs += delta + statsMutex.Unlock() + } +} diff --git a/vendor/gopkg.in/mgo.v2/suite_test.go b/vendor/gopkg.in/mgo.v2/suite_test.go new file mode 100644 index 000000000..140e5a09a --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/suite_test.go @@ -0,0 +1,254 @@ +// mgo - MongoDB driver for Go +// +// Copyright (c) 2010-2012 - Gustavo Niemeyer +// +// All rights reserved. +// +// Redistribution and use in source and binary forms, with or without +// modification, are permitted provided that the following conditions are met: +// +// 1. Redistributions of source code must retain the above copyright notice, this +// list of conditions and the following disclaimer. +// 2. Redistributions in binary form must reproduce the above copyright notice, +// this list of conditions and the following disclaimer in the documentation +// and/or other materials provided with the distribution. +// +// THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND +// ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED +// WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +// DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT OWNER OR CONTRIBUTORS BE LIABLE FOR +// ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES +// (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; +// LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND +// ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT +// (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS +// SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +package mgo_test + +import ( + "errors" + "flag" + "fmt" + "net" + "os/exec" + "runtime" + "strconv" + "testing" + "time" + + . "gopkg.in/check.v1" + "gopkg.in/mgo.v2" + "gopkg.in/mgo.v2/bson" +) + +var fast = flag.Bool("fast", false, "Skip slow tests") + +type M bson.M + +type cLogger C + +func (c *cLogger) Output(calldepth int, s string) error { + ns := time.Now().UnixNano() + t := float64(ns%100e9) / 1e9 + ((*C)(c)).Logf("[LOG] %.05f %s", t, s) + return nil +} + +func TestAll(t *testing.T) { + TestingT(t) +} + +type S struct { + session *mgo.Session + stopped bool + build mgo.BuildInfo + frozen []string +} + +func (s *S) versionAtLeast(v ...int) (result bool) { + for i := range v { + if i == len(s.build.VersionArray) { + return false + } + if s.build.VersionArray[i] != v[i] { + return s.build.VersionArray[i] >= v[i] + } + } + return true +} + +var _ = Suite(&S{}) + +func (s *S) SetUpSuite(c *C) { + mgo.SetDebug(true) + mgo.SetStats(true) + s.StartAll() + + session, err := mgo.Dial("localhost:40001") + c.Assert(err, IsNil) + s.build, err = session.BuildInfo() + c.Check(err, IsNil) + session.Close() +} + +func (s *S) SetUpTest(c *C) { + err := run("mongo --nodb testdb/dropall.js") + if err != nil { + panic(err.Error()) + } + mgo.SetLogger((*cLogger)(c)) + mgo.ResetStats() +} + +func (s *S) TearDownTest(c *C) { + if s.stopped { + s.StartAll() + } + for _, host := range s.frozen { + if host != "" { + s.Thaw(host) + } + } + var stats mgo.Stats + for i := 0; ; i++ { + stats = mgo.GetStats() + if stats.SocketsInUse == 0 && stats.SocketsAlive == 0 { + break + } + if i == 20 { + c.Fatal("Test left sockets in a dirty state") + } + c.Logf("Waiting for sockets to die: %d in use, %d alive", stats.SocketsInUse, stats.SocketsAlive) + time.Sleep(500 * time.Millisecond) + } + for i := 0; ; i++ { + stats = mgo.GetStats() + if stats.Clusters == 0 { + break + } + if i == 60 { + c.Fatal("Test left clusters alive") + } + c.Logf("Waiting for clusters to die: %d alive", stats.Clusters) + time.Sleep(1 * time.Second) + } +} + +func (s *S) Stop(host string) { + // Give a moment for slaves to sync and avoid getting rollback issues. + panicOnWindows() + time.Sleep(2 * time.Second) + err := run("cd _testdb && supervisorctl stop " + supvName(host)) + if err != nil { + panic(err) + } + s.stopped = true +} + +func (s *S) pid(host string) int { + output, err := exec.Command("lsof", "-iTCP:"+hostPort(host), "-sTCP:LISTEN", "-Fp").CombinedOutput() + if err != nil { + panic(err) + } + pidstr := string(output[1 : len(output)-1]) + pid, err := strconv.Atoi(pidstr) + if err != nil { + panic("cannot convert pid to int: " + pidstr) + } + return pid +} + +func (s *S) Freeze(host string) { + err := stop(s.pid(host)) + if err != nil { + panic(err) + } + s.frozen = append(s.frozen, host) +} + +func (s *S) Thaw(host string) { + err := cont(s.pid(host)) + if err != nil { + panic(err) + } + for i, frozen := range s.frozen { + if frozen == host { + s.frozen[i] = "" + } + } +} + +func (s *S) StartAll() { + // Restart any stopped nodes. + run("cd _testdb && supervisorctl start all") + err := run("cd testdb && mongo --nodb wait.js") + if err != nil { + panic(err) + } + s.stopped = false +} + +func run(command string) error { + var output []byte + var err error + if runtime.GOOS == "windows" { + output, err = exec.Command("cmd", "/C", command).CombinedOutput() + } else { + output, err = exec.Command("/bin/sh", "-c", command).CombinedOutput() + } + + if err != nil { + msg := fmt.Sprintf("Failed to execute: %s: %s\n%s", command, err.Error(), string(output)) + return errors.New(msg) + } + return nil +} + +var supvNames = map[string]string{ + "40001": "db1", + "40002": "db2", + "40011": "rs1a", + "40012": "rs1b", + "40013": "rs1c", + "40021": "rs2a", + "40022": "rs2b", + "40023": "rs2c", + "40031": "rs3a", + "40032": "rs3b", + "40033": "rs3c", + "40041": "rs4a", + "40101": "cfg1", + "40102": "cfg2", + "40103": "cfg3", + "40201": "s1", + "40202": "s2", + "40203": "s3", +} + +// supvName returns the supervisord name for the given host address. +func supvName(host string) string { + host, port, err := net.SplitHostPort(host) + if err != nil { + panic(err) + } + name, ok := supvNames[port] + if !ok { + panic("Unknown host: " + host) + } + return name +} + +func hostPort(host string) string { + _, port, err := net.SplitHostPort(host) + if err != nil { + panic(err) + } + return port +} + +func panicOnWindows() { + if runtime.GOOS == "windows" { + panic("the test suite is not yet fully supported on Windows") + } +} diff --git a/vendor/gopkg.in/mgo.v2/syscall_test.go b/vendor/gopkg.in/mgo.v2/syscall_test.go new file mode 100644 index 000000000..b8bbd7b34 --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/syscall_test.go @@ -0,0 +1,15 @@ +// +build !windows + +package mgo_test + +import ( + "syscall" +) + +func stop(pid int) (err error) { + return syscall.Kill(pid, syscall.SIGSTOP) +} + +func cont(pid int) (err error) { + return syscall.Kill(pid, syscall.SIGCONT) +} diff --git a/vendor/gopkg.in/mgo.v2/syscall_windows_test.go b/vendor/gopkg.in/mgo.v2/syscall_windows_test.go new file mode 100644 index 000000000..f2deaca86 --- /dev/null +++ b/vendor/gopkg.in/mgo.v2/syscall_windows_test.go @@ -0,0 +1,11 @@ +package mgo_test + +func stop(pid int) (err error) { + panicOnWindows() // Always does. + return nil +} + +func cont(pid int) (err error) { + panicOnWindows() // Always does. + return nil +} diff --git a/version_test.go b/version_test.go index bf467b3d6..c7c838b34 100644 --- a/version_test.go +++ b/version_test.go @@ -23,7 +23,11 @@ import ( . "gopkg.in/check.v1" ) -func (s *TestSuite) TestVersion(c *C) { +type VersionSuite struct{} + +var _ = Suite(&VersionSuite{}) + +func (s *VersionSuite) TestVersion(c *C) { _, err := time.Parse(minioVersion, http.TimeFormat) c.Assert(err, NotNil) }