diff --git a/transport/http/proto/client.go b/transport/http/proto/client.go new file mode 100644 index 0000000..cc729f9 --- /dev/null +++ b/transport/http/proto/client.go @@ -0,0 +1,37 @@ +package proto + +import ( + "bytes" + "context" + "errors" + "io/ioutil" + "net/http" + + httptransport "github.com/go-kit/kit/transport/http" + "github.com/golang/protobuf/proto" +) + +// EncodeProtoRequest is an EncodeRequestFunc that serializes the request as Protobuf. +// If the request implements Headerer, the provided headers will be applied +// to the request. If the given request does not implement proto.Message, an error will +// be returned. +func EncodeProtoRequest(_ context.Context, r *http.Request, preq interface{}) error { + r.Header.Set("Content-Type", "application/x-protobuf") + if headerer, ok := preq.(httptransport.Headerer); ok { + for k := range headerer.Headers() { + r.Header.Set(k, headerer.Headers().Get(k)) + } + } + req, ok := preq.(proto.Message) + if !ok { + return errors.New("response does not implement proto.Message") + } + + b, err := proto.Marshal(req) + if err != nil { + return err + } + r.ContentLength = int64(len(b)) + r.Body = ioutil.NopCloser(bytes.NewReader(b)) + return nil +} diff --git a/transport/http/proto/proto_pb_test.go b/transport/http/proto/proto_pb_test.go new file mode 100644 index 0000000..779d888 --- /dev/null +++ b/transport/http/proto/proto_pb_test.go @@ -0,0 +1,94 @@ +// Code generated by protoc-gen-go. DO NOT EDIT. +// source: proto_test.proto + +package proto + +import ( + fmt "fmt" + math "math" + + proto "github.com/golang/protobuf/proto" +) + +// Reference imports to suppress errors if they are not otherwise used. +var _ = proto.Marshal +var _ = fmt.Errorf +var _ = math.Inf + +// This is a compile-time assertion to ensure that this generated file +// is compatible with the proto package it is being compiled against. +// A compilation error at this line likely means your copy of the +// proto package needs to be updated. +const _ = proto.ProtoPackageIsVersion2 // please upgrade the proto package + +type Cat struct { + Age int32 `protobuf:"varint,1,opt,name=Age,proto3" json:"Age,omitempty"` + Breed string `protobuf:"bytes,2,opt,name=Breed,proto3" json:"Breed,omitempty"` + Name string `protobuf:"bytes,3,opt,name=Name,proto3" json:"Name,omitempty"` + XXX_NoUnkeyedLiteral struct{} `json:"-"` + XXX_unrecognized []byte `json:"-"` + XXX_sizecache int32 `json:"-"` +} + +func (m *Cat) Reset() { *m = Cat{} } +func (m *Cat) String() string { return proto.CompactTextString(m) } +func (*Cat) ProtoMessage() {} +func (*Cat) Descriptor() ([]byte, []int) { + return fileDescriptor_a794ba8d0e5440a3, []int{0} +} + +func (m *Cat) XXX_Unmarshal(b []byte) error { + return xxx_messageInfo_Cat.Unmarshal(m, b) +} +func (m *Cat) XXX_Marshal(b []byte, deterministic bool) ([]byte, error) { + return xxx_messageInfo_Cat.Marshal(b, m, deterministic) +} +func (m *Cat) XXX_Merge(src proto.Message) { + xxx_messageInfo_Cat.Merge(m, src) +} +func (m *Cat) XXX_Size() int { + return xxx_messageInfo_Cat.Size(m) +} +func (m *Cat) XXX_DiscardUnknown() { + xxx_messageInfo_Cat.DiscardUnknown(m) +} + +var xxx_messageInfo_Cat proto.InternalMessageInfo + +func (m *Cat) GetAge() int32 { + if m != nil { + return m.Age + } + return 0 +} + +func (m *Cat) GetBreed() string { + if m != nil { + return m.Breed + } + return "" +} + +func (m *Cat) GetName() string { + if m != nil { + return m.Name + } + return "" +} + +func init() { + proto.RegisterType((*Cat)(nil), "Cat") +} + +func init() { proto.RegisterFile("proto_test.proto", fileDescriptor_a794ba8d0e5440a3) } + +var fileDescriptor_a794ba8d0e5440a3 = []byte{ + // 98 bytes of a gzipped FileDescriptorProto + 0x1f, 0x8b, 0x08, 0x00, 0x00, 0x00, 0x00, 0x00, 0x02, 0xff, 0xe2, 0x12, 0x28, 0x28, 0xca, 0x2f, + 0xc9, 0x8f, 0x2f, 0x49, 0x2d, 0x2e, 0xd1, 0x03, 0x33, 0x95, 0x1c, 0xb9, 0x98, 0x9d, 0x13, 0x4b, + 0x84, 0x04, 0xb8, 0x98, 0x1d, 0xd3, 0x53, 0x25, 0x18, 0x15, 0x18, 0x35, 0x58, 0x83, 0x40, 0x4c, + 0x21, 0x11, 0x2e, 0x56, 0xa7, 0xa2, 0xd4, 0xd4, 0x14, 0x09, 0x26, 0x05, 0x46, 0x0d, 0xce, 0x20, + 0x08, 0x47, 0x48, 0x88, 0x8b, 0xc5, 0x2f, 0x31, 0x37, 0x55, 0x82, 0x19, 0x2c, 0x08, 0x66, 0x27, + 0xb1, 0x81, 0x4d, 0x32, 0x06, 0x04, 0x00, 0x00, 0xff, 0xff, 0x5f, 0x95, 0x83, 0x0a, 0x5d, 0x00, + 0x00, 0x00, +} diff --git a/transport/http/proto/proto_test.go b/transport/http/proto/proto_test.go new file mode 100644 index 0000000..7b59d71 --- /dev/null +++ b/transport/http/proto/proto_test.go @@ -0,0 +1,95 @@ +package proto + +import ( + "io/ioutil" + "net/http" + "net/http/httptest" + "testing" + + "github.com/golang/protobuf/proto" +) + +func TestEncodeProtoRequest(t *testing.T) { + cat := &Cat{Name: "Ziggy", Age: 13, Breed: "Lumpy"} + + r := httptest.NewRequest(http.MethodGet, "/cat", nil) + + err := EncodeProtoRequest(nil, r, cat) + if err != nil { + t.Errorf("expected no encoding errors but got: %s", err) + return + } + + const xproto = "application/x-protobuf" + if typ := r.Header.Get("Content-Type"); typ != xproto { + t.Errorf("expected content type of %q, got %q", xproto, typ) + return + } + + bod, err := ioutil.ReadAll(r.Body) + if err != nil { + t.Errorf("expected no read errors but got: %s", err) + return + } + defer r.Body.Close() + + var got Cat + err = proto.Unmarshal(bod, &got) + if err != nil { + t.Errorf("expected no proto errors but got: %s", err) + return + } + + if !proto.Equal(&got, cat) { + t.Errorf("expected cats to be equal but got:\n\n%#v\n\nwant:\n\n%#v", got, cat) + return + } +} + +func TestEncodeProtoResponse(t *testing.T) { + cat := &Cat{Name: "Ziggy", Age: 13, Breed: "Lumpy"} + + wr := httptest.NewRecorder() + + err := EncodeProtoResponse(nil, wr, cat) + if err != nil { + t.Errorf("expected no encoding errors but got: %s", err) + return + } + + w := wr.Result() + + const xproto = "application/x-protobuf" + if typ := w.Header.Get("Content-Type"); typ != xproto { + t.Errorf("expected content type of %q, got %q", xproto, typ) + return + } + + if w.StatusCode != http.StatusTeapot { + t.Errorf("expected status code of %d, got %d", http.StatusTeapot, w.StatusCode) + return + } + + bod, err := ioutil.ReadAll(w.Body) + if err != nil { + t.Errorf("expected no read errors but got: %s", err) + return + } + defer w.Body.Close() + + var got Cat + err = proto.Unmarshal(bod, &got) + if err != nil { + t.Errorf("expected no proto errors but got: %s", err) + return + } + + if !proto.Equal(&got, cat) { + t.Errorf("expected cats to be equal but got:\n\n%#v\n\nwant:\n\n%#v", got, cat) + return + } +} + +func (c *Cat) StatusCode() int { + return http.StatusTeapot +} diff --git a/transport/http/proto/proto_test.proto b/transport/http/proto/proto_test.proto new file mode 100644 index 0000000..018486d --- /dev/null +++ b/transport/http/proto/proto_test.proto @@ -0,0 +1,7 @@ +syntax = "proto3"; + +message Cat { + int32 Age = 1; + string Breed = 2; + string Name = 3; +} diff --git a/transport/http/proto/server.go b/transport/http/proto/server.go new file mode 100644 index 0000000..9990081 --- /dev/null +++ b/transport/http/proto/server.go @@ -0,0 +1,47 @@ +package proto + +import ( + "context" + "errors" + "net/http" + + httptransport "github.com/go-kit/kit/transport/http" + "github.com/golang/protobuf/proto" +) + +// EncodeProtoResponse is an EncodeResponseFunc that serializes the response as Protobuf. +// Many Proto-over-HTTP services can use it as a sensible default. If the response +// implements Headerer, the provided headers will be applied to the response. If the +// response implements StatusCoder, the provided StatusCode will be used instead of 200. +func EncodeProtoResponse(ctx context.Context, w http.ResponseWriter, pres interface{}) error { + res, ok := pres.(proto.Message) + if !ok { + return errors.New("response does not implement proto.Message") + } + w.Header().Set("Content-Type", "application/x-protobuf") + if headerer, ok := w.(httptransport.Headerer); ok { + for k := range headerer.Headers() { + w.Header().Set(k, headerer.Headers().Get(k)) + } + } + code := http.StatusOK + if sc, ok := pres.(httptransport.StatusCoder); ok { + code = sc.StatusCode() + } + w.WriteHeader(code) + if code == http.StatusNoContent { + return nil + } + if res == nil { + return nil + } + b, err := proto.Marshal(res) + if err != nil { + return err + } + _, err = w.Write(b) + if err != nil { + return err + } + return nil +}