// Copyright 2009 The Go Authors. All rights reserved. // Copyright 2012 The Gorilla Authors. All rights reserved. // Use of this source code is governed by a BSD-style // license that can be found in the LICENSE file. package rpc import ( "fmt" "net/http" "reflect" "strings" ) // ---------------------------------------------------------------------------- // Codec // ---------------------------------------------------------------------------- // Codec creates a CodecRequest to process each request. type Codec interface { NewRequest(*http.Request) CodecRequest } // CodecRequest decodes a request and encodes a response using a specific // serialization scheme. type CodecRequest interface { // Reads the request and returns the RPC method name. Method() (string, error) // Reads the request filling the RPC method args. ReadRequest(interface{}) error // Writes the response using the RPC method reply. WriteResponse(http.ResponseWriter, interface{}) // Writes an error produced by the server. WriteError(w http.ResponseWriter, status int, err error) } // ---------------------------------------------------------------------------- // Server // ---------------------------------------------------------------------------- // NewServer returns a new RPC server. func NewServer() *Server { return &Server{ codecs: make(map[string]Codec), services: new(serviceMap), } } // Server serves registered RPC services using registered codecs. type Server struct { codecs map[string]Codec services *serviceMap } // RegisterCodec adds a new codec to the server. // // Codecs are defined to process a given serialization scheme, e.g., JSON or // XML. A codec is chosen based on the "Content-Type" header from the request, // excluding the charset definition. func (s *Server) RegisterCodec(codec Codec, contentType string) { s.codecs[strings.ToLower(contentType)] = codec } // RegisterService adds a new service to the server. // // The name parameter is optional: if empty it will be inferred from // the receiver type name. // // Methods from the receiver will be extracted if these rules are satisfied: // // - The receiver is exported (begins with an upper case letter) or local // (defined in the package registering the service). // - The method name is exported. // - The method has three arguments: *http.Request, *args, *reply. // - All three arguments are pointers. // - The second and third arguments are exported or local. // - The method has return type error. // // All other methods are ignored. func (s *Server) RegisterService(receiver interface{}, name string) error { return s.services.register(receiver, name) } // HasMethod returns true if the given method is registered. // // The method uses a dotted notation as in "Service.Method". func (s *Server) HasMethod(method string) bool { if _, _, err := s.services.get(method); err == nil { return true } return false } // ServeHTTP func (s *Server) ServeHTTP(w http.ResponseWriter, r *http.Request) { if r.Method != "POST" { WriteError(w, 405, "rpc: POST method required, received "+r.Method) return } contentType := r.Header.Get("Content-Type") idx := strings.Index(contentType, ";") if idx != -1 { contentType = contentType[:idx] } var codec Codec if contentType == "" && len(s.codecs) == 1 { // If Content-Type is not set and only one codec has been registered, // then default to that codec. for _, c := range s.codecs { codec = c } } else if codec = s.codecs[strings.ToLower(contentType)]; codec == nil { WriteError(w, 415, "rpc: unrecognized Content-Type: "+contentType) return } // Create a new codec request. codecReq := codec.NewRequest(r) // Get service method to be called. method, errMethod := codecReq.Method() if errMethod != nil { codecReq.WriteError(w, 400, errMethod) return } serviceSpec, methodSpec, errGet := s.services.get(method) if errGet != nil { codecReq.WriteError(w, 400, errGet) return } // Decode the args. args := reflect.New(methodSpec.argsType) if errRead := codecReq.ReadRequest(args.Interface()); errRead != nil { codecReq.WriteError(w, 400, errRead) return } // Call the service method. reply := reflect.New(methodSpec.replyType) errValue := methodSpec.method.Func.Call([]reflect.Value{ serviceSpec.rcvr, reflect.ValueOf(r), args, reply, }) // Cast the result to error if needed. var errResult error errInter := errValue[0].Interface() if errInter != nil { errResult = errInter.(error) } // Prevents Internet Explorer from MIME-sniffing a response away // from the declared content-type w.Header().Set("x-content-type-options", "nosniff") // Prevents against XSS Atacks w.Header().Set("X-XSS-Protection", "\"1; mode=block\"") // Prevents against Clickjacking w.Header().Set("X-Frame-Options", "SAMEORIGIN") // Encode the response. if errResult == nil { codecReq.WriteResponse(w, reply.Interface()) } else { codecReq.WriteError(w, 400, errResult) } } func WriteError(w http.ResponseWriter, status int, msg string) { w.WriteHeader(status) w.Header().Set("Content-Type", "text/plain; charset=utf-8") fmt.Fprint(w, msg) }