New Upstream Release - golang-gopkg-vmihailenco-msgpack.v2
Ready changes
Summary
Merged new upstream version: 5.3.5 (was: 4.3.1).
Resulting package
Built on 2023-06-23T21:04 (took 12m51s)
The resulting binary packages can be installed (if you have the apt repository enabled) by running one of:
apt install -t fresh-releases golang-gopkg-vmihailenco-msgpack.v2-dev
Lintian Result
Diff
diff --git a/.github/FUNDING.yml b/.github/FUNDING.yml
new file mode 100644
index 0000000..caef4b3
--- /dev/null
+++ b/.github/FUNDING.yml
@@ -0,0 +1 @@
+custom: ['https://uptrace.dev']
diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md
new file mode 100644
index 0000000..ecb2d55
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/bug_report.md
@@ -0,0 +1,50 @@
+---
+name: Bug report
+about: Create a report to help us improve
+title: ''
+labels: ''
+assignees: ''
+---
+
+Issue tracker is used for reporting bugs and discussing new features. Please use
+[Discord](https://discord.gg/rWtp5Aj) or [stackoverflow](https://stackoverflow.com) for supporting
+issues.
+
+<!--- Provide a general summary of the issue in the Title above -->
+
+## Expected Behavior
+
+<!--- Tell us what should happen -->
+
+## Current Behavior
+
+<!--- Tell us what happens instead of the expected behavior -->
+
+## Possible Solution
+
+<!--- Not obligatory, but suggest a fix/reason for the bug, -->
+
+## Steps to Reproduce
+
+<!--- Provide a link to a live example, or an unambiguous set of steps to -->
+<!--- reproduce this bug. Include code to reproduce, if relevant -->
+
+1.
+2.
+3.
+4.
+
+## Context (Environment)
+
+<!--- How has this issue affected you? What are you trying to accomplish? -->
+<!--- Providing context helps us come up with a solution that is most useful in the real world -->
+
+<!--- Provide a general summary of the issue in the Title above -->
+
+## Detailed Description
+
+<!--- Provide a detailed description of the change or addition you are proposing -->
+
+## Possible Implementation
+
+<!--- Not obligatory, but suggest an idea for implementing addition or change -->
diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml
new file mode 100644
index 0000000..697f40a
--- /dev/null
+++ b/.github/ISSUE_TEMPLATE/config.yml
@@ -0,0 +1,5 @@
+blank_issues_enabled: false
+contact_links:
+ - name: Discord
+ url: https://discord.gg/rWtp5Aj
+ about: Ask a question at Discord
diff --git a/.github/workflows/build.yml b/.github/workflows/build.yml
new file mode 100644
index 0000000..c8910de
--- /dev/null
+++ b/.github/workflows/build.yml
@@ -0,0 +1,27 @@
+name: Go
+
+on:
+ push:
+ branches: [v5]
+ pull_request:
+ branches: [v5]
+
+jobs:
+ build:
+ name: build
+ runs-on: ubuntu-latest
+ strategy:
+ matrix:
+ go-version: [1.16.x, 1.17.x]
+
+ steps:
+ - name: Set up ${{ matrix.go-version }}
+ uses: actions/setup-go@v2
+ with:
+ go-version: ${{ matrix.go-version }}
+
+ - name: Checkout code
+ uses: actions/checkout@v2
+
+ - name: Test
+ run: make test
diff --git a/.github/workflows/commitlint.yml b/.github/workflows/commitlint.yml
new file mode 100644
index 0000000..67e6df3
--- /dev/null
+++ b/.github/workflows/commitlint.yml
@@ -0,0 +1,11 @@
+name: Lint Commit Messages
+on: [pull_request]
+
+jobs:
+ commitlint:
+ runs-on: ubuntu-latest
+ steps:
+ - uses: actions/checkout@v2
+ with:
+ fetch-depth: 0
+ - uses: wagoid/commitlint-github-action@v4
diff --git a/.golangci.yml b/.golangci.yml
deleted file mode 100644
index 98d6cb7..0000000
--- a/.golangci.yml
+++ /dev/null
@@ -1,12 +0,0 @@
-run:
- concurrency: 8
- deadline: 5m
- tests: false
-linters:
- enable-all: true
- disable:
- - gochecknoglobals
- - gocognit
- - godox
- - wsl
- - funlen
diff --git a/.prettierrc b/.prettierrc
new file mode 100644
index 0000000..8b7f044
--- /dev/null
+++ b/.prettierrc
@@ -0,0 +1,4 @@
+semi: false
+singleQuote: true
+proseWrap: always
+printWidth: 100
diff --git a/.travis.yml b/.travis.yml
index 82e5217..e2ce06c 100644
--- a/.travis.yml
+++ b/.travis.yml
@@ -2,9 +2,8 @@ sudo: false
language: go
go:
- - 1.11.x
- - 1.12.x
- - 1.13.x
+ - 1.15.x
+ - 1.16.x
- tip
matrix:
@@ -17,4 +16,5 @@ env:
go_import_path: github.com/vmihailenco/msgpack
before_install:
- - curl -sfL https://install.goreleaser.com/github.com/golangci/golangci-lint.sh | sh -s -- -b $(go env GOPATH)/bin v1.21.0
+ - curl -sfL https://install.goreleaser.com/github.com/golangci/golangci-lint.sh | sh -s -- -b $(go
+ env GOPATH)/bin v1.31.0
diff --git a/CHANGELOG.md b/CHANGELOG.md
index fac9709..f6b19d5 100644
--- a/CHANGELOG.md
+++ b/CHANGELOG.md
@@ -1,8 +1,32 @@
+## [5.3.5](https://github.com/vmihailenco/msgpack/compare/v5.3.4...v5.3.5) (2021-10-22)
+
+
+
+## v5
+
+### Added
+
+- `DecodeMap` is split into `DecodeMap`, `DecodeTypedMap`, and `DecodeUntypedMap`.
+- New msgpack extensions API.
+
+### Changed
+
+- `Reset*` functions also reset flags.
+- `SetMapDecodeFunc` is renamed to `SetMapDecoder`.
+- `StructAsArray` is renamed to `UseArrayEncodedStructs`.
+- `SortMapKeys` is renamed to `SetSortMapKeys`.
+
+### Removed
+
+- `UseJSONTag` is removed. Use `SetCustomStructTag("json")` instead.
+
## v4
-- Encode, Decode, Marshal, and Unmarshal are changed to accept single argument. EncodeMulti and DecodeMulti are added as replacement.
+- Encode, Decode, Marshal, and Unmarshal are changed to accept single argument. EncodeMulti and
+ DecodeMulti are added as replacement.
- Added EncodeInt8/16/32/64 and EncodeUint8/16/32/64.
-- Encoder changed to preserve type of numbers instead of chosing most compact encoding. The old behavior can be achieved with Encoder.UseCompactEncoding.
+- Encoder changed to preserve type of numbers instead of chosing most compact encoding. The old
+ behavior can be achieved with Encoder.UseCompactEncoding.
## v3.3
@@ -16,9 +40,12 @@
- gopkg.in is not supported any more. Update import path to github.com/vmihailenco/msgpack.
- Msgpack maps are decoded into map[string]interface{} by default.
-- EncodeSliceLen is removed in favor of EncodeArrayLen. DecodeSliceLen is removed in favor of DecodeArrayLen.
+- EncodeSliceLen is removed in favor of EncodeArrayLen. DecodeSliceLen is removed in favor of
+ DecodeArrayLen.
- Embedded structs are automatically inlined where possible.
-- Time is encoded using extension as described in https://github.com/msgpack/msgpack/pull/209. Old format is supported as well.
-- EncodeInt8/16/32/64 is replaced with EncodeInt. EncodeUint8/16/32/64 is replaced with EncodeUint. There should be no performance differences.
+- Time is encoded using extension as described in https://github.com/msgpack/msgpack/pull/209. Old
+ format is supported as well.
+- EncodeInt8/16/32/64 is replaced with EncodeInt. EncodeUint8/16/32/64 is replaced with EncodeUint.
+ There should be no performance differences.
- DecodeInterface can now return int8/16/32 and uint8/16/32.
- PeekCode returns codes.Code instead of byte.
diff --git a/Makefile b/Makefile
index 57914e3..e9aade7 100644
--- a/Makefile
+++ b/Makefile
@@ -1,6 +1,6 @@
-all:
+test:
go test ./...
go test ./... -short -race
go test ./... -run=NONE -bench=. -benchmem
env GOOS=linux GOARCH=386 go test ./...
- golangci-lint run
+ go vet
diff --git a/README.md b/README.md
index fea3a62..66ad98b 100644
--- a/README.md
+++ b/README.md
@@ -1,72 +1,86 @@
# MessagePack encoding for Golang
-[![Build Status](https://travis-ci.org/vmihailenco/msgpack.svg?branch=v2)](https://travis-ci.org/vmihailenco/msgpack)
-[![GoDoc](https://godoc.org/github.com/vmihailenco/msgpack?status.svg)](https://godoc.org/github.com/vmihailenco/msgpack)
+[![Build Status](https://travis-ci.org/vmihailenco/msgpack.svg)](https://travis-ci.org/vmihailenco/msgpack)
+[![PkgGoDev](https://pkg.go.dev/badge/github.com/vmihailenco/msgpack/v5)](https://pkg.go.dev/github.com/vmihailenco/msgpack/v5)
+[![Documentation](https://img.shields.io/badge/msgpack-documentation-informational)](https://msgpack.uptrace.dev/)
+[![Chat](https://discordapp.com/api/guilds/752070105847955518/widget.png)](https://discord.gg/rWtp5Aj)
+
+> :heart:
+> [**Uptrace.dev** - All-in-one tool to optimize performance and monitor errors & logs](https://uptrace.dev/?utm_source=gh-msgpack&utm_campaign=gh-msgpack-var2)
+
+- Join [Discord](https://discord.gg/rWtp5Aj) to ask questions.
+- [Documentation](https://msgpack.uptrace.dev)
+- [Reference](https://pkg.go.dev/github.com/vmihailenco/msgpack/v5)
+- [Examples](https://pkg.go.dev/github.com/vmihailenco/msgpack/v5#pkg-examples)
+
+Other projects you may like:
+
+- [Bun](https://bun.uptrace.dev) - fast and simple SQL client for PostgreSQL, MySQL, and SQLite.
+- [BunRouter](https://bunrouter.uptrace.dev/) - fast and flexible HTTP router for Go.
+
+## Features
-Supports:
- Primitives, arrays, maps, structs, time.Time and interface{}.
-- Appengine *datastore.Key and datastore.Cursor.
-- [CustomEncoder](https://godoc.org/github.com/vmihailenco/msgpack#example-CustomEncoder)/CustomDecoder interfaces for custom encoding.
-- [Extensions](https://godoc.org/github.com/vmihailenco/msgpack#example-RegisterExt) to encode type information.
-- Renaming fields via `msgpack:"my_field_name"`.
-- Omitting individual empty fields via `msgpack:",omitempty"` tag or all [empty fields in a struct](https://godoc.org/github.com/vmihailenco/msgpack#example-Marshal--OmitEmpty).
-- [Map keys sorting](https://godoc.org/github.com/vmihailenco/msgpack#Encoder.SortMapKeys).
-- Encoding/decoding all [structs as arrays](https://godoc.org/github.com/vmihailenco/msgpack#Encoder.UseArrayForStructs) or [individual structs](https://godoc.org/github.com/vmihailenco/msgpack#example-Marshal--AsArray).
-- [Encoder.UseJSONTag](https://godoc.org/github.com/vmihailenco/msgpack#Encoder.UseJSONTag) with [Decoder.UseJSONTag](https://godoc.org/github.com/vmihailenco/msgpack#Decoder.UseJSONTag) can turn msgpack into drop-in replacement for JSON.
-- Simple but very fast and efficient [queries](https://godoc.org/github.com/vmihailenco/msgpack#example-Decoder-Query).
-
-API docs: https://godoc.org/github.com/vmihailenco/msgpack.
-Examples: https://godoc.org/github.com/vmihailenco/msgpack#pkg-examples.
+- Appengine \*datastore.Key and datastore.Cursor.
+- [CustomEncoder]/[CustomDecoder] interfaces for custom encoding.
+- [Extensions](https://pkg.go.dev/github.com/vmihailenco/msgpack/v5#example-RegisterExt) to encode
+ type information.
+- Renaming fields via `msgpack:"my_field_name"` and alias via `msgpack:"alias:another_name"`.
+- Omitting individual empty fields via `msgpack:",omitempty"` tag or all
+ [empty fields in a struct](https://pkg.go.dev/github.com/vmihailenco/msgpack/v5#example-Marshal-OmitEmpty).
+- [Map keys sorting](https://pkg.go.dev/github.com/vmihailenco/msgpack/v5#Encoder.SetSortMapKeys).
+- Encoding/decoding all
+ [structs as arrays](https://pkg.go.dev/github.com/vmihailenco/msgpack/v5#Encoder.UseArrayEncodedStructs)
+ or
+ [individual structs](https://pkg.go.dev/github.com/vmihailenco/msgpack/v5#example-Marshal-AsArray).
+- [Encoder.SetCustomStructTag] with [Decoder.SetCustomStructTag] can turn msgpack into drop-in
+ replacement for any tag.
+- Simple but very fast and efficient
+ [queries](https://pkg.go.dev/github.com/vmihailenco/msgpack/v5#example-Decoder.Query).
+
+[customencoder]: https://pkg.go.dev/github.com/vmihailenco/msgpack/v5#CustomEncoder
+[customdecoder]: https://pkg.go.dev/github.com/vmihailenco/msgpack/v5#CustomDecoder
+[encoder.setcustomstructtag]:
+ https://pkg.go.dev/github.com/vmihailenco/msgpack/v5#Encoder.SetCustomStructTag
+[decoder.setcustomstructtag]:
+ https://pkg.go.dev/github.com/vmihailenco/msgpack/v5#Decoder.SetCustomStructTag
## Installation
-This project uses [Go Modules](https://github.com/golang/go/wiki/Modules) and semantic import versioning since v4:
+msgpack supports 2 last Go versions and requires support for
+[Go modules](https://github.com/golang/go/wiki/Modules). So make sure to initialize a Go module:
-``` shell
+```shell
go mod init github.com/my/repo
-go get github.com/vmihailenco/msgpack/v4
+```
+
+And then install msgpack/v5 (note _v5_ in the import; omitting it is a popular mistake):
+
+```shell
+go get github.com/vmihailenco/msgpack/v5
```
## Quickstart
-``` go
-import "github.com/vmihailenco/msgpack/v4"
+```go
+import "github.com/vmihailenco/msgpack/v5"
func ExampleMarshal() {
- type Item struct {
- Foo string
- }
-
- b, err := msgpack.Marshal(&Item{Foo: "bar"})
- if err != nil {
- panic(err)
- }
-
- var item Item
- err = msgpack.Unmarshal(b, &item)
- if err != nil {
- panic(err)
- }
- fmt.Println(item.Foo)
- // Output: bar
+ type Item struct {
+ Foo string
+ }
+
+ b, err := msgpack.Marshal(&Item{Foo: "bar"})
+ if err != nil {
+ panic(err)
+ }
+
+ var item Item
+ err = msgpack.Unmarshal(b, &item)
+ if err != nil {
+ panic(err)
+ }
+ fmt.Println(item.Foo)
+ // Output: bar
}
```
-
-## Benchmark
-
-```
-BenchmarkStructVmihailencoMsgpack-4 200000 12814 ns/op 2128 B/op 26 allocs/op
-BenchmarkStructUgorjiGoMsgpack-4 100000 17678 ns/op 3616 B/op 70 allocs/op
-BenchmarkStructUgorjiGoCodec-4 100000 19053 ns/op 7346 B/op 23 allocs/op
-BenchmarkStructJSON-4 20000 69438 ns/op 7864 B/op 26 allocs/op
-BenchmarkStructGOB-4 10000 104331 ns/op 14664 B/op 278 allocs/op
-```
-
-## Howto
-
-Please go through [examples](https://godoc.org/github.com/vmihailenco/msgpack#pkg-examples) to get an idea how to use this package.
-
-## See also
-
-- [Golang PostgreSQL ORM](https://github.com/go-pg/pg)
-- [Golang message task queue](https://github.com/vmihailenco/taskq)
diff --git a/bench_test.go b/bench_test.go
index 6693062..5e02f47 100644
--- a/bench_test.go
+++ b/bench_test.go
@@ -8,7 +8,7 @@ import (
"testing"
"time"
- "github.com/vmihailenco/msgpack/v4"
+ "github.com/vmihailenco/msgpack/v5"
)
func BenchmarkDiscard(b *testing.B) {
@@ -95,6 +95,26 @@ func BenchmarkInt32(b *testing.B) {
benchmarkEncodeDecode(b, int32(0), &dst)
}
+func BenchmarkFloat32(b *testing.B) {
+ var dst float32
+ benchmarkEncodeDecode(b, float32(0), &dst)
+}
+
+func BenchmarkFloat32_Max(b *testing.B) {
+ var dst float32
+ benchmarkEncodeDecode(b, float32(math.MaxFloat32), &dst)
+}
+
+func BenchmarkFloat64(b *testing.B) {
+ var dst float64
+ benchmarkEncodeDecode(b, float64(0), &dst)
+}
+
+func BenchmarkFloat64_Max(b *testing.B) {
+ var dst float64
+ benchmarkEncodeDecode(b, float64(math.MaxFloat64), &dst)
+}
+
func BenchmarkTime(b *testing.B) {
var dst time.Time
benchmarkEncodeDecode(b, time.Now(), &dst)
@@ -203,8 +223,10 @@ type benchmarkStruct2 struct {
UpdatedAt time.Time
}
-var _ msgpack.CustomEncoder = (*benchmarkStruct2)(nil)
-var _ msgpack.CustomDecoder = (*benchmarkStruct2)(nil)
+var (
+ _ msgpack.CustomEncoder = (*benchmarkStruct2)(nil)
+ _ msgpack.CustomDecoder = (*benchmarkStruct2)(nil)
+)
func (s *benchmarkStruct2) EncodeMsgpack(enc *msgpack.Encoder) error {
return enc.EncodeMulti(
@@ -346,8 +368,8 @@ func BenchmarkQuery(b *testing.B) {
var records []map[string]interface{}
for i := 0; i < 1000; i++ {
record := map[string]interface{}{
- "id": i,
- "attrs": map[string]interface{}{"phone": i},
+ "id": int64(i),
+ "attrs": map[string]interface{}{"phone": int64(i)},
}
records = append(records, record)
}
diff --git a/codes/codes.go b/codes/codes.go
deleted file mode 100644
index 28e0a5a..0000000
--- a/codes/codes.go
+++ /dev/null
@@ -1,90 +0,0 @@
-package codes
-
-type Code byte
-
-var (
- PosFixedNumHigh Code = 0x7f
- NegFixedNumLow Code = 0xe0
-
- Nil Code = 0xc0
-
- False Code = 0xc2
- True Code = 0xc3
-
- Float Code = 0xca
- Double Code = 0xcb
-
- Uint8 Code = 0xcc
- Uint16 Code = 0xcd
- Uint32 Code = 0xce
- Uint64 Code = 0xcf
-
- Int8 Code = 0xd0
- Int16 Code = 0xd1
- Int32 Code = 0xd2
- Int64 Code = 0xd3
-
- FixedStrLow Code = 0xa0
- FixedStrHigh Code = 0xbf
- FixedStrMask Code = 0x1f
- Str8 Code = 0xd9
- Str16 Code = 0xda
- Str32 Code = 0xdb
-
- Bin8 Code = 0xc4
- Bin16 Code = 0xc5
- Bin32 Code = 0xc6
-
- FixedArrayLow Code = 0x90
- FixedArrayHigh Code = 0x9f
- FixedArrayMask Code = 0xf
- Array16 Code = 0xdc
- Array32 Code = 0xdd
-
- FixedMapLow Code = 0x80
- FixedMapHigh Code = 0x8f
- FixedMapMask Code = 0xf
- Map16 Code = 0xde
- Map32 Code = 0xdf
-
- FixExt1 Code = 0xd4
- FixExt2 Code = 0xd5
- FixExt4 Code = 0xd6
- FixExt8 Code = 0xd7
- FixExt16 Code = 0xd8
- Ext8 Code = 0xc7
- Ext16 Code = 0xc8
- Ext32 Code = 0xc9
-)
-
-func IsFixedNum(c Code) bool {
- return c <= PosFixedNumHigh || c >= NegFixedNumLow
-}
-
-func IsFixedMap(c Code) bool {
- return c >= FixedMapLow && c <= FixedMapHigh
-}
-
-func IsFixedArray(c Code) bool {
- return c >= FixedArrayLow && c <= FixedArrayHigh
-}
-
-func IsFixedString(c Code) bool {
- return c >= FixedStrLow && c <= FixedStrHigh
-}
-
-func IsString(c Code) bool {
- return IsFixedString(c) || c == Str8 || c == Str16 || c == Str32
-}
-
-func IsBin(c Code) bool {
- return c == Bin8 || c == Bin16 || c == Bin32
-}
-
-func IsFixedExt(c Code) bool {
- return c >= FixExt1 && c <= FixExt16
-}
-
-func IsExt(c Code) bool {
- return IsFixedExt(c) || c == Ext8 || c == Ext16 || c == Ext32
-}
diff --git a/commitlint.config.js b/commitlint.config.js
new file mode 100644
index 0000000..4fedde6
--- /dev/null
+++ b/commitlint.config.js
@@ -0,0 +1 @@
+module.exports = { extends: ['@commitlint/config-conventional'] }
diff --git a/debian/changelog b/debian/changelog
index 5210943..57e102a 100644
--- a/debian/changelog
+++ b/debian/changelog
@@ -1,3 +1,9 @@
+golang-gopkg-vmihailenco-msgpack.v2 (5.3.5-1) UNRELEASED; urgency=low
+
+ * New upstream release.
+
+ -- Debian Janitor <janitor@jelmer.uk> Fri, 23 Jun 2023 20:51:39 -0000
+
golang-gopkg-vmihailenco-msgpack.v2 (4.3.1-2) unstable; urgency=medium
* Team upload.
diff --git a/decode.go b/decode.go
index 9586934..5df40e5 100644
--- a/decode.go
+++ b/decode.go
@@ -7,9 +7,15 @@ import (
"fmt"
"io"
"reflect"
+ "sync"
"time"
- "github.com/vmihailenco/msgpack/v4/codes"
+ "github.com/vmihailenco/msgpack/v5/msgpcode"
+)
+
+const (
+ looseInterfaceDecodingFlag uint32 = 1 << iota
+ disallowUnknownFieldsFlag
)
const (
@@ -23,17 +29,37 @@ type bufReader interface {
io.ByteScanner
}
-func newBufReader(r io.Reader) bufReader {
- if br, ok := r.(bufReader); ok {
- return br
- }
- return bufio.NewReader(r)
+//------------------------------------------------------------------------------
+
+var decPool = sync.Pool{
+ New: func() interface{} {
+ return NewDecoder(nil)
+ },
+}
+
+func GetDecoder() *Decoder {
+ return decPool.Get().(*Decoder)
}
+func PutDecoder(dec *Decoder) {
+ dec.r = nil
+ dec.s = nil
+ decPool.Put(dec)
+}
+
+//------------------------------------------------------------------------------
+
// Unmarshal decodes the MessagePack-encoded data and stores the result
// in the value pointed to by v.
func Unmarshal(data []byte, v interface{}) error {
- return NewDecoder(bytes.NewReader(data)).Decode(v)
+ dec := GetDecoder()
+
+ dec.Reset(bytes.NewReader(data))
+ err := dec.Decode(v)
+
+ PutDecoder(dec)
+
+ return err
}
// A Decoder reads and decodes MessagePack values from an input stream.
@@ -42,59 +68,103 @@ type Decoder struct {
s io.ByteScanner
buf []byte
- extLen int
- rec []byte // accumulates read data if not nil
+ rec []byte // accumulates read data if not nil
- useLoose bool
- useJSONTag bool
-
- decodeMapFunc func(*Decoder) (interface{}, error)
+ dict []string
+ flags uint32
+ structTag string
+ mapDecoder func(*Decoder) (interface{}, error)
}
// NewDecoder returns a new decoder that reads from r.
//
// The decoder introduces its own buffering and may read data from r
-// beyond the MessagePack values requested. Buffering can be disabled
+// beyond the requested msgpack values. Buffering can be disabled
// by passing a reader that implements io.ByteScanner interface.
func NewDecoder(r io.Reader) *Decoder {
d := new(Decoder)
- d.resetReader(r)
+ d.Reset(r)
return d
}
-func (d *Decoder) SetDecodeMapFunc(fn func(*Decoder) (interface{}, error)) {
- d.decodeMapFunc = fn
+// Reset discards any buffered data, resets all state, and switches the buffered
+// reader to read from r.
+func (d *Decoder) Reset(r io.Reader) {
+ d.ResetDict(r, nil)
}
-// UseDecodeInterfaceLoose causes decoder to use DecodeInterfaceLoose
+// ResetDict is like Reset, but also resets the dict.
+func (d *Decoder) ResetDict(r io.Reader, dict []string) {
+ d.resetReader(r)
+ d.flags = 0
+ d.structTag = ""
+ d.mapDecoder = nil
+ d.dict = dict
+}
+
+func (d *Decoder) WithDict(dict []string, fn func(*Decoder) error) error {
+ oldDict := d.dict
+ d.dict = dict
+ err := fn(d)
+ d.dict = oldDict
+ return err
+}
+
+func (d *Decoder) resetReader(r io.Reader) {
+ if br, ok := r.(bufReader); ok {
+ d.r = br
+ d.s = br
+ } else {
+ br := bufio.NewReader(r)
+ d.r = br
+ d.s = br
+ }
+}
+
+func (d *Decoder) SetMapDecoder(fn func(*Decoder) (interface{}, error)) {
+ d.mapDecoder = fn
+}
+
+// UseLooseInterfaceDecoding causes decoder to use DecodeInterfaceLoose
// to decode msgpack value into Go interface{}.
-func (d *Decoder) UseDecodeInterfaceLoose(flag bool) *Decoder {
- d.useLoose = flag
- return d
+func (d *Decoder) UseLooseInterfaceDecoding(on bool) {
+ if on {
+ d.flags |= looseInterfaceDecodingFlag
+ } else {
+ d.flags &= ^looseInterfaceDecodingFlag
+ }
}
-// UseJSONTag causes the Decoder to use json struct tag as fallback option
+// SetCustomStructTag causes the decoder to use the supplied tag as a fallback option
// if there is no msgpack tag.
-func (d *Decoder) UseJSONTag(flag bool) *Decoder {
- d.useJSONTag = flag
- return d
+func (d *Decoder) SetCustomStructTag(tag string) {
+ d.structTag = tag
}
-// Buffered returns a reader of the data remaining in the Decoder's buffer.
-// The reader is valid until the next call to Decode.
-func (d *Decoder) Buffered() io.Reader {
- return d.r
+// DisallowUnknownFields causes the Decoder to return an error when the destination
+// is a struct and the input contains object keys which do not match any
+// non-ignored, exported fields in the destination.
+func (d *Decoder) DisallowUnknownFields(on bool) {
+ if on {
+ d.flags |= disallowUnknownFieldsFlag
+ } else {
+ d.flags &= ^disallowUnknownFieldsFlag
+ }
}
-func (d *Decoder) Reset(r io.Reader) error {
- d.resetReader(r)
- return nil
+// UseInternedStrings enables support for decoding interned strings.
+func (d *Decoder) UseInternedStrings(on bool) {
+ if on {
+ d.flags |= useInternedStringsFlag
+ } else {
+ d.flags &= ^useInternedStringsFlag
+ }
}
-func (d *Decoder) resetReader(r io.Reader) {
- reader := newBufReader(r)
- d.r = reader
- d.s = reader
+// Buffered returns a reader of the data remaining in the Decoder's buffer.
+// The reader is valid until the next call to Decode.
+func (d *Decoder) Buffered() io.Reader {
+ return d.r
}
//nolint:gocyclo
@@ -199,12 +269,22 @@ func (d *Decoder) Decode(v interface{}) error {
return errors.New("msgpack: Decode(nil)")
}
if vv.Kind() != reflect.Ptr {
- return fmt.Errorf("msgpack: Decode(nonsettable %T)", v)
+ return fmt.Errorf("msgpack: Decode(non-pointer %T)", v)
+ }
+ if vv.IsNil() {
+ return fmt.Errorf("msgpack: Decode(non-settable %T)", v)
}
+
vv = vv.Elem()
- if !vv.IsValid() {
- return fmt.Errorf("msgpack: Decode(nonsettable %T)", v)
+ if vv.Kind() == reflect.Interface {
+ if !vv.IsNil() {
+ vv = vv.Elem()
+ if vv.Kind() != reflect.Ptr {
+ return fmt.Errorf("msgpack: Decode(non-pointer %s)", vv.Type().String())
+ }
+ }
}
+
return d.DecodeValue(vv)
}
@@ -218,7 +298,7 @@ func (d *Decoder) DecodeMulti(v ...interface{}) error {
}
func (d *Decoder) decodeInterfaceCond() (interface{}, error) {
- if d.useLoose {
+ if d.flags&looseInterfaceDecodingFlag != 0 {
return d.DecodeInterfaceLoose()
}
return d.DecodeInterface()
@@ -234,7 +314,7 @@ func (d *Decoder) DecodeNil() error {
if err != nil {
return err
}
- if c != codes.Nil {
+ if c != msgpcode.Nil {
return fmt.Errorf("msgpack: invalid code=%x decoding nil", c)
}
return nil
@@ -260,16 +340,27 @@ func (d *Decoder) DecodeBool() (bool, error) {
return d.bool(c)
}
-func (d *Decoder) bool(c codes.Code) (bool, error) {
- if c == codes.False {
+func (d *Decoder) bool(c byte) (bool, error) {
+ if c == msgpcode.Nil {
return false, nil
}
- if c == codes.True {
+ if c == msgpcode.False {
+ return false, nil
+ }
+ if c == msgpcode.True {
return true, nil
}
return false, fmt.Errorf("msgpack: invalid code=%x decoding bool", c)
}
+func (d *Decoder) DecodeDuration() (time.Duration, error) {
+ n, err := d.DecodeInt64()
+ if err != nil {
+ return 0, err
+ }
+ return time.Duration(n), nil
+}
+
// DecodeInterface decodes value into interface. It returns following types:
// - nil,
// - bool,
@@ -290,63 +381,63 @@ func (d *Decoder) DecodeInterface() (interface{}, error) {
return nil, err
}
- if codes.IsFixedNum(c) {
+ if msgpcode.IsFixedNum(c) {
return int8(c), nil
}
- if codes.IsFixedMap(c) {
+ if msgpcode.IsFixedMap(c) {
err = d.s.UnreadByte()
if err != nil {
return nil, err
}
- return d.DecodeMap()
+ return d.decodeMapDefault()
}
- if codes.IsFixedArray(c) {
+ if msgpcode.IsFixedArray(c) {
return d.decodeSlice(c)
}
- if codes.IsFixedString(c) {
+ if msgpcode.IsFixedString(c) {
return d.string(c)
}
switch c {
- case codes.Nil:
+ case msgpcode.Nil:
return nil, nil
- case codes.False, codes.True:
+ case msgpcode.False, msgpcode.True:
return d.bool(c)
- case codes.Float:
+ case msgpcode.Float:
return d.float32(c)
- case codes.Double:
+ case msgpcode.Double:
return d.float64(c)
- case codes.Uint8:
+ case msgpcode.Uint8:
return d.uint8()
- case codes.Uint16:
+ case msgpcode.Uint16:
return d.uint16()
- case codes.Uint32:
+ case msgpcode.Uint32:
return d.uint32()
- case codes.Uint64:
+ case msgpcode.Uint64:
return d.uint64()
- case codes.Int8:
+ case msgpcode.Int8:
return d.int8()
- case codes.Int16:
+ case msgpcode.Int16:
return d.int16()
- case codes.Int32:
+ case msgpcode.Int32:
return d.int32()
- case codes.Int64:
+ case msgpcode.Int64:
return d.int64()
- case codes.Bin8, codes.Bin16, codes.Bin32:
+ case msgpcode.Bin8, msgpcode.Bin16, msgpcode.Bin32:
return d.bytes(c, nil)
- case codes.Str8, codes.Str16, codes.Str32:
+ case msgpcode.Str8, msgpcode.Str16, msgpcode.Str32:
return d.string(c)
- case codes.Array16, codes.Array32:
+ case msgpcode.Array16, msgpcode.Array32:
return d.decodeSlice(c)
- case codes.Map16, codes.Map32:
+ case msgpcode.Map16, msgpcode.Map32:
err = d.s.UnreadByte()
if err != nil {
return nil, err
}
- return d.DecodeMap()
- case codes.FixExt1, codes.FixExt2, codes.FixExt4, codes.FixExt8, codes.FixExt16,
- codes.Ext8, codes.Ext16, codes.Ext32:
- return d.extInterface(c)
+ return d.decodeMapDefault()
+ case msgpcode.FixExt1, msgpcode.FixExt2, msgpcode.FixExt4, msgpcode.FixExt8, msgpcode.FixExt16,
+ msgpcode.Ext8, msgpcode.Ext16, msgpcode.Ext32:
+ return d.decodeInterfaceExt(c)
}
return 0, fmt.Errorf("msgpack: unknown code %x decoding interface{}", c)
@@ -356,55 +447,55 @@ func (d *Decoder) DecodeInterface() (interface{}, error) {
// - int8, int16, and int32 are converted to int64,
// - uint8, uint16, and uint32 are converted to uint64,
// - float32 is converted to float64.
+// - []byte is converted to string.
func (d *Decoder) DecodeInterfaceLoose() (interface{}, error) {
c, err := d.readCode()
if err != nil {
return nil, err
}
- if codes.IsFixedNum(c) {
+ if msgpcode.IsFixedNum(c) {
return int64(int8(c)), nil
}
- if codes.IsFixedMap(c) {
+ if msgpcode.IsFixedMap(c) {
err = d.s.UnreadByte()
if err != nil {
return nil, err
}
- return d.DecodeMap()
+ return d.decodeMapDefault()
}
- if codes.IsFixedArray(c) {
+ if msgpcode.IsFixedArray(c) {
return d.decodeSlice(c)
}
- if codes.IsFixedString(c) {
+ if msgpcode.IsFixedString(c) {
return d.string(c)
}
switch c {
- case codes.Nil:
+ case msgpcode.Nil:
return nil, nil
- case codes.False, codes.True:
+ case msgpcode.False, msgpcode.True:
return d.bool(c)
- case codes.Float, codes.Double:
+ case msgpcode.Float, msgpcode.Double:
return d.float64(c)
- case codes.Uint8, codes.Uint16, codes.Uint32, codes.Uint64:
+ case msgpcode.Uint8, msgpcode.Uint16, msgpcode.Uint32, msgpcode.Uint64:
return d.uint(c)
- case codes.Int8, codes.Int16, codes.Int32, codes.Int64:
+ case msgpcode.Int8, msgpcode.Int16, msgpcode.Int32, msgpcode.Int64:
return d.int(c)
- case codes.Bin8, codes.Bin16, codes.Bin32:
- return d.bytes(c, nil)
- case codes.Str8, codes.Str16, codes.Str32:
+ case msgpcode.Str8, msgpcode.Str16, msgpcode.Str32,
+ msgpcode.Bin8, msgpcode.Bin16, msgpcode.Bin32:
return d.string(c)
- case codes.Array16, codes.Array32:
+ case msgpcode.Array16, msgpcode.Array32:
return d.decodeSlice(c)
- case codes.Map16, codes.Map32:
+ case msgpcode.Map16, msgpcode.Map32:
err = d.s.UnreadByte()
if err != nil {
return nil, err
}
- return d.DecodeMap()
- case codes.FixExt1, codes.FixExt2, codes.FixExt4, codes.FixExt8, codes.FixExt16,
- codes.Ext8, codes.Ext16, codes.Ext32:
- return d.extInterface(c)
+ return d.decodeMapDefault()
+ case msgpcode.FixExt1, msgpcode.FixExt2, msgpcode.FixExt4, msgpcode.FixExt8, msgpcode.FixExt16,
+ msgpcode.Ext8, msgpcode.Ext16, msgpcode.Ext32:
+ return d.decodeInterfaceExt(c)
}
return 0, fmt.Errorf("msgpack: unknown code %x decoding interface{}", c)
@@ -417,63 +508,78 @@ func (d *Decoder) Skip() error {
return err
}
- if codes.IsFixedNum(c) {
+ if msgpcode.IsFixedNum(c) {
return nil
}
- if codes.IsFixedMap(c) {
+ if msgpcode.IsFixedMap(c) {
return d.skipMap(c)
}
- if codes.IsFixedArray(c) {
+ if msgpcode.IsFixedArray(c) {
return d.skipSlice(c)
}
- if codes.IsFixedString(c) {
+ if msgpcode.IsFixedString(c) {
return d.skipBytes(c)
}
switch c {
- case codes.Nil, codes.False, codes.True:
+ case msgpcode.Nil, msgpcode.False, msgpcode.True:
return nil
- case codes.Uint8, codes.Int8:
+ case msgpcode.Uint8, msgpcode.Int8:
return d.skipN(1)
- case codes.Uint16, codes.Int16:
+ case msgpcode.Uint16, msgpcode.Int16:
return d.skipN(2)
- case codes.Uint32, codes.Int32, codes.Float:
+ case msgpcode.Uint32, msgpcode.Int32, msgpcode.Float:
return d.skipN(4)
- case codes.Uint64, codes.Int64, codes.Double:
+ case msgpcode.Uint64, msgpcode.Int64, msgpcode.Double:
return d.skipN(8)
- case codes.Bin8, codes.Bin16, codes.Bin32:
+ case msgpcode.Bin8, msgpcode.Bin16, msgpcode.Bin32:
return d.skipBytes(c)
- case codes.Str8, codes.Str16, codes.Str32:
+ case msgpcode.Str8, msgpcode.Str16, msgpcode.Str32:
return d.skipBytes(c)
- case codes.Array16, codes.Array32:
+ case msgpcode.Array16, msgpcode.Array32:
return d.skipSlice(c)
- case codes.Map16, codes.Map32:
+ case msgpcode.Map16, msgpcode.Map32:
return d.skipMap(c)
- case codes.FixExt1, codes.FixExt2, codes.FixExt4, codes.FixExt8, codes.FixExt16,
- codes.Ext8, codes.Ext16, codes.Ext32:
+ case msgpcode.FixExt1, msgpcode.FixExt2, msgpcode.FixExt4, msgpcode.FixExt8, msgpcode.FixExt16,
+ msgpcode.Ext8, msgpcode.Ext16, msgpcode.Ext32:
return d.skipExt(c)
}
return fmt.Errorf("msgpack: unknown code %x", c)
}
+func (d *Decoder) DecodeRaw() (RawMessage, error) {
+ d.rec = make([]byte, 0)
+ if err := d.Skip(); err != nil {
+ return nil, err
+ }
+ msg := RawMessage(d.rec)
+ d.rec = nil
+ return msg, nil
+}
+
// PeekCode returns the next MessagePack code without advancing the reader.
-// Subpackage msgpack/codes contains list of available codes.
-func (d *Decoder) PeekCode() (codes.Code, error) {
+// Subpackage msgpack/codes defines the list of available msgpcode.
+func (d *Decoder) PeekCode() (byte, error) {
c, err := d.s.ReadByte()
if err != nil {
return 0, err
}
- return codes.Code(c), d.s.UnreadByte()
+ return c, d.s.UnreadByte()
+}
+
+// ReadFull reads exactly len(buf) bytes into the buf.
+func (d *Decoder) ReadFull(buf []byte) error {
+ _, err := readN(d.r, buf, len(buf))
+ return err
}
func (d *Decoder) hasNilCode() bool {
code, err := d.PeekCode()
- return err == nil && code == codes.Nil
+ return err == nil && code == msgpcode.Nil
}
-func (d *Decoder) readCode() (codes.Code, error) {
- d.extLen = 0
+func (d *Decoder) readCode() (byte, error) {
c, err := d.s.ReadByte()
if err != nil {
return 0, err
@@ -481,7 +587,7 @@ func (d *Decoder) readCode() (codes.Code, error) {
if d.rec != nil {
d.rec = append(d.rec, c)
}
- return codes.Code(c), nil
+ return c, nil
}
func (d *Decoder) readFull(b []byte) error {
@@ -490,7 +596,6 @@ func (d *Decoder) readFull(b []byte) error {
return err
}
if d.rec != nil {
- //TODO: read directly into d.rec?
d.rec = append(d.rec, b...)
}
return nil
@@ -503,7 +608,7 @@ func (d *Decoder) readN(n int) ([]byte, error) {
return nil, err
}
if d.rec != nil {
- //TODO: read directly into d.rec?
+ // TODO: read directly into d.rec?
d.rec = append(d.rec, d.buf...)
}
return d.buf, nil
@@ -533,10 +638,7 @@ func readN(r io.Reader, b []byte, n int) ([]byte, error) {
var pos int
for {
- alloc := n - len(b)
- if alloc > bytesAllocLimit {
- alloc = bytesAllocLimit
- }
+ alloc := min(n-len(b), bytesAllocLimit)
b = append(b, make([]byte, alloc)...)
_, err := io.ReadFull(r, b[pos:])
diff --git a/decode_map.go b/decode_map.go
index 2592b92..52e0526 100644
--- a/decode_map.go
+++ b/decode_map.go
@@ -5,9 +5,11 @@ import (
"fmt"
"reflect"
- "github.com/vmihailenco/msgpack/v4/codes"
+ "github.com/vmihailenco/msgpack/v5/msgpcode"
)
+var errArrayStruct = errors.New("msgpack: number of fields in array-encoded struct has changed")
+
var (
mapStringStringPtrType = reflect.TypeOf((*map[string]string)(nil))
mapStringStringType = mapStringStringPtrType.Elem()
@@ -19,13 +21,13 @@ var (
)
func decodeMapValue(d *Decoder, v reflect.Value) error {
- size, err := d.DecodeMapLen()
+ n, err := d.DecodeMapLen()
if err != nil {
return err
}
typ := v.Type()
- if size == -1 {
+ if n == -1 {
v.Set(reflect.Zero(typ))
return nil
}
@@ -33,33 +35,18 @@ func decodeMapValue(d *Decoder, v reflect.Value) error {
if v.IsNil() {
v.Set(reflect.MakeMap(typ))
}
- if size == 0 {
+ if n == 0 {
return nil
}
- return decodeMapValueSize(d, v, size)
+ return d.decodeTypedMapValue(v, n)
}
-func decodeMapValueSize(d *Decoder, v reflect.Value, size int) error {
- typ := v.Type()
- keyType := typ.Key()
- valueType := typ.Elem()
-
- for i := 0; i < size; i++ {
- mk := reflect.New(keyType).Elem()
- if err := d.DecodeValue(mk); err != nil {
- return err
- }
-
- mv := reflect.New(valueType).Elem()
- if err := d.DecodeValue(mv); err != nil {
- return err
- }
-
- v.SetMapIndex(mk, mv)
+func (d *Decoder) decodeMapDefault() (interface{}, error) {
+ if d.mapDecoder != nil {
+ return d.mapDecoder(d)
}
-
- return nil
+ return d.DecodeMap()
}
// DecodeMapLen decodes map length. Length is -1 when map is nil.
@@ -69,7 +56,7 @@ func (d *Decoder) DecodeMapLen() (int, error) {
return 0, err
}
- if codes.IsExt(c) {
+ if msgpcode.IsExt(c) {
if err = d.skipExtHeader(c); err != nil {
return 0, err
}
@@ -82,37 +69,22 @@ func (d *Decoder) DecodeMapLen() (int, error) {
return d.mapLen(c)
}
-func (d *Decoder) mapLen(c codes.Code) (int, error) {
- size, err := d._mapLen(c)
- err = expandInvalidCodeMapLenError(c, err)
- return size, err
-}
-
-func (d *Decoder) _mapLen(c codes.Code) (int, error) {
- if c == codes.Nil {
+func (d *Decoder) mapLen(c byte) (int, error) {
+ if c == msgpcode.Nil {
return -1, nil
}
- if c >= codes.FixedMapLow && c <= codes.FixedMapHigh {
- return int(c & codes.FixedMapMask), nil
+ if c >= msgpcode.FixedMapLow && c <= msgpcode.FixedMapHigh {
+ return int(c & msgpcode.FixedMapMask), nil
}
- if c == codes.Map16 {
+ if c == msgpcode.Map16 {
size, err := d.uint16()
return int(size), err
}
- if c == codes.Map32 {
+ if c == msgpcode.Map32 {
size, err := d.uint32()
return int(size), err
}
- return 0, errInvalidCode
-}
-
-var errInvalidCode = errors.New("invalid code")
-
-func expandInvalidCodeMapLenError(c codes.Code, err error) error {
- if err == errInvalidCode {
- return fmt.Errorf("msgpack: invalid code=%x decoding map length", c)
- }
- return err
+ return 0, unexpectedCodeError{code: c, hint: "map length"}
}
func decodeMapStringStringValue(d *Decoder, v reflect.Value) error {
@@ -157,59 +129,79 @@ func decodeMapStringInterfaceValue(d *Decoder, v reflect.Value) error {
}
func (d *Decoder) decodeMapStringInterfacePtr(ptr *map[string]interface{}) error {
- n, err := d.DecodeMapLen()
+ m, err := d.DecodeMap()
if err != nil {
return err
}
- if n == -1 {
- *ptr = nil
- return nil
+ *ptr = m
+ return nil
+}
+
+func (d *Decoder) DecodeMap() (map[string]interface{}, error) {
+ n, err := d.DecodeMapLen()
+ if err != nil {
+ return nil, err
}
- m := *ptr
- if m == nil {
- *ptr = make(map[string]interface{}, min(n, maxMapSize))
- m = *ptr
+ if n == -1 {
+ return nil, nil
}
+ m := make(map[string]interface{}, min(n, maxMapSize))
+
for i := 0; i < n; i++ {
mk, err := d.DecodeString()
if err != nil {
- return err
+ return nil, err
}
mv, err := d.decodeInterfaceCond()
if err != nil {
- return err
+ return nil, err
}
m[mk] = mv
}
- return nil
+ return m, nil
}
-func (d *Decoder) DecodeMap() (interface{}, error) {
- if d.decodeMapFunc != nil {
- return d.decodeMapFunc(d)
- }
-
- size, err := d.DecodeMapLen()
+func (d *Decoder) DecodeUntypedMap() (map[interface{}]interface{}, error) {
+ n, err := d.DecodeMapLen()
if err != nil {
return nil, err
}
- if size == -1 {
+
+ if n == -1 {
return nil, nil
}
- if size == 0 {
- return make(map[string]interface{}), nil
+
+ m := make(map[interface{}]interface{}, min(n, maxMapSize))
+
+ for i := 0; i < n; i++ {
+ mk, err := d.decodeInterfaceCond()
+ if err != nil {
+ return nil, err
+ }
+
+ mv, err := d.decodeInterfaceCond()
+ if err != nil {
+ return nil, err
+ }
+
+ m[mk] = mv
}
- code, err := d.PeekCode()
+ return m, nil
+}
+
+// DecodeTypedMap decodes a typed map. Typed map is a map that has a fixed type for keys and values.
+// Key and value types may be different.
+func (d *Decoder) DecodeTypedMap() (interface{}, error) {
+ n, err := d.DecodeMapLen()
if err != nil {
return nil, err
}
-
- if codes.IsString(code) || codes.IsBin(code) {
- return d.decodeMapStringInterfaceSize(size)
+ if n <= 0 {
+ return nil, nil
}
key, err := d.decodeInterfaceCond()
@@ -225,37 +217,45 @@ func (d *Decoder) DecodeMap() (interface{}, error) {
keyType := reflect.TypeOf(key)
valueType := reflect.TypeOf(value)
+ if !keyType.Comparable() {
+ return nil, fmt.Errorf("msgpack: unsupported map key: %s", keyType.String())
+ }
+
mapType := reflect.MapOf(keyType, valueType)
mapValue := reflect.MakeMap(mapType)
-
mapValue.SetMapIndex(reflect.ValueOf(key), reflect.ValueOf(value))
- size--
- err = decodeMapValueSize(d, mapValue, size)
- if err != nil {
+ n--
+ if err := d.decodeTypedMapValue(mapValue, n); err != nil {
return nil, err
}
return mapValue.Interface(), nil
}
-func (d *Decoder) decodeMapStringInterfaceSize(size int) (map[string]interface{}, error) {
- m := make(map[string]interface{}, min(size, maxMapSize))
- for i := 0; i < size; i++ {
- mk, err := d.DecodeString()
- if err != nil {
- return nil, err
+func (d *Decoder) decodeTypedMapValue(v reflect.Value, n int) error {
+ typ := v.Type()
+ keyType := typ.Key()
+ valueType := typ.Elem()
+
+ for i := 0; i < n; i++ {
+ mk := reflect.New(keyType).Elem()
+ if err := d.DecodeValue(mk); err != nil {
+ return err
}
- mv, err := d.decodeInterfaceCond()
- if err != nil {
- return nil, err
+
+ mv := reflect.New(valueType).Elem()
+ if err := d.DecodeValue(mv); err != nil {
+ return err
}
- m[mk] = mv
+
+ v.SetMapIndex(mk, mv)
}
- return m, nil
+
+ return nil
}
-func (d *Decoder) skipMap(c codes.Code) error {
+func (d *Decoder) skipMap(c byte) error {
n, err := d.mapLen(c)
if err != nil {
return err
@@ -277,63 +277,61 @@ func decodeStructValue(d *Decoder, v reflect.Value) error {
return err
}
- var isArray bool
+ n, err := d.mapLen(c)
+ if err == nil {
+ return d.decodeStruct(v, n)
+ }
- n, err := d._mapLen(c)
- if err != nil {
- var err2 error
- n, err2 = d.arrayLen(c)
- if err2 != nil {
- return expandInvalidCodeMapLenError(c, err)
- }
- isArray = true
+ var err2 error
+ n, err2 = d.arrayLen(c)
+ if err2 != nil {
+ return err
}
- if n == -1 {
- if err = mustSet(v); err != nil {
- return err
- }
+
+ if n <= 0 {
v.Set(reflect.Zero(v.Type()))
return nil
}
- var fields *fields
- if d.useJSONTag {
- fields = jsonStructs.Fields(v.Type())
- } else {
- fields = structs.Fields(v.Type())
+ fields := structs.Fields(v.Type(), d.structTag)
+ if n != len(fields.List) {
+ return errArrayStruct
}
- if isArray {
- for i, f := range fields.List {
- if i >= n {
- break
- }
- if err := f.DecodeValue(d, v); err != nil {
- return err
- }
- }
- // Skip extra values.
- for i := len(fields.List); i < n; i++ {
- if err := d.Skip(); err != nil {
- return err
- }
+ for _, f := range fields.List {
+ if err := f.DecodeValue(d, v); err != nil {
+ return err
}
+ }
+
+ return nil
+}
+
+func (d *Decoder) decodeStruct(v reflect.Value, n int) error {
+ if n == -1 {
+ v.Set(reflect.Zero(v.Type()))
return nil
}
+ fields := structs.Fields(v.Type(), d.structTag)
for i := 0; i < n; i++ {
- name, err := d.DecodeString()
+ name, err := d.decodeStringTemp()
if err != nil {
return err
}
+
if f := fields.Map[name]; f != nil {
if err := f.DecodeValue(d, v); err != nil {
return err
}
- } else {
- if err := d.Skip(); err != nil {
- return err
- }
+ continue
+ }
+
+ if d.flags&disallowUnknownFieldsFlag != 0 {
+ return fmt.Errorf("msgpack: unknown field %q", name)
+ }
+ if err := d.Skip(); err != nil {
+ return err
}
}
diff --git a/decode_number.go b/decode_number.go
index f6b9151..45d6a74 100644
--- a/decode_number.go
+++ b/decode_number.go
@@ -5,7 +5,7 @@ import (
"math"
"reflect"
- "github.com/vmihailenco/msgpack/v4/codes"
+ "github.com/vmihailenco/msgpack/v5/msgpcode"
)
func (d *Decoder) skipN(n int) error {
@@ -18,7 +18,7 @@ func (d *Decoder) uint8() (uint8, error) {
if err != nil {
return 0, err
}
- return uint8(c), nil
+ return c, nil
}
func (d *Decoder) int8() (int8, error) {
@@ -87,33 +87,33 @@ func (d *Decoder) DecodeUint64() (uint64, error) {
return d.uint(c)
}
-func (d *Decoder) uint(c codes.Code) (uint64, error) {
- if c == codes.Nil {
+func (d *Decoder) uint(c byte) (uint64, error) {
+ if c == msgpcode.Nil {
return 0, nil
}
- if codes.IsFixedNum(c) {
+ if msgpcode.IsFixedNum(c) {
return uint64(int8(c)), nil
}
switch c {
- case codes.Uint8:
+ case msgpcode.Uint8:
n, err := d.uint8()
return uint64(n), err
- case codes.Int8:
+ case msgpcode.Int8:
n, err := d.int8()
return uint64(n), err
- case codes.Uint16:
+ case msgpcode.Uint16:
n, err := d.uint16()
return uint64(n), err
- case codes.Int16:
+ case msgpcode.Int16:
n, err := d.int16()
return uint64(n), err
- case codes.Uint32:
+ case msgpcode.Uint32:
n, err := d.uint32()
return uint64(n), err
- case codes.Int32:
+ case msgpcode.Int32:
n, err := d.int32()
return uint64(n), err
- case codes.Uint64, codes.Int64:
+ case msgpcode.Uint64, msgpcode.Int64:
return d.uint64()
}
return 0, fmt.Errorf("msgpack: invalid code=%x decoding uint64", c)
@@ -129,33 +129,33 @@ func (d *Decoder) DecodeInt64() (int64, error) {
return d.int(c)
}
-func (d *Decoder) int(c codes.Code) (int64, error) {
- if c == codes.Nil {
+func (d *Decoder) int(c byte) (int64, error) {
+ if c == msgpcode.Nil {
return 0, nil
}
- if codes.IsFixedNum(c) {
+ if msgpcode.IsFixedNum(c) {
return int64(int8(c)), nil
}
switch c {
- case codes.Uint8:
+ case msgpcode.Uint8:
n, err := d.uint8()
return int64(n), err
- case codes.Int8:
+ case msgpcode.Int8:
n, err := d.uint8()
return int64(int8(n)), err
- case codes.Uint16:
+ case msgpcode.Uint16:
n, err := d.uint16()
return int64(n), err
- case codes.Int16:
+ case msgpcode.Int16:
n, err := d.uint16()
return int64(int16(n)), err
- case codes.Uint32:
+ case msgpcode.Uint32:
n, err := d.uint32()
return int64(n), err
- case codes.Int32:
+ case msgpcode.Int32:
n, err := d.uint32()
return int64(int32(n)), err
- case codes.Uint64, codes.Int64:
+ case msgpcode.Uint64, msgpcode.Int64:
n, err := d.uint64()
return int64(n), err
}
@@ -170,8 +170,8 @@ func (d *Decoder) DecodeFloat32() (float32, error) {
return d.float32(c)
}
-func (d *Decoder) float32(c codes.Code) (float32, error) {
- if c == codes.Float {
+func (d *Decoder) float32(c byte) (float32, error) {
+ if c == msgpcode.Float {
n, err := d.uint32()
if err != nil {
return 0, err
@@ -195,15 +195,15 @@ func (d *Decoder) DecodeFloat64() (float64, error) {
return d.float64(c)
}
-func (d *Decoder) float64(c codes.Code) (float64, error) {
+func (d *Decoder) float64(c byte) (float64, error) {
switch c {
- case codes.Float:
+ case msgpcode.Float:
n, err := d.float32(c)
if err != nil {
return 0, err
}
return float64(n), nil
- case codes.Double:
+ case msgpcode.Double:
n, err := d.uint64()
if err != nil {
return 0, err
@@ -263,9 +263,6 @@ func decodeFloat32Value(d *Decoder, v reflect.Value) error {
if err != nil {
return err
}
- if err = mustSet(v); err != nil {
- return err
- }
v.SetFloat(float64(f))
return nil
}
@@ -275,9 +272,6 @@ func decodeFloat64Value(d *Decoder, v reflect.Value) error {
if err != nil {
return err
}
- if err = mustSet(v); err != nil {
- return err
- }
v.SetFloat(f)
return nil
}
@@ -287,9 +281,6 @@ func decodeInt64Value(d *Decoder, v reflect.Value) error {
if err != nil {
return err
}
- if err = mustSet(v); err != nil {
- return err
- }
v.SetInt(n)
return nil
}
@@ -299,9 +290,6 @@ func decodeUint64Value(d *Decoder, v reflect.Value) error {
if err != nil {
return err
}
- if err = mustSet(v); err != nil {
- return err
- }
v.SetUint(n)
return nil
}
diff --git a/decode_query.go b/decode_query.go
index 80cd80e..c302ed1 100644
--- a/decode_query.go
+++ b/decode_query.go
@@ -5,7 +5,7 @@ import (
"strconv"
"strings"
- "github.com/vmihailenco/msgpack/v4/codes"
+ "github.com/vmihailenco/msgpack/v5/msgpcode"
)
type queryResult struct {
@@ -57,9 +57,9 @@ func (d *Decoder) query(q *queryResult) error {
}
switch {
- case code == codes.Map16 || code == codes.Map32 || codes.IsFixedMap(code):
+ case code == msgpcode.Map16 || code == msgpcode.Map32 || msgpcode.IsFixedMap(code):
err = d.queryMapKey(q)
- case code == codes.Array16 || code == codes.Array32 || codes.IsFixedArray(code):
+ case code == msgpcode.Array16 || code == msgpcode.Array32 || msgpcode.IsFixedArray(code):
err = d.queryArrayIndex(q)
default:
err = fmt.Errorf("msgpack: unsupported code=%x decoding key=%q", code, q.key)
@@ -77,12 +77,12 @@ func (d *Decoder) queryMapKey(q *queryResult) error {
}
for i := 0; i < n; i++ {
- k, err := d.bytesNoCopy()
+ key, err := d.decodeStringTemp()
if err != nil {
return err
}
- if string(k) == q.key {
+ if key == q.key {
if err := d.query(q); err != nil {
return err
}
diff --git a/decode_slice.go b/decode_slice.go
index adf17ae..db6f7c5 100644
--- a/decode_slice.go
+++ b/decode_slice.go
@@ -4,7 +4,7 @@ import (
"fmt"
"reflect"
- "github.com/vmihailenco/msgpack/v4/codes"
+ "github.com/vmihailenco/msgpack/v5/msgpcode"
)
var sliceStringPtrType = reflect.TypeOf((*[]string)(nil))
@@ -18,17 +18,17 @@ func (d *Decoder) DecodeArrayLen() (int, error) {
return d.arrayLen(c)
}
-func (d *Decoder) arrayLen(c codes.Code) (int, error) {
- if c == codes.Nil {
+func (d *Decoder) arrayLen(c byte) (int, error) {
+ if c == msgpcode.Nil {
return -1, nil
- } else if c >= codes.FixedArrayLow && c <= codes.FixedArrayHigh {
- return int(c & codes.FixedArrayMask), nil
+ } else if c >= msgpcode.FixedArrayLow && c <= msgpcode.FixedArrayHigh {
+ return int(c & msgpcode.FixedArrayMask), nil
}
switch c {
- case codes.Array16:
+ case msgpcode.Array16:
n, err := d.uint16()
return int(n), err
- case codes.Array32:
+ case msgpcode.Array32:
n, err := d.uint32()
return int(n), err
}
@@ -154,7 +154,7 @@ func (d *Decoder) DecodeSlice() ([]interface{}, error) {
return d.decodeSlice(c)
}
-func (d *Decoder) decodeSlice(c codes.Code) ([]interface{}, error) {
+func (d *Decoder) decodeSlice(c byte) ([]interface{}, error) {
n, err := d.arrayLen(c)
if err != nil {
return nil, err
@@ -175,7 +175,7 @@ func (d *Decoder) decodeSlice(c codes.Code) ([]interface{}, error) {
return s, nil
}
-func (d *Decoder) skipSlice(c codes.Code) error {
+func (d *Decoder) skipSlice(c byte) error {
n, err := d.arrayLen(c)
if err != nil {
return err
diff --git a/decode_string.go b/decode_string.go
index 20730c2..e837e08 100644
--- a/decode_string.go
+++ b/decode_string.go
@@ -4,30 +4,38 @@ import (
"fmt"
"reflect"
- "github.com/vmihailenco/msgpack/v4/codes"
+ "github.com/vmihailenco/msgpack/v5/msgpcode"
)
-func (d *Decoder) bytesLen(c codes.Code) (int, error) {
- if c == codes.Nil {
+func (d *Decoder) bytesLen(c byte) (int, error) {
+ if c == msgpcode.Nil {
return -1, nil
- } else if codes.IsFixedString(c) {
- return int(c & codes.FixedStrMask), nil
}
+
+ if msgpcode.IsFixedString(c) {
+ return int(c & msgpcode.FixedStrMask), nil
+ }
+
switch c {
- case codes.Str8, codes.Bin8:
+ case msgpcode.Str8, msgpcode.Bin8:
n, err := d.uint8()
return int(n), err
- case codes.Str16, codes.Bin16:
+ case msgpcode.Str16, msgpcode.Bin16:
n, err := d.uint16()
return int(n), err
- case codes.Str32, codes.Bin32:
+ case msgpcode.Str32, msgpcode.Bin32:
n, err := d.uint32()
return int(n), err
}
- return 0, fmt.Errorf("msgpack: invalid code=%x decoding bytes length", c)
+
+ return 0, fmt.Errorf("msgpack: invalid code=%x decoding string/bytes length", c)
}
func (d *Decoder) DecodeString() (string, error) {
+ if intern := d.flags&useInternedStringsFlag != 0; intern || len(d.dict) > 0 {
+ return d.decodeInternedString(intern)
+ }
+
c, err := d.readCode()
if err != nil {
return "", err
@@ -35,11 +43,15 @@ func (d *Decoder) DecodeString() (string, error) {
return d.string(c)
}
-func (d *Decoder) string(c codes.Code) (string, error) {
+func (d *Decoder) string(c byte) (string, error) {
n, err := d.bytesLen(c)
if err != nil {
return "", err
}
+ return d.stringWithLen(n)
+}
+
+func (d *Decoder) stringWithLen(n int) (string, error) {
if n <= 0 {
return "", nil
}
@@ -52,9 +64,6 @@ func decodeStringValue(d *Decoder, v reflect.Value) error {
if err != nil {
return err
}
- if err = mustSet(v); err != nil {
- return err
- }
v.SetString(s)
return nil
}
@@ -75,7 +84,7 @@ func (d *Decoder) DecodeBytes() ([]byte, error) {
return d.bytes(c, nil)
}
-func (d *Decoder) bytes(c codes.Code, b []byte) ([]byte, error) {
+func (d *Decoder) bytes(c byte, b []byte) ([]byte, error) {
n, err := d.bytesLen(c)
if err != nil {
return nil, err
@@ -86,19 +95,30 @@ func (d *Decoder) bytes(c codes.Code, b []byte) ([]byte, error) {
return readN(d.r, b, n)
}
-func (d *Decoder) bytesNoCopy() ([]byte, error) {
+func (d *Decoder) decodeStringTemp() (string, error) {
+ if intern := d.flags&useInternedStringsFlag != 0; intern || len(d.dict) > 0 {
+ return d.decodeInternedString(intern)
+ }
+
c, err := d.readCode()
if err != nil {
- return nil, err
+ return "", err
}
+
n, err := d.bytesLen(c)
if err != nil {
- return nil, err
+ return "", err
}
if n == -1 {
- return nil, nil
+ return "", nil
}
- return d.readN(n)
+
+ b, err := d.readN(n)
+ if err != nil {
+ return "", err
+ }
+
+ return bytesToString(b), nil
}
func (d *Decoder) decodeBytesPtr(ptr *[]byte) error {
@@ -109,7 +129,7 @@ func (d *Decoder) decodeBytesPtr(ptr *[]byte) error {
return d.bytesPtr(c, ptr)
}
-func (d *Decoder) bytesPtr(c codes.Code, ptr *[]byte) error {
+func (d *Decoder) bytesPtr(c byte, ptr *[]byte) error {
n, err := d.bytesLen(c)
if err != nil {
return err
@@ -123,7 +143,7 @@ func (d *Decoder) bytesPtr(c codes.Code, ptr *[]byte) error {
return err
}
-func (d *Decoder) skipBytes(c codes.Code) error {
+func (d *Decoder) skipBytes(c byte) error {
n, err := d.bytesLen(c)
if err != nil {
return err
@@ -145,9 +165,6 @@ func decodeBytesValue(d *Decoder, v reflect.Value) error {
return err
}
- if err = mustSet(v); err != nil {
- return err
- }
v.SetBytes(b)
return nil
diff --git a/decode_value.go b/decode_value.go
index 5a5b5ff..d2ff2ae 100644
--- a/decode_value.go
+++ b/decode_value.go
@@ -7,8 +7,10 @@ import (
"reflect"
)
-var interfaceType = reflect.TypeOf((*interface{})(nil)).Elem()
-var stringType = reflect.TypeOf((*string)(nil)).Elem()
+var (
+ interfaceType = reflect.TypeOf((*interface{})(nil)).Elem()
+ stringType = reflect.TypeOf((*string)(nil)).Elem()
+)
var valueDecoders []decoderFunc
@@ -43,13 +45,6 @@ func init() {
}
}
-func mustSet(v reflect.Value) error {
- if !v.CanSet() {
- return fmt.Errorf("msgpack: Decode(nonsettable %s)", v.Type())
- }
- return nil
-}
-
func getDecoder(typ reflect.Type) decoderFunc {
if v, ok := typeDecMap.Load(typ); ok {
return v.(decoderFunc)
@@ -62,33 +57,45 @@ func getDecoder(typ reflect.Type) decoderFunc {
func _getDecoder(typ reflect.Type) decoderFunc {
kind := typ.Kind()
+ if kind == reflect.Ptr {
+ if _, ok := typeDecMap.Load(typ.Elem()); ok {
+ return ptrValueDecoder(typ)
+ }
+ }
+
if typ.Implements(customDecoderType) {
- return decodeCustomValue
+ return nilAwareDecoder(typ, decodeCustomValue)
}
if typ.Implements(unmarshalerType) {
- return unmarshalValue
+ return nilAwareDecoder(typ, unmarshalValue)
}
if typ.Implements(binaryUnmarshalerType) {
- return unmarshalBinaryValue
+ return nilAwareDecoder(typ, unmarshalBinaryValue)
+ }
+ if typ.Implements(textUnmarshalerType) {
+ return nilAwareDecoder(typ, unmarshalTextValue)
}
// Addressable struct field value.
if kind != reflect.Ptr {
ptr := reflect.PtrTo(typ)
if ptr.Implements(customDecoderType) {
- return decodeCustomValueAddr
+ return addrDecoder(nilAwareDecoder(typ, decodeCustomValue))
}
if ptr.Implements(unmarshalerType) {
- return unmarshalValueAddr
+ return addrDecoder(nilAwareDecoder(typ, unmarshalValue))
}
if ptr.Implements(binaryUnmarshalerType) {
- return unmarshalBinaryValueAddr
+ return addrDecoder(nilAwareDecoder(typ, unmarshalBinaryValue))
+ }
+ if ptr.Implements(textUnmarshalerType) {
+ return addrDecoder(nilAwareDecoder(typ, unmarshalTextValue))
}
}
switch kind {
case reflect.Ptr:
- return ptrDecoderFunc(typ)
+ return ptrValueDecoder(typ)
case reflect.Slice:
elem := typ.Elem()
if elem.Kind() == reflect.Uint8 {
@@ -115,83 +122,50 @@ func _getDecoder(typ reflect.Type) decoderFunc {
return valueDecoders[kind]
}
-func ptrDecoderFunc(typ reflect.Type) decoderFunc {
+func ptrValueDecoder(typ reflect.Type) decoderFunc {
decoder := getDecoder(typ.Elem())
return func(d *Decoder, v reflect.Value) error {
if d.hasNilCode() {
- if err := mustSet(v); err != nil {
- return err
- }
if !v.IsNil() {
v.Set(reflect.Zero(v.Type()))
}
return d.DecodeNil()
}
if v.IsNil() {
- if err := mustSet(v); err != nil {
- return err
- }
v.Set(reflect.New(v.Type().Elem()))
}
return decoder(d, v.Elem())
}
}
-func decodeCustomValueAddr(d *Decoder, v reflect.Value) error {
- if !v.CanAddr() {
- return fmt.Errorf("msgpack: Decode(nonaddressable %T)", v.Interface())
- }
- return decodeCustomValue(d, v.Addr())
-}
-
-func decodeCustomValue(d *Decoder, v reflect.Value) error {
- if d.hasNilCode() {
- return d.decodeNilValue(v)
- }
-
- if v.IsNil() {
- v.Set(reflect.New(v.Type().Elem()))
+func addrDecoder(fn decoderFunc) decoderFunc {
+ return func(d *Decoder, v reflect.Value) error {
+ if !v.CanAddr() {
+ return fmt.Errorf("msgpack: Decode(nonaddressable %T)", v.Interface())
+ }
+ return fn(d, v.Addr())
}
-
- decoder := v.Interface().(CustomDecoder)
- return decoder.DecodeMsgpack(d)
}
-func unmarshalValueAddr(d *Decoder, v reflect.Value) error {
- if !v.CanAddr() {
- return fmt.Errorf("msgpack: Decode(nonaddressable %T)", v.Interface())
+func nilAwareDecoder(typ reflect.Type, fn decoderFunc) decoderFunc {
+ if nilable(typ.Kind()) {
+ return func(d *Decoder, v reflect.Value) error {
+ if d.hasNilCode() {
+ return d.decodeNilValue(v)
+ }
+ if v.IsNil() {
+ v.Set(reflect.New(v.Type().Elem()))
+ }
+ return fn(d, v)
+ }
}
- return unmarshalValue(d, v.Addr())
-}
-func unmarshalValue(d *Decoder, v reflect.Value) error {
- if d.extLen == 0 || d.extLen == 1 {
+ return func(d *Decoder, v reflect.Value) error {
if d.hasNilCode() {
return d.decodeNilValue(v)
}
+ return fn(d, v)
}
-
- if v.IsNil() {
- v.Set(reflect.New(v.Type().Elem()))
- }
-
- if d.extLen != 0 {
- b, err := d.readN(d.extLen)
- if err != nil {
- return err
- }
- d.rec = b
- } else {
- d.rec = make([]byte, 0, 64)
- if err := d.Skip(); err != nil {
- return err
- }
- }
-
- unmarshaler := v.Interface().(Unmarshaler)
- err := unmarshaler.UnmarshalMsgpack(d.rec)
- d.rec = nil
- return err
}
func decodeBoolValue(d *Decoder, v reflect.Value) error {
@@ -199,9 +173,6 @@ func decodeBoolValue(d *Decoder, v reflect.Value) error {
if err != nil {
return err
}
- if err = mustSet(v); err != nil {
- return err
- }
v.SetBool(flag)
return nil
}
@@ -210,16 +181,7 @@ func decodeInterfaceValue(d *Decoder, v reflect.Value) error {
if v.IsNil() {
return d.interfaceValue(v)
}
-
- elem := v.Elem()
- if !elem.CanAddr() {
- if d.hasNilCode() {
- v.Set(reflect.Zero(v.Type()))
- return d.DecodeNil()
- }
- }
-
- return d.DecodeValue(elem)
+ return d.DecodeValue(v.Elem())
}
func (d *Decoder) interfaceValue(v reflect.Value) error {
@@ -248,22 +210,26 @@ func decodeUnsupportedValue(d *Decoder, v reflect.Value) error {
//------------------------------------------------------------------------------
-func unmarshalBinaryValueAddr(d *Decoder, v reflect.Value) error {
- if !v.CanAddr() {
- return fmt.Errorf("msgpack: Decode(nonaddressable %T)", v.Interface())
- }
- return unmarshalBinaryValue(d, v.Addr())
+func decodeCustomValue(d *Decoder, v reflect.Value) error {
+ decoder := v.Interface().(CustomDecoder)
+ return decoder.DecodeMsgpack(d)
}
-func unmarshalBinaryValue(d *Decoder, v reflect.Value) error {
- if d.hasNilCode() {
- return d.decodeNilValue(v)
- }
+func unmarshalValue(d *Decoder, v reflect.Value) error {
+ var b []byte
- if v.IsNil() {
- v.Set(reflect.New(v.Type().Elem()))
+ d.rec = make([]byte, 0, 64)
+ if err := d.Skip(); err != nil {
+ return err
}
+ b = d.rec
+ d.rec = nil
+
+ unmarshaler := v.Interface().(Unmarshaler)
+ return unmarshaler.UnmarshalMsgpack(b)
+}
+func unmarshalBinaryValue(d *Decoder, v reflect.Value) error {
data, err := d.DecodeBytes()
if err != nil {
return err
@@ -272,3 +238,13 @@ func unmarshalBinaryValue(d *Decoder, v reflect.Value) error {
unmarshaler := v.Interface().(encoding.BinaryUnmarshaler)
return unmarshaler.UnmarshalBinary(data)
}
+
+func unmarshalTextValue(d *Decoder, v reflect.Value) error {
+ data, err := d.DecodeBytes()
+ if err != nil {
+ return err
+ }
+
+ unmarshaler := v.Interface().(encoding.TextUnmarshaler)
+ return unmarshaler.UnmarshalText(data)
+}
diff --git a/encode.go b/encode.go
index a5c35c8..0ef6212 100644
--- a/encode.go
+++ b/encode.go
@@ -4,107 +4,195 @@ import (
"bytes"
"io"
"reflect"
+ "sync"
"time"
- "github.com/vmihailenco/msgpack/v4/codes"
+ "github.com/vmihailenco/msgpack/v5/msgpcode"
+)
+
+const (
+ sortMapKeysFlag uint32 = 1 << iota
+ arrayEncodedStructsFlag
+ useCompactIntsFlag
+ useCompactFloatsFlag
+ useInternedStringsFlag
+ omitEmptyFlag
)
type writer interface {
io.Writer
WriteByte(byte) error
- WriteString(string) (int, error)
}
type byteWriter struct {
io.Writer
-
- buf []byte
- bootstrap [64]byte
}
-func newByteWriter(w io.Writer) *byteWriter {
- bw := &byteWriter{
+func newByteWriter(w io.Writer) byteWriter {
+ return byteWriter{
Writer: w,
}
- bw.buf = bw.bootstrap[:]
- return bw
}
-func (w *byteWriter) WriteByte(c byte) error {
- w.buf = w.buf[:1]
- w.buf[0] = c
- _, err := w.Write(w.buf)
+func (bw byteWriter) WriteByte(c byte) error {
+ _, err := bw.Write([]byte{c})
return err
}
-func (w *byteWriter) WriteString(s string) (int, error) {
- w.buf = append(w.buf[:0], s...)
- return w.Write(w.buf)
+//------------------------------------------------------------------------------
+
+var encPool = sync.Pool{
+ New: func() interface{} {
+ return NewEncoder(nil)
+ },
+}
+
+func GetEncoder() *Encoder {
+ return encPool.Get().(*Encoder)
+}
+
+func PutEncoder(enc *Encoder) {
+ enc.w = nil
+ encPool.Put(enc)
}
// Marshal returns the MessagePack encoding of v.
func Marshal(v interface{}) ([]byte, error) {
+ enc := GetEncoder()
+
var buf bytes.Buffer
- err := NewEncoder(&buf).Encode(v)
- return buf.Bytes(), err
+ enc.Reset(&buf)
+
+ err := enc.Encode(v)
+ b := buf.Bytes()
+
+ PutEncoder(enc)
+
+ if err != nil {
+ return nil, err
+ }
+ return b, err
}
type Encoder struct {
w writer
- buf []byte
- // timeBuf is lazily allocated in encodeTime() to
- // avoid allocations when time.Time value are encoded
- //
- // buf can't be reused for time encoding, as buf is used
- // to encode msgpack extLen
+ buf []byte
timeBuf []byte
- sortMapKeys bool
- structAsArray bool
- useJSONTag bool
- useCompact bool
+ dict map[string]int
+
+ flags uint32
+ structTag string
}
// NewEncoder returns a new encoder that writes to w.
func NewEncoder(w io.Writer) *Encoder {
- bw, ok := w.(writer)
- if !ok {
- bw = newByteWriter(w)
- }
- return &Encoder{
- w: bw,
+ e := &Encoder{
buf: make([]byte, 9),
}
+ e.Reset(w)
+ return e
+}
+
+// Writer returns the Encoder's writer.
+func (e *Encoder) Writer() io.Writer {
+ return e.w
+}
+
+// Reset discards any buffered data, resets all state, and switches the writer to write to w.
+func (e *Encoder) Reset(w io.Writer) {
+ e.ResetDict(w, nil)
+}
+
+// ResetDict is like Reset, but also resets the dict.
+func (e *Encoder) ResetDict(w io.Writer, dict map[string]int) {
+ e.resetWriter(w)
+ e.flags = 0
+ e.structTag = ""
+ e.dict = dict
+}
+
+func (e *Encoder) WithDict(dict map[string]int, fn func(*Encoder) error) error {
+ oldDict := e.dict
+ e.dict = dict
+ err := fn(e)
+ e.dict = oldDict
+ return err
+}
+
+func (e *Encoder) resetWriter(w io.Writer) {
+ if bw, ok := w.(writer); ok {
+ e.w = bw
+ } else {
+ e.w = newByteWriter(w)
+ }
}
-// SortMapKeys causes the Encoder to encode map keys in increasing order.
+// SetSortMapKeys causes the Encoder to encode map keys in increasing order.
// Supported map types are:
// - map[string]string
// - map[string]interface{}
-func (e *Encoder) SortMapKeys(flag bool) *Encoder {
- e.sortMapKeys = flag
+func (e *Encoder) SetSortMapKeys(on bool) *Encoder {
+ if on {
+ e.flags |= sortMapKeysFlag
+ } else {
+ e.flags &= ^sortMapKeysFlag
+ }
return e
}
-// StructAsArray causes the Encoder to encode Go structs as msgpack arrays.
-func (e *Encoder) StructAsArray(flag bool) *Encoder {
- e.structAsArray = flag
- return e
+// SetCustomStructTag causes the Encoder to use a custom struct tag as
+// fallback option if there is no msgpack tag.
+func (e *Encoder) SetCustomStructTag(tag string) {
+ e.structTag = tag
}
-// UseJSONTag causes the Encoder to use json struct tag as fallback option
-// if there is no msgpack tag.
-func (e *Encoder) UseJSONTag(flag bool) *Encoder {
- e.useJSONTag = flag
- return e
+// SetOmitEmpty causes the Encoder to omit empty values by default.
+func (e *Encoder) SetOmitEmpty(on bool) {
+ if on {
+ e.flags |= omitEmptyFlag
+ } else {
+ e.flags &= ^omitEmptyFlag
+ }
+}
+
+// UseArrayEncodedStructs causes the Encoder to encode Go structs as msgpack arrays.
+func (e *Encoder) UseArrayEncodedStructs(on bool) {
+ if on {
+ e.flags |= arrayEncodedStructsFlag
+ } else {
+ e.flags &= ^arrayEncodedStructsFlag
+ }
}
// UseCompactEncoding causes the Encoder to chose the most compact encoding.
// For example, it allows to encode small Go int64 as msgpack int8 saving 7 bytes.
-func (e *Encoder) UseCompactEncoding(flag bool) *Encoder {
- e.useCompact = flag
- return e
+func (e *Encoder) UseCompactInts(on bool) {
+ if on {
+ e.flags |= useCompactIntsFlag
+ } else {
+ e.flags &= ^useCompactIntsFlag
+ }
+}
+
+// UseCompactFloats causes the Encoder to chose a compact integer encoding
+// for floats that can be represented as integers.
+func (e *Encoder) UseCompactFloats(on bool) {
+ if on {
+ e.flags |= useCompactFloatsFlag
+ } else {
+ e.flags &= ^useCompactFloatsFlag
+ }
+}
+
+// UseInternedStrings causes the Encoder to intern strings.
+func (e *Encoder) UseInternedStrings(on bool) {
+ if on {
+ e.flags |= useInternedStringsFlag
+ } else {
+ e.flags &= ^useInternedStringsFlag
+ }
}
func (e *Encoder) Encode(v interface{}) error {
@@ -116,11 +204,11 @@ func (e *Encoder) Encode(v interface{}) error {
case []byte:
return e.EncodeBytes(v)
case int:
- return e.encodeInt64Cond(int64(v))
+ return e.EncodeInt(int64(v))
case int64:
return e.encodeInt64Cond(v)
case uint:
- return e.encodeUint64Cond(uint64(v))
+ return e.EncodeUint(uint64(v))
case uint64:
return e.encodeUint64Cond(v)
case bool:
@@ -152,18 +240,22 @@ func (e *Encoder) EncodeValue(v reflect.Value) error {
}
func (e *Encoder) EncodeNil() error {
- return e.writeCode(codes.Nil)
+ return e.writeCode(msgpcode.Nil)
}
func (e *Encoder) EncodeBool(value bool) error {
if value {
- return e.writeCode(codes.True)
+ return e.writeCode(msgpcode.True)
}
- return e.writeCode(codes.False)
+ return e.writeCode(msgpcode.False)
+}
+
+func (e *Encoder) EncodeDuration(d time.Duration) error {
+ return e.EncodeInt(int64(d))
}
-func (e *Encoder) writeCode(c codes.Code) error {
- return e.w.WriteByte(byte(c))
+func (e *Encoder) writeCode(c byte) error {
+ return e.w.WriteByte(c)
}
func (e *Encoder) write(b []byte) error {
@@ -172,6 +264,6 @@ func (e *Encoder) write(b []byte) error {
}
func (e *Encoder) writeString(s string) error {
- _, err := e.w.WriteString(s)
+ _, err := e.w.Write(stringToBytes(s))
return err
}
diff --git a/encode_map.go b/encode_map.go
index 6dd635f..ba4c61b 100644
--- a/encode_map.go
+++ b/encode_map.go
@@ -1,10 +1,11 @@
package msgpack
import (
+ "math"
"reflect"
"sort"
- "github.com/vmihailenco/msgpack/v4/codes"
+ "github.com/vmihailenco/msgpack/v5/msgpcode"
)
func encodeMapValue(e *Encoder, v reflect.Value) error {
@@ -16,11 +17,12 @@ func encodeMapValue(e *Encoder, v reflect.Value) error {
return err
}
- for _, key := range v.MapKeys() {
- if err := e.EncodeValue(key); err != nil {
+ iter := v.MapRange()
+ for iter.Next() {
+ if err := e.EncodeValue(iter.Key()); err != nil {
return err
}
- if err := e.EncodeValue(v.MapIndex(key)); err != nil {
+ if err := e.EncodeValue(iter.Value()); err != nil {
return err
}
}
@@ -38,7 +40,7 @@ func encodeMapStringStringValue(e *Encoder, v reflect.Value) error {
}
m := v.Convert(mapStringStringType).Interface().(map[string]string)
- if e.sortMapKeys {
+ if e.flags&sortMapKeysFlag != 0 {
return e.encodeSortedMapStringString(m)
}
@@ -58,16 +60,20 @@ func encodeMapStringInterfaceValue(e *Encoder, v reflect.Value) error {
if v.IsNil() {
return e.EncodeNil()
}
-
- if err := e.EncodeMapLen(v.Len()); err != nil {
- return err
- }
-
m := v.Convert(mapStringInterfaceType).Interface().(map[string]interface{})
- if e.sortMapKeys {
- return e.encodeSortedMapStringInterface(m)
+ if e.flags&sortMapKeysFlag != 0 {
+ return e.EncodeMapSorted(m)
}
+ return e.EncodeMap(m)
+}
+func (e *Encoder) EncodeMap(m map[string]interface{}) error {
+ if m == nil {
+ return e.EncodeNil()
+ }
+ if err := e.EncodeMapLen(len(m)); err != nil {
+ return err
+ }
for mk, mv := range m {
if err := e.EncodeString(mk); err != nil {
return err
@@ -76,23 +82,30 @@ func encodeMapStringInterfaceValue(e *Encoder, v reflect.Value) error {
return err
}
}
-
return nil
}
-func (e *Encoder) encodeSortedMapStringString(m map[string]string) error {
+func (e *Encoder) EncodeMapSorted(m map[string]interface{}) error {
+ if m == nil {
+ return e.EncodeNil()
+ }
+ if err := e.EncodeMapLen(len(m)); err != nil {
+ return err
+ }
+
keys := make([]string, 0, len(m))
+
for k := range m {
keys = append(keys, k)
}
+
sort.Strings(keys)
for _, k := range keys {
- err := e.EncodeString(k)
- if err != nil {
+ if err := e.EncodeString(k); err != nil {
return err
}
- if err = e.EncodeString(m[k]); err != nil {
+ if err := e.Encode(m[k]); err != nil {
return err
}
}
@@ -100,7 +113,7 @@ func (e *Encoder) encodeSortedMapStringString(m map[string]string) error {
return nil
}
-func (e *Encoder) encodeSortedMapStringInterface(m map[string]interface{}) error {
+func (e *Encoder) encodeSortedMapStringString(m map[string]string) error {
keys := make([]string, 0, len(m))
for k := range m {
keys = append(keys, k)
@@ -112,7 +125,7 @@ func (e *Encoder) encodeSortedMapStringInterface(m map[string]interface{}) error
if err != nil {
return err
}
- if err = e.Encode(m[k]); err != nil {
+ if err = e.EncodeString(m[k]); err != nil {
return err
}
}
@@ -122,26 +135,20 @@ func (e *Encoder) encodeSortedMapStringInterface(m map[string]interface{}) error
func (e *Encoder) EncodeMapLen(l int) error {
if l < 16 {
- return e.writeCode(codes.FixedMapLow | codes.Code(l))
+ return e.writeCode(msgpcode.FixedMapLow | byte(l))
}
- if l < 65536 {
- return e.write2(codes.Map16, uint16(l))
+ if l <= math.MaxUint16 {
+ return e.write2(msgpcode.Map16, uint16(l))
}
- return e.write4(codes.Map32, uint32(l))
+ return e.write4(msgpcode.Map32, uint32(l))
}
func encodeStructValue(e *Encoder, strct reflect.Value) error {
- var structFields *fields
- if e.useJSONTag {
- structFields = jsonStructs.Fields(strct.Type())
- } else {
- structFields = structs.Fields(strct.Type())
- }
-
- if e.structAsArray || structFields.AsArray {
+ structFields := structs.Fields(strct.Type(), e.structTag)
+ if e.flags&arrayEncodedStructsFlag != 0 || structFields.AsArray {
return encodeStructValueAsArray(e, strct, structFields.List)
}
- fields := structFields.OmitEmpty(strct)
+ fields := structFields.OmitEmpty(strct, e.flags&omitEmptyFlag != 0)
if err := e.EncodeMapLen(len(fields)); err != nil {
return err
diff --git a/encode_number.go b/encode_number.go
index e92ad88..63c311b 100644
--- a/encode_number.go
+++ b/encode_number.go
@@ -4,16 +4,16 @@ import (
"math"
"reflect"
- "github.com/vmihailenco/msgpack/v4/codes"
+ "github.com/vmihailenco/msgpack/v5/msgpcode"
)
// EncodeUint8 encodes an uint8 in 2 bytes preserving type of the number.
func (e *Encoder) EncodeUint8(n uint8) error {
- return e.write1(codes.Uint8, n)
+ return e.write1(msgpcode.Uint8, n)
}
func (e *Encoder) encodeUint8Cond(n uint8) error {
- if e.useCompact {
+ if e.flags&useCompactIntsFlag != 0 {
return e.EncodeUint(uint64(n))
}
return e.EncodeUint8(n)
@@ -21,11 +21,11 @@ func (e *Encoder) encodeUint8Cond(n uint8) error {
// EncodeUint16 encodes an uint16 in 3 bytes preserving type of the number.
func (e *Encoder) EncodeUint16(n uint16) error {
- return e.write2(codes.Uint16, n)
+ return e.write2(msgpcode.Uint16, n)
}
func (e *Encoder) encodeUint16Cond(n uint16) error {
- if e.useCompact {
+ if e.flags&useCompactIntsFlag != 0 {
return e.EncodeUint(uint64(n))
}
return e.EncodeUint16(n)
@@ -33,11 +33,11 @@ func (e *Encoder) encodeUint16Cond(n uint16) error {
// EncodeUint32 encodes an uint16 in 5 bytes preserving type of the number.
func (e *Encoder) EncodeUint32(n uint32) error {
- return e.write4(codes.Uint32, n)
+ return e.write4(msgpcode.Uint32, n)
}
func (e *Encoder) encodeUint32Cond(n uint32) error {
- if e.useCompact {
+ if e.flags&useCompactIntsFlag != 0 {
return e.EncodeUint(uint64(n))
}
return e.EncodeUint32(n)
@@ -45,11 +45,11 @@ func (e *Encoder) encodeUint32Cond(n uint32) error {
// EncodeUint64 encodes an uint16 in 9 bytes preserving type of the number.
func (e *Encoder) EncodeUint64(n uint64) error {
- return e.write8(codes.Uint64, n)
+ return e.write8(msgpcode.Uint64, n)
}
func (e *Encoder) encodeUint64Cond(n uint64) error {
- if e.useCompact {
+ if e.flags&useCompactIntsFlag != 0 {
return e.EncodeUint(n)
}
return e.EncodeUint64(n)
@@ -57,11 +57,11 @@ func (e *Encoder) encodeUint64Cond(n uint64) error {
// EncodeInt8 encodes an int8 in 2 bytes preserving type of the number.
func (e *Encoder) EncodeInt8(n int8) error {
- return e.write1(codes.Int8, uint8(n))
+ return e.write1(msgpcode.Int8, uint8(n))
}
func (e *Encoder) encodeInt8Cond(n int8) error {
- if e.useCompact {
+ if e.flags&useCompactIntsFlag != 0 {
return e.EncodeInt(int64(n))
}
return e.EncodeInt8(n)
@@ -69,11 +69,11 @@ func (e *Encoder) encodeInt8Cond(n int8) error {
// EncodeInt16 encodes an int16 in 3 bytes preserving type of the number.
func (e *Encoder) EncodeInt16(n int16) error {
- return e.write2(codes.Int16, uint16(n))
+ return e.write2(msgpcode.Int16, uint16(n))
}
func (e *Encoder) encodeInt16Cond(n int16) error {
- if e.useCompact {
+ if e.flags&useCompactIntsFlag != 0 {
return e.EncodeInt(int64(n))
}
return e.EncodeInt16(n)
@@ -81,11 +81,11 @@ func (e *Encoder) encodeInt16Cond(n int16) error {
// EncodeInt32 encodes an int32 in 5 bytes preserving type of the number.
func (e *Encoder) EncodeInt32(n int32) error {
- return e.write4(codes.Int32, uint32(n))
+ return e.write4(msgpcode.Int32, uint32(n))
}
func (e *Encoder) encodeInt32Cond(n int32) error {
- if e.useCompact {
+ if e.flags&useCompactIntsFlag != 0 {
return e.EncodeInt(int64(n))
}
return e.EncodeInt32(n)
@@ -93,11 +93,11 @@ func (e *Encoder) encodeInt32Cond(n int32) error {
// EncodeInt64 encodes an int64 in 9 bytes preserving type of the number.
func (e *Encoder) EncodeInt64(n int64) error {
- return e.write8(codes.Int64, uint64(n))
+ return e.write8(msgpcode.Int64, uint64(n))
}
func (e *Encoder) encodeInt64Cond(n int64) error {
- if e.useCompact {
+ if e.flags&useCompactIntsFlag != 0 {
return e.EncodeInt(n)
}
return e.EncodeInt64(n)
@@ -127,7 +127,7 @@ func (e *Encoder) EncodeInt(n int64) error {
if n >= 0 {
return e.EncodeUint(uint64(n))
}
- if n >= int64(int8(codes.NegFixedNumLow)) {
+ if n >= int64(int8(msgpcode.NegFixedNumLow)) {
return e.w.WriteByte(byte(n))
}
if n >= math.MinInt8 {
@@ -143,31 +143,45 @@ func (e *Encoder) EncodeInt(n int64) error {
}
func (e *Encoder) EncodeFloat32(n float32) error {
- return e.write4(codes.Float, math.Float32bits(n))
+ if e.flags&useCompactFloatsFlag != 0 {
+ if float32(int64(n)) == n {
+ return e.EncodeInt(int64(n))
+ }
+ }
+ return e.write4(msgpcode.Float, math.Float32bits(n))
}
func (e *Encoder) EncodeFloat64(n float64) error {
- return e.write8(codes.Double, math.Float64bits(n))
+ if e.flags&useCompactFloatsFlag != 0 {
+ // Both NaN and Inf convert to int64(-0x8000000000000000)
+ // If n is NaN then it never compares true with any other value
+ // If n is Inf then it doesn't convert from int64 back to +/-Inf
+ // In both cases the comparison works.
+ if float64(int64(n)) == n {
+ return e.EncodeInt(int64(n))
+ }
+ }
+ return e.write8(msgpcode.Double, math.Float64bits(n))
}
-func (e *Encoder) write1(code codes.Code, n uint8) error {
+func (e *Encoder) write1(code byte, n uint8) error {
e.buf = e.buf[:2]
- e.buf[0] = byte(code)
+ e.buf[0] = code
e.buf[1] = n
return e.write(e.buf)
}
-func (e *Encoder) write2(code codes.Code, n uint16) error {
+func (e *Encoder) write2(code byte, n uint16) error {
e.buf = e.buf[:3]
- e.buf[0] = byte(code)
+ e.buf[0] = code
e.buf[1] = byte(n >> 8)
e.buf[2] = byte(n)
return e.write(e.buf)
}
-func (e *Encoder) write4(code codes.Code, n uint32) error {
+func (e *Encoder) write4(code byte, n uint32) error {
e.buf = e.buf[:5]
- e.buf[0] = byte(code)
+ e.buf[0] = code
e.buf[1] = byte(n >> 24)
e.buf[2] = byte(n >> 16)
e.buf[3] = byte(n >> 8)
@@ -175,9 +189,9 @@ func (e *Encoder) write4(code codes.Code, n uint32) error {
return e.write(e.buf)
}
-func (e *Encoder) write8(code codes.Code, n uint64) error {
+func (e *Encoder) write8(code byte, n uint64) error {
e.buf = e.buf[:9]
- e.buf[0] = byte(code)
+ e.buf[0] = code
e.buf[1] = byte(n >> 56)
e.buf[2] = byte(n >> 48)
e.buf[3] = byte(n >> 40)
@@ -189,6 +203,14 @@ func (e *Encoder) write8(code codes.Code, n uint64) error {
return e.write(e.buf)
}
+func encodeUintValue(e *Encoder, v reflect.Value) error {
+ return e.EncodeUint(v.Uint())
+}
+
+func encodeIntValue(e *Encoder, v reflect.Value) error {
+ return e.EncodeInt(v.Int())
+}
+
func encodeUint8CondValue(e *Encoder, v reflect.Value) error {
return e.encodeUint8Cond(uint8(v.Uint()))
}
diff --git a/encode_slice.go b/encode_slice.go
index 69a9618..ca46ead 100644
--- a/encode_slice.go
+++ b/encode_slice.go
@@ -1,12 +1,13 @@
package msgpack
import (
+ "math"
"reflect"
- "github.com/vmihailenco/msgpack/v4/codes"
+ "github.com/vmihailenco/msgpack/v5/msgpcode"
)
-var sliceStringType = reflect.TypeOf(([]string)(nil))
+var stringSliceType = reflect.TypeOf(([]string)(nil))
func encodeStringValue(e *Encoder, v reflect.Value) error {
return e.EncodeString(v.String())
@@ -42,29 +43,36 @@ func grow(b []byte, n int) []byte {
func (e *Encoder) EncodeBytesLen(l int) error {
if l < 256 {
- return e.write1(codes.Bin8, uint8(l))
+ return e.write1(msgpcode.Bin8, uint8(l))
}
- if l < 65536 {
- return e.write2(codes.Bin16, uint16(l))
+ if l <= math.MaxUint16 {
+ return e.write2(msgpcode.Bin16, uint16(l))
}
- return e.write4(codes.Bin32, uint32(l))
+ return e.write4(msgpcode.Bin32, uint32(l))
}
-func (e *Encoder) encodeStrLen(l int) error {
+func (e *Encoder) encodeStringLen(l int) error {
if l < 32 {
- return e.writeCode(codes.FixedStrLow | codes.Code(l))
+ return e.writeCode(msgpcode.FixedStrLow | byte(l))
}
if l < 256 {
- return e.write1(codes.Str8, uint8(l))
+ return e.write1(msgpcode.Str8, uint8(l))
}
- if l < 65536 {
- return e.write2(codes.Str16, uint16(l))
+ if l <= math.MaxUint16 {
+ return e.write2(msgpcode.Str16, uint16(l))
}
- return e.write4(codes.Str32, uint32(l))
+ return e.write4(msgpcode.Str32, uint32(l))
}
func (e *Encoder) EncodeString(v string) error {
- if err := e.encodeStrLen(len(v)); err != nil {
+ if intern := e.flags&useInternedStringsFlag != 0; intern || len(e.dict) > 0 {
+ return e.encodeInternedString(v, intern)
+ }
+ return e.encodeNormalString(v)
+}
+
+func (e *Encoder) encodeNormalString(v string) error {
+ if err := e.encodeStringLen(len(v)); err != nil {
return err
}
return e.writeString(v)
@@ -82,16 +90,16 @@ func (e *Encoder) EncodeBytes(v []byte) error {
func (e *Encoder) EncodeArrayLen(l int) error {
if l < 16 {
- return e.writeCode(codes.FixedArrayLow | codes.Code(l))
+ return e.writeCode(msgpcode.FixedArrayLow | byte(l))
}
- if l < 65536 {
- return e.write2(codes.Array16, uint16(l))
+ if l <= math.MaxUint16 {
+ return e.write2(msgpcode.Array16, uint16(l))
}
- return e.write4(codes.Array32, uint32(l))
+ return e.write4(msgpcode.Array32, uint32(l))
}
func encodeStringSliceValue(e *Encoder, v reflect.Value) error {
- ss := v.Convert(sliceStringType).Interface().([]string)
+ ss := v.Convert(stringSliceType).Interface().([]string)
return e.encodeStringSlice(ss)
}
diff --git a/encode_value.go b/encode_value.go
index 2dbcfbe..48cf489 100644
--- a/encode_value.go
+++ b/encode_value.go
@@ -12,12 +12,12 @@ var valueEncoders []encoderFunc
func init() {
valueEncoders = []encoderFunc{
reflect.Bool: encodeBoolValue,
- reflect.Int: encodeInt64CondValue,
+ reflect.Int: encodeIntValue,
reflect.Int8: encodeInt8CondValue,
reflect.Int16: encodeInt16CondValue,
reflect.Int32: encodeInt32CondValue,
reflect.Int64: encodeInt64CondValue,
- reflect.Uint: encodeUint64CondValue,
+ reflect.Uint: encodeUintValue,
reflect.Uint8: encodeUint8CondValue,
reflect.Uint16: encodeUint16CondValue,
reflect.Uint32: encodeUint32CondValue,
@@ -49,6 +49,14 @@ func getEncoder(typ reflect.Type) encoderFunc {
}
func _getEncoder(typ reflect.Type) encoderFunc {
+ kind := typ.Kind()
+
+ if kind == reflect.Ptr {
+ if _, ok := typeEncMap.Load(typ.Elem()); ok {
+ return ptrEncoderFunc(typ)
+ }
+ }
+
if typ.Implements(customEncoderType) {
return encodeCustomValue
}
@@ -58,8 +66,9 @@ func _getEncoder(typ reflect.Type) encoderFunc {
if typ.Implements(binaryMarshalerType) {
return marshalBinaryValue
}
-
- kind := typ.Kind()
+ if typ.Implements(textMarshalerType) {
+ return marshalTextValue
+ }
// Addressable struct field value.
if kind != reflect.Ptr {
@@ -71,7 +80,10 @@ func _getEncoder(typ reflect.Type) encoderFunc {
return marshalValuePtr
}
if ptr.Implements(binaryMarshalerType) {
- return marshalBinaryValuePtr
+ return marshalBinaryValueAddr
+ }
+ if ptr.Implements(textMarshalerType) {
+ return marshalTextValueAddr
}
}
@@ -127,7 +139,7 @@ func encodeCustomValuePtr(e *Encoder, v reflect.Value) error {
}
func encodeCustomValue(e *Encoder, v reflect.Value) error {
- if nilable(v) && v.IsNil() {
+ if nilable(v.Kind()) && v.IsNil() {
return e.EncodeNil()
}
@@ -143,7 +155,7 @@ func marshalValuePtr(e *Encoder, v reflect.Value) error {
}
func marshalValue(e *Encoder, v reflect.Value) error {
- if nilable(v) && v.IsNil() {
+ if nilable(v.Kind()) && v.IsNil() {
return e.EncodeNil()
}
@@ -178,8 +190,8 @@ func encodeUnsupportedValue(e *Encoder, v reflect.Value) error {
return fmt.Errorf("msgpack: Encode(unsupported %s)", v.Type())
}
-func nilable(v reflect.Value) bool {
- switch v.Kind() {
+func nilable(kind reflect.Kind) bool {
+ switch kind {
case reflect.Chan, reflect.Func, reflect.Interface, reflect.Map, reflect.Ptr, reflect.Slice:
return true
}
@@ -188,7 +200,7 @@ func nilable(v reflect.Value) bool {
//------------------------------------------------------------------------------
-func marshalBinaryValuePtr(e *Encoder, v reflect.Value) error {
+func marshalBinaryValueAddr(e *Encoder, v reflect.Value) error {
if !v.CanAddr() {
return fmt.Errorf("msgpack: Encode(non-addressable %T)", v.Interface())
}
@@ -196,7 +208,7 @@ func marshalBinaryValuePtr(e *Encoder, v reflect.Value) error {
}
func marshalBinaryValue(e *Encoder, v reflect.Value) error {
- if nilable(v) && v.IsNil() {
+ if nilable(v.Kind()) && v.IsNil() {
return e.EncodeNil()
}
@@ -208,3 +220,26 @@ func marshalBinaryValue(e *Encoder, v reflect.Value) error {
return e.EncodeBytes(data)
}
+
+//------------------------------------------------------------------------------
+
+func marshalTextValueAddr(e *Encoder, v reflect.Value) error {
+ if !v.CanAddr() {
+ return fmt.Errorf("msgpack: Encode(non-addressable %T)", v.Interface())
+ }
+ return marshalTextValue(e, v.Addr())
+}
+
+func marshalTextValue(e *Encoder, v reflect.Value) error {
+ if nilable(v.Kind()) && v.IsNil() {
+ return e.EncodeNil()
+ }
+
+ marshaler := v.Interface().(encoding.TextMarshaler)
+ data, err := marshaler.MarshalText()
+ if err != nil {
+ return err
+ }
+
+ return e.EncodeBytes(data)
+}
diff --git a/example_CustomEncoder_test.go b/example_CustomEncoder_test.go
index 5a3ad85..77458cb 100644
--- a/example_CustomEncoder_test.go
+++ b/example_CustomEncoder_test.go
@@ -3,7 +3,7 @@ package msgpack_test
import (
"fmt"
- "github.com/vmihailenco/msgpack/v4"
+ "github.com/vmihailenco/msgpack/v5"
)
type customStruct struct {
diff --git a/example_registerExt_test.go b/example_registerExt_test.go
index 8bc842d..ae3b50e 100644
--- a/example_registerExt_test.go
+++ b/example_registerExt_test.go
@@ -5,20 +5,24 @@ import (
"fmt"
"time"
- "github.com/vmihailenco/msgpack/v4"
+ "github.com/vmihailenco/msgpack/v5"
)
-func init() {
- msgpack.RegisterExt(1, (*EventTime)(nil))
-}
-
// https://github.com/fluent/fluentd/wiki/Forward-Protocol-Specification-v1#eventtime-ext-format
type EventTime struct {
time.Time
}
-var _ msgpack.Marshaler = (*EventTime)(nil)
-var _ msgpack.Unmarshaler = (*EventTime)(nil)
+type OneMoreSecondEventTime struct {
+ EventTime
+}
+
+var (
+ _ msgpack.Marshaler = (*EventTime)(nil)
+ _ msgpack.Unmarshaler = (*EventTime)(nil)
+ _ msgpack.Marshaler = (*OneMoreSecondEventTime)(nil)
+ _ msgpack.Unmarshaler = (*OneMoreSecondEventTime)(nil)
+)
func (tm *EventTime) MarshalMsgpack() ([]byte, error) {
b := make([]byte, 8)
@@ -37,26 +41,143 @@ func (tm *EventTime) UnmarshalMsgpack(b []byte) error {
return nil
}
+func (tm *OneMoreSecondEventTime) MarshalMsgpack() ([]byte, error) {
+ b := make([]byte, 8)
+ binary.BigEndian.PutUint32(b, uint32(tm.Unix()+1))
+ binary.BigEndian.PutUint32(b[4:], uint32(tm.Nanosecond()))
+ return b, nil
+}
+
+func (tm *OneMoreSecondEventTime) UnmarshalMsgpack(b []byte) error {
+ if len(b) != 8 {
+ return fmt.Errorf("invalid data length: got %d, wanted 8", len(b))
+ }
+ sec := binary.BigEndian.Uint32(b)
+ usec := binary.BigEndian.Uint32(b[4:])
+ tm.Time = time.Unix(int64(sec+1), int64(usec))
+ return nil
+}
+
func ExampleRegisterExt() {
- b, err := msgpack.Marshal(&EventTime{time.Unix(123456789, 123)})
- if err != nil {
- panic(err)
+ t := time.Unix(123456789, 123)
+
+ {
+ msgpack.RegisterExt(1, (*EventTime)(nil))
+ b, err := msgpack.Marshal(&EventTime{t})
+ if err != nil {
+ panic(err)
+ }
+
+ var v interface{}
+ err = msgpack.Unmarshal(b, &v)
+ if err != nil {
+ panic(err)
+ }
+ fmt.Println(v.(*EventTime).UTC())
+
+ tm := new(EventTime)
+ err = msgpack.Unmarshal(b, &tm)
+ if err != nil {
+ panic(err)
+ }
+ fmt.Println(tm.UTC())
}
- var v interface{}
- err = msgpack.Unmarshal(b, &v)
- if err != nil {
- panic(err)
+ {
+ msgpack.RegisterExt(1, (*EventTime)(nil))
+ b, err := msgpack.Marshal(&EventTime{t})
+ if err != nil {
+ panic(err)
+ }
+
+ // override ext
+ msgpack.RegisterExt(1, (*OneMoreSecondEventTime)(nil))
+
+ var v interface{}
+ err = msgpack.Unmarshal(b, &v)
+ if err != nil {
+ panic(err)
+ }
+ fmt.Println(v.(*OneMoreSecondEventTime).UTC())
}
- fmt.Println(v.(*EventTime).UTC())
- tm := new(EventTime)
- err = msgpack.Unmarshal(b, tm)
- if err != nil {
- panic(err)
+ {
+ msgpack.RegisterExt(1, (*OneMoreSecondEventTime)(nil))
+ b, err := msgpack.Marshal(&OneMoreSecondEventTime{
+ EventTime{t},
+ })
+ if err != nil {
+ panic(err)
+ }
+
+ // override ext
+ msgpack.RegisterExt(1, (*EventTime)(nil))
+ var v interface{}
+ err = msgpack.Unmarshal(b, &v)
+ if err != nil {
+ panic(err)
+ }
+ fmt.Println(v.(*EventTime).UTC())
}
- fmt.Println(tm.UTC())
// Output: 1973-11-29 21:33:09.000000123 +0000 UTC
// 1973-11-29 21:33:09.000000123 +0000 UTC
+ // 1973-11-29 21:33:10.000000123 +0000 UTC
+ // 1973-11-29 21:33:10.000000123 +0000 UTC
+}
+
+func ExampleUnregisterExt() {
+ t := time.Unix(123456789, 123)
+
+ {
+ msgpack.RegisterExt(1, (*EventTime)(nil))
+ b, err := msgpack.Marshal(&EventTime{t})
+ if err != nil {
+ panic(err)
+ }
+
+ msgpack.UnregisterExt(1)
+
+ var v interface{}
+ err = msgpack.Unmarshal(b, &v)
+ wanted := "msgpack: unknown ext id=1"
+ if err.Error() != wanted {
+ panic(err)
+ }
+
+ msgpack.RegisterExt(1, (*OneMoreSecondEventTime)(nil))
+ err = msgpack.Unmarshal(b, &v)
+ if err != nil {
+ panic(err)
+ }
+ fmt.Println(v.(*OneMoreSecondEventTime).UTC())
+ }
+
+ {
+ msgpack.RegisterExt(1, (*OneMoreSecondEventTime)(nil))
+ b, err := msgpack.Marshal(&OneMoreSecondEventTime{
+ EventTime{t},
+ })
+ if err != nil {
+ panic(err)
+ }
+
+ msgpack.UnregisterExt(1)
+ var v interface{}
+ err = msgpack.Unmarshal(b, &v)
+ wanted := "msgpack: unknown ext id=1"
+ if err.Error() != wanted {
+ panic(err)
+ }
+
+ msgpack.RegisterExt(1, (*EventTime)(nil))
+ err = msgpack.Unmarshal(b, &v)
+ if err != nil {
+ panic(err)
+ }
+ fmt.Println(v.(*EventTime).UTC())
+ }
+
+ // Output: 1973-11-29 21:33:10.000000123 +0000 UTC
+ // 1973-11-29 21:33:10.000000123 +0000 UTC
}
diff --git a/example_test.go b/example_test.go
index e75f184..1408a59 100644
--- a/example_test.go
+++ b/example_test.go
@@ -4,7 +4,7 @@ import (
"bytes"
"fmt"
- "github.com/vmihailenco/msgpack/v4"
+ "github.com/vmihailenco/msgpack/v5"
)
func ExampleMarshal() {
@@ -47,7 +47,7 @@ func ExampleMarshal_mapStringInterface() {
// hello = world
}
-func ExampleDecoder_SetDecodeMapFunc() {
+func ExampleDecoder_SetMapDecoder() {
buf := new(bytes.Buffer)
enc := msgpack.NewEncoder(buf)
@@ -60,7 +60,7 @@ func ExampleDecoder_SetDecodeMapFunc() {
dec := msgpack.NewDecoder(buf)
// Causes decoder to produce map[string]string instead of map[string]interface{}.
- dec.SetDecodeMapFunc(func(d *msgpack.Decoder) (interface{}, error) {
+ dec.SetMapDecoder(func(d *msgpack.Decoder) (interface{}, error) {
n, err := d.DecodeMapLen()
if err != nil {
return nil, err
@@ -117,14 +117,16 @@ func ExampleDecoder_Query() {
// 2nd phone is 54321
}
-func ExampleEncoder_UseArrayForStructs() {
+func ExampleEncoder_UseArrayEncodedStructs() {
type Item struct {
Foo string
Bar string
}
var buf bytes.Buffer
- enc := msgpack.NewEncoder(&buf).StructAsArray(true)
+ enc := msgpack.NewEncoder(&buf)
+ enc.UseArrayEncodedStructs(true)
+
err := enc.Encode(&Item{Foo: "foo", Bar: "bar"})
if err != nil {
panic(err)
@@ -141,7 +143,7 @@ func ExampleEncoder_UseArrayForStructs() {
func ExampleMarshal_asArray() {
type Item struct {
- _msgpack struct{} `msgpack:",asArray"`
+ _msgpack struct{} `msgpack:",as_array"`
Foo string
Bar string
}
@@ -195,3 +197,25 @@ func ExampleMarshal_omitEmpty() {
// Output: item: "\x82\xa3Foo\xa5hello\xa3Bar\xa0"
// item2: "\x81\xa3Foo\xa5hello"
}
+
+func ExampleMarshal_escapedNames() {
+ og := map[string]interface{}{
+ "something:special": uint(123),
+ "hello, world": "hello!",
+ }
+ raw, err := msgpack.Marshal(og)
+ if err != nil {
+ panic(err)
+ }
+
+ type Item struct {
+ SomethingSpecial uint `msgpack:"'something:special'"`
+ HelloWorld string `msgpack:"'hello, world'"`
+ }
+ var item Item
+ if err := msgpack.Unmarshal(raw, &item); err != nil {
+ panic(err)
+ }
+ fmt.Printf("%#v\n", item)
+ //output: msgpack_test.Item{SomethingSpecial:0x7b, HelloWorld:"hello!"}
+}
diff --git a/ext.go b/ext.go
index 9b48f44..76e1160 100644
--- a/ext.go
+++ b/ext.go
@@ -1,191 +1,250 @@
package msgpack
import (
- "bytes"
"fmt"
+ "math"
"reflect"
- "sync"
- "github.com/vmihailenco/msgpack/v4/codes"
+ "github.com/vmihailenco/msgpack/v5/msgpcode"
)
type extInfo struct {
Type reflect.Type
- Decoder decoderFunc
+ Decoder func(d *Decoder, v reflect.Value, extLen int) error
}
-var extTypes = make(map[int8]extInfo)
+var extTypes = make(map[int8]*extInfo)
-var bufferPool = &sync.Pool{
- New: func() interface{} {
- return new(bytes.Buffer)
- },
+type MarshalerUnmarshaler interface {
+ Marshaler
+ Unmarshaler
}
-// RegisterExt records a type, identified by a value for that type,
-// under the provided id. That id will identify the concrete type of a value
-// sent or received as an interface variable. Only types that will be
-// transferred as implementations of interface values need to be registered.
-// Expecting to be used only during initialization, it panics if the mapping
-// between types and ids is not a bijection.
-func RegisterExt(id int8, value interface{}) {
- typ := reflect.TypeOf(value)
- if typ.Kind() == reflect.Ptr {
- typ = typ.Elem()
- }
- ptr := reflect.PtrTo(typ)
-
- if _, ok := extTypes[id]; ok {
- panic(fmt.Errorf("msgpack: ext with id=%d is already registered", id))
- }
+func RegisterExt(extID int8, value MarshalerUnmarshaler) {
+ RegisterExtEncoder(extID, value, func(e *Encoder, v reflect.Value) ([]byte, error) {
+ marshaler := v.Interface().(Marshaler)
+ return marshaler.MarshalMsgpack()
+ })
+ RegisterExtDecoder(extID, value, func(d *Decoder, v reflect.Value, extLen int) error {
+ b, err := d.readN(extLen)
+ if err != nil {
+ return err
+ }
+ return v.Interface().(Unmarshaler).UnmarshalMsgpack(b)
+ })
+}
- registerExt(id, ptr, getEncoder(ptr), getDecoder(ptr))
- registerExt(id, typ, getEncoder(typ), getDecoder(typ))
+func UnregisterExt(extID int8) {
+ unregisterExtEncoder(extID)
+ unregisterExtDecoder(extID)
}
-func registerExt(id int8, typ reflect.Type, enc encoderFunc, dec decoderFunc) {
- if enc != nil {
- typeEncMap.Store(typ, makeExtEncoder(id, enc))
- }
- if dec != nil {
- extTypes[id] = extInfo{
- Type: typ,
- Decoder: dec,
- }
- typeDecMap.Store(typ, makeExtDecoder(id, dec))
+func RegisterExtEncoder(
+ extID int8,
+ value interface{},
+ encoder func(enc *Encoder, v reflect.Value) ([]byte, error),
+) {
+ unregisterExtEncoder(extID)
+
+ typ := reflect.TypeOf(value)
+ extEncoder := makeExtEncoder(extID, typ, encoder)
+ typeEncMap.Store(extID, typ)
+ typeEncMap.Store(typ, extEncoder)
+ if typ.Kind() == reflect.Ptr {
+ typeEncMap.Store(typ.Elem(), makeExtEncoderAddr(extEncoder))
}
}
-func (e *Encoder) EncodeExtHeader(typeID int8, length int) error {
- if err := e.encodeExtLen(length); err != nil {
- return err
+func unregisterExtEncoder(extID int8) {
+ t, ok := typeEncMap.Load(extID)
+ if !ok {
+ return
}
- if err := e.w.WriteByte(byte(typeID)); err != nil {
- return err
+ typeEncMap.Delete(extID)
+ typ := t.(reflect.Type)
+ typeEncMap.Delete(typ)
+ if typ.Kind() == reflect.Ptr {
+ typeEncMap.Delete(typ.Elem())
}
- return nil
}
-func makeExtEncoder(typeID int8, enc encoderFunc) encoderFunc {
- return func(e *Encoder, v reflect.Value) error {
- buf := bufferPool.Get().(*bytes.Buffer)
- defer bufferPool.Put(buf)
- buf.Reset()
+func makeExtEncoder(
+ extID int8,
+ typ reflect.Type,
+ encoder func(enc *Encoder, v reflect.Value) ([]byte, error),
+) encoderFunc {
+ nilable := typ.Kind() == reflect.Ptr
- oldw := e.w
- e.w = buf
- err := enc(e, v)
- e.w = oldw
+ return func(e *Encoder, v reflect.Value) error {
+ if nilable && v.IsNil() {
+ return e.EncodeNil()
+ }
+ b, err := encoder(e, v)
if err != nil {
return err
}
- err = e.EncodeExtHeader(typeID, buf.Len())
- if err != nil {
+ if err := e.EncodeExtHeader(extID, len(b)); err != nil {
return err
}
- return e.write(buf.Bytes())
+
+ return e.write(b)
}
}
-func makeExtDecoder(typeID int8, dec decoderFunc) decoderFunc {
- return func(d *Decoder, v reflect.Value) error {
- c, err := d.PeekCode()
- if err != nil {
- return err
+func makeExtEncoderAddr(extEncoder encoderFunc) encoderFunc {
+ return func(e *Encoder, v reflect.Value) error {
+ if !v.CanAddr() {
+ return fmt.Errorf("msgpack: Decode(nonaddressable %T)", v.Interface())
}
+ return extEncoder(e, v.Addr())
+ }
+}
- if !codes.IsExt(c) {
- return dec(d, v)
- }
+func RegisterExtDecoder(
+ extID int8,
+ value interface{},
+ decoder func(dec *Decoder, v reflect.Value, extLen int) error,
+) {
+ unregisterExtDecoder(extID)
+
+ typ := reflect.TypeOf(value)
+ extDecoder := makeExtDecoder(extID, typ, decoder)
+ extTypes[extID] = &extInfo{
+ Type: typ,
+ Decoder: decoder,
+ }
+
+ typeDecMap.Store(extID, typ)
+ typeDecMap.Store(typ, extDecoder)
+ if typ.Kind() == reflect.Ptr {
+ typeDecMap.Store(typ.Elem(), makeExtDecoderAddr(extDecoder))
+ }
+}
- id, extLen, err := d.DecodeExtHeader()
+func unregisterExtDecoder(extID int8) {
+ t, ok := typeDecMap.Load(extID)
+ if !ok {
+ return
+ }
+ typeDecMap.Delete(extID)
+ delete(extTypes, extID)
+ typ := t.(reflect.Type)
+ typeDecMap.Delete(typ)
+ if typ.Kind() == reflect.Ptr {
+ typeDecMap.Delete(typ.Elem())
+ }
+}
+
+func makeExtDecoder(
+ wantedExtID int8,
+ typ reflect.Type,
+ decoder func(d *Decoder, v reflect.Value, extLen int) error,
+) decoderFunc {
+ return nilAwareDecoder(typ, func(d *Decoder, v reflect.Value) error {
+ extID, extLen, err := d.DecodeExtHeader()
if err != nil {
return err
}
+ if extID != wantedExtID {
+ return fmt.Errorf("msgpack: got ext type=%d, wanted %d", extID, wantedExtID)
+ }
+ return decoder(d, v, extLen)
+ })
+}
- if id != typeID {
- return fmt.Errorf("msgpack: got ext type=%d, wanted %d", id, typeID)
+func makeExtDecoderAddr(extDecoder decoderFunc) decoderFunc {
+ return func(d *Decoder, v reflect.Value) error {
+ if !v.CanAddr() {
+ return fmt.Errorf("msgpack: Decode(nonaddressable %T)", v.Interface())
}
+ return extDecoder(d, v.Addr())
+ }
+}
- d.extLen = extLen
- return dec(d, v)
+func (e *Encoder) EncodeExtHeader(extID int8, extLen int) error {
+ if err := e.encodeExtLen(extLen); err != nil {
+ return err
+ }
+ if err := e.w.WriteByte(byte(extID)); err != nil {
+ return err
}
+ return nil
}
func (e *Encoder) encodeExtLen(l int) error {
switch l {
case 1:
- return e.writeCode(codes.FixExt1)
+ return e.writeCode(msgpcode.FixExt1)
case 2:
- return e.writeCode(codes.FixExt2)
+ return e.writeCode(msgpcode.FixExt2)
case 4:
- return e.writeCode(codes.FixExt4)
+ return e.writeCode(msgpcode.FixExt4)
case 8:
- return e.writeCode(codes.FixExt8)
+ return e.writeCode(msgpcode.FixExt8)
case 16:
- return e.writeCode(codes.FixExt16)
+ return e.writeCode(msgpcode.FixExt16)
}
- if l < 256 {
- return e.write1(codes.Ext8, uint8(l))
+ if l <= math.MaxUint8 {
+ return e.write1(msgpcode.Ext8, uint8(l))
}
- if l < 65536 {
- return e.write2(codes.Ext16, uint16(l))
+ if l <= math.MaxUint16 {
+ return e.write2(msgpcode.Ext16, uint16(l))
}
- return e.write4(codes.Ext32, uint32(l))
+ return e.write4(msgpcode.Ext32, uint32(l))
}
-func (d *Decoder) parseExtLen(c codes.Code) (int, error) {
- switch c {
- case codes.FixExt1:
- return 1, nil
- case codes.FixExt2:
- return 2, nil
- case codes.FixExt4:
- return 4, nil
- case codes.FixExt8:
- return 8, nil
- case codes.FixExt16:
- return 16, nil
- case codes.Ext8:
- n, err := d.uint8()
- return int(n), err
- case codes.Ext16:
- n, err := d.uint16()
- return int(n), err
- case codes.Ext32:
- n, err := d.uint32()
- return int(n), err
- default:
- return 0, fmt.Errorf("msgpack: invalid code=%x decoding ext length", c)
+func (d *Decoder) DecodeExtHeader() (extID int8, extLen int, err error) {
+ c, err := d.readCode()
+ if err != nil {
+ return
}
+ return d.extHeader(c)
}
-func (d *Decoder) decodeExtHeader(c codes.Code) (int8, int, error) {
- length, err := d.parseExtLen(c)
+func (d *Decoder) extHeader(c byte) (int8, int, error) {
+ extLen, err := d.parseExtLen(c)
if err != nil {
return 0, 0, err
}
- typeID, err := d.readCode()
+ extID, err := d.readCode()
if err != nil {
return 0, 0, err
}
- return int8(typeID), length, nil
+ return int8(extID), extLen, nil
}
-func (d *Decoder) DecodeExtHeader() (typeID int8, length int, err error) {
- c, err := d.readCode()
- if err != nil {
- return
+func (d *Decoder) parseExtLen(c byte) (int, error) {
+ switch c {
+ case msgpcode.FixExt1:
+ return 1, nil
+ case msgpcode.FixExt2:
+ return 2, nil
+ case msgpcode.FixExt4:
+ return 4, nil
+ case msgpcode.FixExt8:
+ return 8, nil
+ case msgpcode.FixExt16:
+ return 16, nil
+ case msgpcode.Ext8:
+ n, err := d.uint8()
+ return int(n), err
+ case msgpcode.Ext16:
+ n, err := d.uint16()
+ return int(n), err
+ case msgpcode.Ext32:
+ n, err := d.uint32()
+ return int(n), err
+ default:
+ return 0, fmt.Errorf("msgpack: invalid code=%x decoding ext len", c)
}
- return d.decodeExtHeader(c)
}
-func (d *Decoder) extInterface(c codes.Code) (interface{}, error) {
- extID, extLen, err := d.decodeExtHeader(c)
+func (d *Decoder) decodeInterfaceExt(c byte) (interface{}, error) {
+ extID, extLen, err := d.extHeader(c)
if err != nil {
return nil, err
}
@@ -195,19 +254,19 @@ func (d *Decoder) extInterface(c codes.Code) (interface{}, error) {
return nil, fmt.Errorf("msgpack: unknown ext id=%d", extID)
}
- v := reflect.New(info.Type)
+ v := reflect.New(info.Type).Elem()
+ if nilable(v.Kind()) && v.IsNil() {
+ v.Set(reflect.New(info.Type.Elem()))
+ }
- d.extLen = extLen
- err = info.Decoder(d, v.Elem())
- d.extLen = 0
- if err != nil {
+ if err := info.Decoder(d, v, extLen); err != nil {
return nil, err
}
return v.Interface(), nil
}
-func (d *Decoder) skipExt(c codes.Code) error {
+func (d *Decoder) skipExt(c byte) error {
n, err := d.parseExtLen(c)
if err != nil {
return err
@@ -215,7 +274,7 @@ func (d *Decoder) skipExt(c codes.Code) error {
return d.skipN(n + 1)
}
-func (d *Decoder) skipExtHeader(c codes.Code) error {
+func (d *Decoder) skipExtHeader(c byte) error {
// Read ext type.
_, err := d.readCode()
if err != nil {
@@ -231,13 +290,13 @@ func (d *Decoder) skipExtHeader(c codes.Code) error {
return nil
}
-func extHeaderLen(c codes.Code) int {
+func extHeaderLen(c byte) int {
switch c {
- case codes.Ext8:
+ case msgpcode.Ext8:
return 1
- case codes.Ext16:
+ case msgpcode.Ext16:
return 2
- case codes.Ext32:
+ case msgpcode.Ext32:
return 4
}
return 0
diff --git a/ext_test.go b/ext_test.go
index a2fa053..2c23bb1 100644
--- a/ext_test.go
+++ b/ext_test.go
@@ -3,77 +3,82 @@ package msgpack_test
import (
"bytes"
"encoding/hex"
- "reflect"
"testing"
"time"
- "github.com/vmihailenco/msgpack/v4"
- "github.com/vmihailenco/msgpack/v4/codes"
+ "github.com/stretchr/testify/require"
+ "github.com/vmihailenco/msgpack/v5"
+ "github.com/vmihailenco/msgpack/v5/msgpcode"
)
func init() {
msgpack.RegisterExt(9, (*ExtTest)(nil))
}
-func TestRegisterExtPanic(t *testing.T) {
- defer func() {
- r := recover()
- if r == nil {
- t.Fatalf("panic expected")
- }
- got := r.(error).Error()
- wanted := "msgpack: ext with id=9 is already registered"
- if got != wanted {
- t.Fatalf("got %q, wanted %q", got, wanted)
- }
- }()
- msgpack.RegisterExt(9, (*ExtTest)(nil))
-}
-
type ExtTest struct {
S string
}
-var _ msgpack.CustomEncoder = (*ExtTest)(nil)
-var _ msgpack.CustomDecoder = (*ExtTest)(nil)
+var (
+ _ msgpack.Marshaler = (*ExtTest)(nil)
+ _ msgpack.Unmarshaler = (*ExtTest)(nil)
+)
-func (ext ExtTest) EncodeMsgpack(e *msgpack.Encoder) error {
- return e.EncodeString("hello " + ext.S)
+func (ext ExtTest) MarshalMsgpack() ([]byte, error) {
+ return msgpack.Marshal("hello " + ext.S)
}
-func (ext *ExtTest) DecodeMsgpack(d *msgpack.Decoder) error {
- var err error
- ext.S, err = d.DecodeString()
- return err
+func (ext *ExtTest) UnmarshalMsgpack(b []byte) error {
+ return msgpack.Unmarshal(b, &ext.S)
}
func TestEncodeDecodeExtHeader(t *testing.T) {
v := &ExtTest{"world"}
- // Marshal using EncodeExtHeader
- var b bytes.Buffer
- enc := msgpack.NewEncoder(&b)
- err := v.EncodeMsgpack(enc)
- if err != nil {
- t.Fatal(err)
- }
+ payload, err := v.MarshalMsgpack()
+ require.Nil(t, err)
- payload := make([]byte, len(b.Bytes()))
- copy(payload, b.Bytes())
-
- b.Reset()
- enc = msgpack.NewEncoder(&b)
+ var buf bytes.Buffer
+ enc := msgpack.NewEncoder(&buf)
err = enc.EncodeExtHeader(9, len(payload))
+ require.Nil(t, err)
+
+ _, err = buf.Write(payload)
+ require.Nil(t, err)
+
+ var dst interface{}
+ err = msgpack.Unmarshal(buf.Bytes(), &dst)
+ require.Nil(t, err)
+
+ v = dst.(*ExtTest)
+ wanted := "hello world"
+ require.Equal(t, v.S, wanted)
+
+ dec := msgpack.NewDecoder(&buf)
+ extID, extLen, err := dec.DecodeExtHeader()
+ require.Nil(t, err)
+ require.Equal(t, int8(9), extID)
+ require.Equal(t, len(payload), extLen)
+
+ data := make([]byte, extLen)
+ err = dec.ReadFull(data)
+ require.Nil(t, err)
+
+ v = &ExtTest{}
+ err = v.UnmarshalMsgpack(data)
+ require.Nil(t, err)
+ require.Equal(t, wanted, v.S)
+}
+
+func TestExt(t *testing.T) {
+ v := &ExtTest{"world"}
+ b, err := msgpack.Marshal(v)
if err != nil {
t.Fatal(err)
}
- if _, err := b.Write(payload); err != nil {
- t.Fatal(err)
- }
- // Unmarshal using generic function
var dst interface{}
- err = msgpack.Unmarshal(b.Bytes(), &dst)
+ err = msgpack.Unmarshal(b, &dst)
if err != nil {
t.Fatal(err)
}
@@ -88,67 +93,18 @@ func TestEncodeDecodeExtHeader(t *testing.T) {
t.Fatalf("got %q, wanted %q", v.S, wanted)
}
- // Unmarshal using DecodeExtHeader
- d := msgpack.NewDecoder(&b)
- typeId, length, err := d.DecodeExtHeader()
- if err != nil {
- t.Fatal(err)
- }
-
- if typeId != 9 {
- t.Fatalf("got %d, wanted 9", 9)
- }
- if length != len(payload) {
- t.Fatalf("got %d, wanted %d", length, len(payload))
- }
-
- v = &ExtTest{}
- err = v.DecodeMsgpack(d)
+ ext := new(ExtTest)
+ err = msgpack.Unmarshal(b, &ext)
if err != nil {
t.Fatal(err)
}
-
- if v.S != wanted {
- t.Fatalf("got %q, wanted %q", v.S, wanted)
- }
-}
-
-func TestExt(t *testing.T) {
- for _, v := range []interface{}{ExtTest{"world"}, &ExtTest{"world"}} {
- b, err := msgpack.Marshal(v)
- if err != nil {
- t.Fatal(err)
- }
-
- var dst interface{}
- err = msgpack.Unmarshal(b, &dst)
- if err != nil {
- t.Fatal(err)
- }
-
- v, ok := dst.(*ExtTest)
- if !ok {
- t.Fatalf("got %#v, wanted ExtTest", dst)
- }
-
- wanted := "hello world"
- if v.S != wanted {
- t.Fatalf("got %q, wanted %q", v.S, wanted)
- }
-
- ext := new(ExtTest)
- err = msgpack.Unmarshal(b, ext)
- if err != nil {
- t.Fatal(err)
- }
- if ext.S != wanted {
- t.Fatalf("got %q, wanted %q", ext.S, wanted)
- }
+ if ext.S != wanted {
+ t.Fatalf("got %q, wanted %q", ext.S, wanted)
}
}
func TestUnknownExt(t *testing.T) {
- b := []byte{byte(codes.FixExt1), 2, 0}
+ b := []byte{byte(msgpcode.FixExt1), 2, 0}
var dst interface{}
err := msgpack.Unmarshal(b, &dst)
@@ -162,28 +118,6 @@ func TestUnknownExt(t *testing.T) {
}
}
-func TestDecodeExtWithMap(t *testing.T) {
- type S struct {
- I int
- }
- msgpack.RegisterExt(2, S{})
-
- b, err := msgpack.Marshal(&S{I: 42})
- if err != nil {
- t.Fatal(err)
- }
-
- var got map[string]interface{}
- if err := msgpack.Unmarshal(b, &got); err != nil {
- t.Fatal(err)
- }
-
- wanted := map[string]interface{}{"I": int64(42)}
- if !reflect.DeepEqual(got, wanted) {
- t.Fatalf("got %#v, but wanted %#v", got, wanted)
- }
-}
-
func TestSliceOfTime(t *testing.T) {
in := []interface{}{time.Now()}
b, err := msgpack.Marshal(in)
@@ -197,7 +131,7 @@ func TestSliceOfTime(t *testing.T) {
t.Fatal(err)
}
- outTime := *out[0].(*time.Time)
+ outTime := out[0].(time.Time)
inTime := in[0].(time.Time)
if outTime.Unix() != inTime.Unix() {
t.Fatalf("got %v, wanted %v", outTime, inTime)
@@ -208,6 +142,10 @@ type customPayload struct {
payload []byte
}
+func (cp *customPayload) MarshalMsgpack() ([]byte, error) {
+ return cp.payload, nil
+}
+
func (cp *customPayload) UnmarshalMsgpack(b []byte) error {
cp.payload = b
return nil
diff --git a/appengine.go b/extra/msgpappengine/appengine.go
similarity index 52%
rename from appengine.go
rename to extra/msgpappengine/appengine.go
index e8e91e5..8d93057 100644
--- a/appengine.go
+++ b/extra/msgpappengine/appengine.go
@@ -1,31 +1,30 @@
-// +build appengine
-
-package msgpack
+package msgpappengine
import (
"reflect"
+ "github.com/vmihailenco/msgpack/v5"
ds "google.golang.org/appengine/datastore"
)
func init() {
- Register((*ds.Key)(nil), encodeDatastoreKeyValue, decodeDatastoreKeyValue)
- Register((*ds.Cursor)(nil), encodeDatastoreCursorValue, decodeDatastoreCursorValue)
+ msgpack.Register((*ds.Key)(nil), encodeDatastoreKeyValue, decodeDatastoreKeyValue)
+ msgpack.Register((*ds.Cursor)(nil), encodeDatastoreCursorValue, decodeDatastoreCursorValue)
}
-func EncodeDatastoreKey(e *Encoder, key *ds.Key) error {
+func EncodeDatastoreKey(e *msgpack.Encoder, key *ds.Key) error {
if key == nil {
return e.EncodeNil()
}
return e.EncodeString(key.Encode())
}
-func encodeDatastoreKeyValue(e *Encoder, v reflect.Value) error {
+func encodeDatastoreKeyValue(e *msgpack.Encoder, v reflect.Value) error {
key := v.Interface().(*ds.Key)
return EncodeDatastoreKey(e, key)
}
-func DecodeDatastoreKey(d *Decoder) (*ds.Key, error) {
+func DecodeDatastoreKey(d *msgpack.Decoder) (*ds.Key, error) {
v, err := d.DecodeString()
if err != nil {
return nil, err
@@ -36,7 +35,7 @@ func DecodeDatastoreKey(d *Decoder) (*ds.Key, error) {
return ds.DecodeKey(v)
}
-func decodeDatastoreKeyValue(d *Decoder, v reflect.Value) error {
+func decodeDatastoreKeyValue(d *msgpack.Decoder, v reflect.Value) error {
key, err := DecodeDatastoreKey(d)
if err != nil {
return err
@@ -45,12 +44,12 @@ func decodeDatastoreKeyValue(d *Decoder, v reflect.Value) error {
return nil
}
-func encodeDatastoreCursorValue(e *Encoder, v reflect.Value) error {
+func encodeDatastoreCursorValue(e *msgpack.Encoder, v reflect.Value) error {
cursor := v.Interface().(ds.Cursor)
return e.Encode(cursor.String())
}
-func decodeDatastoreCursorValue(d *Decoder, v reflect.Value) error {
+func decodeDatastoreCursorValue(d *msgpack.Decoder, v reflect.Value) error {
s, err := d.DecodeString()
if err != nil {
return err
diff --git a/extra/msgpappengine/go.mod b/extra/msgpappengine/go.mod
new file mode 100644
index 0000000..f2b012d
--- /dev/null
+++ b/extra/msgpappengine/go.mod
@@ -0,0 +1,10 @@
+module github.com/vmihailenco/msgpack/extra/appengine
+
+go 1.15
+
+replace github.com/vmihailenco/msgpack/v5 => ../..
+
+require (
+ github.com/vmihailenco/msgpack/v5 v5.3.5
+ google.golang.org/appengine v1.6.7
+)
diff --git a/extra/msgpappengine/go.sum b/extra/msgpappengine/go.sum
new file mode 100644
index 0000000..8356dcf
--- /dev/null
+++ b/extra/msgpappengine/go.sum
@@ -0,0 +1,24 @@
+github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
+github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/golang/protobuf v1.3.1 h1:YF8+flBXS5eO826T4nzqPrxfhQThhXl0YzfuUPu4SBg=
+github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
+github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
+github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
+github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
+github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
+github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
+github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
+github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
+golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
+golang.org/x/net v0.0.0-20190603091049-60506f45cf65 h1:+rhAzEzT3f4JtomfC371qB+0Ola2caSKcY69NUBZrRQ=
+golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
+golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
+golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
+golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
+golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
+google.golang.org/appengine v1.6.7 h1:FZR1q0exgwxzPzp/aF+VccGrSfxfPpkBqjIIEq3ru6c=
+google.golang.org/appengine v1.6.7/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
+gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
+gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
+gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
+gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
diff --git a/go.mod b/go.mod
index 52cf6c2..f630a54 100644
--- a/go.mod
+++ b/go.mod
@@ -1,12 +1,8 @@
-module github.com/vmihailenco/msgpack/v4
+module github.com/vmihailenco/msgpack/v5
+
+go 1.11
require (
- github.com/golang/protobuf v1.3.2 // indirect
- github.com/kr/pretty v0.1.0 // indirect
- github.com/vmihailenco/tagparser v0.1.1
- golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553 // indirect
- google.golang.org/appengine v1.6.5
- gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127
+ github.com/stretchr/testify v1.6.1
+ github.com/vmihailenco/tagparser/v2 v2.0.0
)
-
-go 1.11
diff --git a/go.sum b/go.sum
index 85168cf..a2bef4a 100644
--- a/go.sum
+++ b/go.sum
@@ -1,22 +1,13 @@
-github.com/golang/protobuf v1.3.1/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
-github.com/golang/protobuf v1.3.2 h1:6nsPYzhq5kReh6QImI3k5qWzO4PEbvbIW2cwSfR/6xs=
-github.com/golang/protobuf v1.3.2/go.mod h1:6lQm79b+lXiMfvg/cZm0SGofjICqVBUtrP5yJMmIC1U=
-github.com/kr/pretty v0.1.0 h1:L/CwN0zerZDmRFUapSPitk6f+Q3+0za1rQkzVuMiMFI=
-github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
-github.com/kr/pty v1.1.1/go.mod h1:pFQYn66WHrOpPYNljwOMqo10TkYh1fy3cYio2l3bCsQ=
-github.com/kr/text v0.1.0 h1:45sCR5RtlFHMR4UwH9sdQ5TC8v0qDQCHnXt+kaKSTVE=
-github.com/kr/text v0.1.0/go.mod h1:4Jbv+DJW3UT/LiOwJeYQe1efqtUx/iVham/4vfdArNI=
-github.com/vmihailenco/tagparser v0.1.1 h1:quXMXlA39OCbd2wAdTsGDlK9RkOk6Wuw+x37wVyIuWY=
-github.com/vmihailenco/tagparser v0.1.1/go.mod h1:OeAg3pn3UbLjkWt+rN9oFYB6u/cQgqMEUPoW2WPyhdI=
-golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
-golang.org/x/net v0.0.0-20190603091049-60506f45cf65/go.mod h1:HSz+uSET+XFnRR8LxR5pz3Of3rY3CfYBVs4xY44aLks=
-golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553 h1:efeOvDhwQ29Dj3SdAV/MJf8oukgn+8D8WgaCaRMchF8=
-golang.org/x/net v0.0.0-20191209160850-c0dbc17a3553/go.mod h1:z5CRVTTTmAJ677TzLLGU+0bjPO0LkuOLi4/5GtJWs/s=
-golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
-golang.org/x/text v0.3.0/go.mod h1:NqM8EUOU14njkJ3fqMW+pc6Ldnwhi/IjpwHt7yyuwOQ=
-golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
-golang.org/x/tools v0.0.0-20180917221912-90fa682c2a6e/go.mod h1:n7NCudcB/nEzxVGmLbDWY5pfWTLqBcC2KZ6jyYvM4mQ=
-google.golang.org/appengine v1.6.5 h1:tycE03LOZYQNhDpS27tcQdAzLCVMaj7QT2SXxebnpCM=
-google.golang.org/appengine v1.6.5/go.mod h1:8WjMMxjGQR8xUklV/ARdw2HLXBOI7O7uCIDZVag1xfc=
-gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127 h1:qIbj1fsPNlZgppZ+VLlY7N33q108Sa+fhmuc+sWQYwY=
-gopkg.in/check.v1 v1.0.0-20180628173108-788fd7840127/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
+github.com/davecgh/go-spew v1.1.0 h1:ZDRjVQ15GmhC3fiQ8ni8+OwkZQO4DARzQgrnXU1Liz8=
+github.com/davecgh/go-spew v1.1.0/go.mod h1:J7Y8YcW2NihsgmVo/mv3lAwl/skON4iLHjSsI+c5H38=
+github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZbAQM=
+github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
+github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
+github.com/stretchr/testify v1.6.1 h1:hDPOHmpOpP40lSULcqw7IrRb/u7w6RpDC9399XyoNd0=
+github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
+github.com/vmihailenco/tagparser/v2 v2.0.0 h1:y09buUbR+b5aycVFQs/g70pqKVZNBmxwAhO7/IwNM9g=
+github.com/vmihailenco/tagparser/v2 v2.0.0/go.mod h1:Wri+At7QHww0WTrCBeu4J6bNtoV6mEfg5OIWRZA9qds=
+gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405 h1:yhCVgyC4o1eVCa2tZl7eS0r+SDo693bJlVdllGtEeKM=
+gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
+gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c h1:dUUwHk2QECo/6vqA44rthZ8ie2QXMNeKRTHCNY2nXvo=
+gopkg.in/yaml.v3 v3.0.0-20200313102051-9f266ea9e77c/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
diff --git a/intern.go b/intern.go
new file mode 100644
index 0000000..be0316a
--- /dev/null
+++ b/intern.go
@@ -0,0 +1,238 @@
+package msgpack
+
+import (
+ "fmt"
+ "math"
+ "reflect"
+
+ "github.com/vmihailenco/msgpack/v5/msgpcode"
+)
+
+const (
+ minInternedStringLen = 3
+ maxDictLen = math.MaxUint16
+)
+
+var internedStringExtID = int8(math.MinInt8)
+
+func init() {
+ extTypes[internedStringExtID] = &extInfo{
+ Type: stringType,
+ Decoder: decodeInternedStringExt,
+ }
+}
+
+func decodeInternedStringExt(d *Decoder, v reflect.Value, extLen int) error {
+ idx, err := d.decodeInternedStringIndex(extLen)
+ if err != nil {
+ return err
+ }
+
+ s, err := d.internedStringAtIndex(idx)
+ if err != nil {
+ return err
+ }
+
+ v.SetString(s)
+ return nil
+}
+
+//------------------------------------------------------------------------------
+
+func encodeInternedInterfaceValue(e *Encoder, v reflect.Value) error {
+ if v.IsNil() {
+ return e.EncodeNil()
+ }
+
+ v = v.Elem()
+ if v.Kind() == reflect.String {
+ return e.encodeInternedString(v.String(), true)
+ }
+ return e.EncodeValue(v)
+}
+
+func encodeInternedStringValue(e *Encoder, v reflect.Value) error {
+ return e.encodeInternedString(v.String(), true)
+}
+
+func (e *Encoder) encodeInternedString(s string, intern bool) error {
+ // Interned string takes at least 3 bytes. Plain string 1 byte + string len.
+ if len(s) >= minInternedStringLen {
+ if idx, ok := e.dict[s]; ok {
+ return e.encodeInternedStringIndex(idx)
+ }
+
+ if intern && len(e.dict) < maxDictLen {
+ if e.dict == nil {
+ e.dict = make(map[string]int)
+ }
+ idx := len(e.dict)
+ e.dict[s] = idx
+ }
+ }
+
+ return e.encodeNormalString(s)
+}
+
+func (e *Encoder) encodeInternedStringIndex(idx int) error {
+ if idx <= math.MaxUint8 {
+ if err := e.writeCode(msgpcode.FixExt1); err != nil {
+ return err
+ }
+ return e.write1(byte(internedStringExtID), uint8(idx))
+ }
+
+ if idx <= math.MaxUint16 {
+ if err := e.writeCode(msgpcode.FixExt2); err != nil {
+ return err
+ }
+ return e.write2(byte(internedStringExtID), uint16(idx))
+ }
+
+ if uint64(idx) <= math.MaxUint32 {
+ if err := e.writeCode(msgpcode.FixExt4); err != nil {
+ return err
+ }
+ return e.write4(byte(internedStringExtID), uint32(idx))
+ }
+
+ return fmt.Errorf("msgpack: interned string index=%d is too large", idx)
+}
+
+//------------------------------------------------------------------------------
+
+func decodeInternedInterfaceValue(d *Decoder, v reflect.Value) error {
+ s, err := d.decodeInternedString(true)
+ if err == nil {
+ v.Set(reflect.ValueOf(s))
+ return nil
+ }
+ if err != nil {
+ if _, ok := err.(unexpectedCodeError); !ok {
+ return err
+ }
+ }
+
+ if err := d.s.UnreadByte(); err != nil {
+ return err
+ }
+ return decodeInterfaceValue(d, v)
+}
+
+func decodeInternedStringValue(d *Decoder, v reflect.Value) error {
+ s, err := d.decodeInternedString(true)
+ if err != nil {
+ return err
+ }
+
+ v.SetString(s)
+ return nil
+}
+
+func (d *Decoder) decodeInternedString(intern bool) (string, error) {
+ c, err := d.readCode()
+ if err != nil {
+ return "", err
+ }
+
+ if msgpcode.IsFixedString(c) {
+ n := int(c & msgpcode.FixedStrMask)
+ return d.decodeInternedStringWithLen(n, intern)
+ }
+
+ switch c {
+ case msgpcode.Nil:
+ return "", nil
+ case msgpcode.FixExt1, msgpcode.FixExt2, msgpcode.FixExt4:
+ typeID, extLen, err := d.extHeader(c)
+ if err != nil {
+ return "", err
+ }
+ if typeID != internedStringExtID {
+ err := fmt.Errorf("msgpack: got ext type=%d, wanted %d",
+ typeID, internedStringExtID)
+ return "", err
+ }
+
+ idx, err := d.decodeInternedStringIndex(extLen)
+ if err != nil {
+ return "", err
+ }
+
+ return d.internedStringAtIndex(idx)
+ case msgpcode.Str8, msgpcode.Bin8:
+ n, err := d.uint8()
+ if err != nil {
+ return "", err
+ }
+ return d.decodeInternedStringWithLen(int(n), intern)
+ case msgpcode.Str16, msgpcode.Bin16:
+ n, err := d.uint16()
+ if err != nil {
+ return "", err
+ }
+ return d.decodeInternedStringWithLen(int(n), intern)
+ case msgpcode.Str32, msgpcode.Bin32:
+ n, err := d.uint32()
+ if err != nil {
+ return "", err
+ }
+ return d.decodeInternedStringWithLen(int(n), intern)
+ }
+
+ return "", unexpectedCodeError{
+ code: c,
+ hint: "interned string",
+ }
+}
+
+func (d *Decoder) decodeInternedStringIndex(extLen int) (int, error) {
+ switch extLen {
+ case 1:
+ n, err := d.uint8()
+ if err != nil {
+ return 0, err
+ }
+ return int(n), nil
+ case 2:
+ n, err := d.uint16()
+ if err != nil {
+ return 0, err
+ }
+ return int(n), nil
+ case 4:
+ n, err := d.uint32()
+ if err != nil {
+ return 0, err
+ }
+ return int(n), nil
+ }
+
+ err := fmt.Errorf("msgpack: unsupported ext len=%d decoding interned string", extLen)
+ return 0, err
+}
+
+func (d *Decoder) internedStringAtIndex(idx int) (string, error) {
+ if idx >= len(d.dict) {
+ err := fmt.Errorf("msgpack: interned string at index=%d does not exist", idx)
+ return "", err
+ }
+ return d.dict[idx], nil
+}
+
+func (d *Decoder) decodeInternedStringWithLen(n int, intern bool) (string, error) {
+ if n <= 0 {
+ return "", nil
+ }
+
+ s, err := d.stringWithLen(n)
+ if err != nil {
+ return "", err
+ }
+
+ if intern && len(s) >= minInternedStringLen && len(d.dict) < maxDictLen {
+ d.dict = append(d.dict, s)
+ }
+
+ return s, nil
+}
diff --git a/intern_test.go b/intern_test.go
new file mode 100644
index 0000000..93c2c10
--- /dev/null
+++ b/intern_test.go
@@ -0,0 +1,146 @@
+package msgpack_test
+
+import (
+ "bytes"
+ "io"
+ "testing"
+
+ "github.com/stretchr/testify/require"
+ "github.com/vmihailenco/msgpack/v5"
+)
+
+type NoIntern struct {
+ A string
+ B string
+ C interface{}
+}
+
+type Intern struct {
+ A string `msgpack:",intern"`
+ B string `msgpack:",intern"`
+ C interface{} `msgpack:",intern"`
+}
+
+func TestInternedString(t *testing.T) {
+ var buf bytes.Buffer
+
+ enc := msgpack.NewEncoder(&buf)
+ enc.UseInternedStrings(true)
+
+ dec := msgpack.NewDecoder(&buf)
+ dec.UseInternedStrings(true)
+
+ for i := 0; i < 3; i++ {
+ err := enc.EncodeString("hello")
+ require.Nil(t, err)
+ }
+
+ for i := 0; i < 3; i++ {
+ s, err := dec.DecodeString()
+ require.Nil(t, err)
+ require.Equal(t, "hello", s)
+ }
+
+ err := enc.Encode("hello")
+ require.Nil(t, err)
+
+ v, err := dec.DecodeInterface()
+ require.Nil(t, err)
+ require.Equal(t, "hello", v)
+
+ _, err = dec.DecodeInterface()
+ require.Equal(t, io.EOF, err)
+}
+
+func TestInternedStringTag(t *testing.T) {
+ var buf bytes.Buffer
+ enc := msgpack.NewEncoder(&buf)
+ dec := msgpack.NewDecoder(&buf)
+
+ in := []Intern{
+ {"f", "f", "f"},
+ {"fo", "fo", "fo"},
+ {"foo", "foo", "foo"},
+ {"f", "fo", "foo"},
+ }
+ err := enc.Encode(in)
+ require.Nil(t, err)
+
+ var out []Intern
+ err = dec.Decode(&out)
+ require.Nil(t, err)
+ require.Equal(t, in, out)
+}
+
+func TestResetDict(t *testing.T) {
+ dict := []string{"hello world", "foo bar"}
+
+ var buf bytes.Buffer
+ enc := msgpack.NewEncoder(&buf)
+ dec := msgpack.NewDecoder(&buf)
+
+ {
+ enc.ResetDict(&buf, dictMap(dict))
+ err := enc.EncodeString("hello world")
+ require.Nil(t, err)
+ require.Equal(t, 3, buf.Len())
+
+ dec.ResetDict(&buf, dict)
+ s, err := dec.DecodeString()
+ require.Nil(t, err)
+ require.Equal(t, "hello world", s)
+ }
+
+ {
+ enc.ResetDict(&buf, dictMap(dict))
+ err := enc.Encode("foo bar")
+ require.Nil(t, err)
+ require.Equal(t, 3, buf.Len())
+
+ dec.ResetDict(&buf, dict)
+ s, err := dec.DecodeInterface()
+ require.Nil(t, err)
+ require.Equal(t, "foo bar", s)
+ }
+
+ dec.ResetDict(&buf, dict)
+ _ = enc.EncodeString("xxxx")
+ require.Equal(t, 5, buf.Len())
+ _ = enc.Encode("xxxx")
+ require.Equal(t, 10, buf.Len())
+}
+
+func TestMapWithInternedString(t *testing.T) {
+ type M map[string]interface{}
+
+ dict := []string{"hello world", "foo bar"}
+
+ var buf bytes.Buffer
+
+ enc := msgpack.NewEncoder(nil)
+ enc.ResetDict(&buf, dictMap(dict))
+
+ dec := msgpack.NewDecoder(nil)
+ dec.ResetDict(&buf, dict)
+
+ for i := 0; i < 100; i++ {
+ in := M{
+ "foo bar": "hello world",
+ "hello world": "foo bar",
+ "foo": "bar",
+ }
+ err := enc.Encode(in)
+ require.Nil(t, err)
+
+ _, err = dec.DecodeInterface()
+ require.Nil(t, err)
+ }
+}
+
+func dictMap(dict []string) map[string]int {
+ m := make(map[string]int, len(dict))
+ for i, s := range dict {
+ m[s] = i
+ }
+ return m
+}
diff --git a/msgpack.go b/msgpack.go
index 220b43c..4db2fa2 100644
--- a/msgpack.go
+++ b/msgpack.go
@@ -1,5 +1,7 @@
package msgpack
+import "fmt"
+
type Marshaler interface {
MarshalMsgpack() ([]byte, error)
}
@@ -15,3 +17,36 @@ type CustomEncoder interface {
type CustomDecoder interface {
DecodeMsgpack(*Decoder) error
}
+
+//------------------------------------------------------------------------------
+
+type RawMessage []byte
+
+var (
+ _ CustomEncoder = (RawMessage)(nil)
+ _ CustomDecoder = (*RawMessage)(nil)
+)
+
+func (m RawMessage) EncodeMsgpack(enc *Encoder) error {
+ return enc.write(m)
+}
+
+func (m *RawMessage) DecodeMsgpack(dec *Decoder) error {
+ msg, err := dec.DecodeRaw()
+ if err != nil {
+ return err
+ }
+ *m = msg
+ return nil
+}
+
+//------------------------------------------------------------------------------
+
+type unexpectedCodeError struct {
+ code byte
+ hint string
+}
+
+func (err unexpectedCodeError) Error() string {
+ return fmt.Sprintf("msgpack: unexpected code=%x decoding %s", err.code, err.hint)
+}
diff --git a/msgpack_test.go b/msgpack_test.go
index b288f1d..1b77027 100644
--- a/msgpack_test.go
+++ b/msgpack_test.go
@@ -3,82 +3,84 @@ package msgpack_test
import (
"bufio"
"bytes"
+ "fmt"
+ "math"
"reflect"
"testing"
"time"
- . "gopkg.in/check.v1"
-
- "github.com/vmihailenco/msgpack/v4"
+ "github.com/stretchr/testify/require"
+ "github.com/stretchr/testify/suite"
+ "github.com/vmihailenco/msgpack/v5"
)
type nameStruct struct {
Name string
}
-func TestGocheck(t *testing.T) { TestingT(t) }
-
type MsgpackTest struct {
+ suite.Suite
+
buf *bytes.Buffer
enc *msgpack.Encoder
dec *msgpack.Decoder
}
-var _ = Suite(&MsgpackTest{})
-
-func (t *MsgpackTest) SetUpTest(c *C) {
+func (t *MsgpackTest) SetUpTest() {
t.buf = &bytes.Buffer{}
t.enc = msgpack.NewEncoder(t.buf)
t.dec = msgpack.NewDecoder(bufio.NewReader(t.buf))
}
-func (t *MsgpackTest) TestDecodeNil(c *C) {
- c.Assert(t.dec.Decode(nil), NotNil)
+func (t *MsgpackTest) TestDecodeNil() {
+ t.NotNil(t.dec.Decode(nil))
}
-func (t *MsgpackTest) TestTime(c *C) {
+func (t *MsgpackTest) TestTime() {
in := time.Now()
var out time.Time
- c.Assert(t.enc.Encode(in), IsNil)
- c.Assert(t.dec.Decode(&out), IsNil)
- c.Assert(out.Equal(in), Equals, true)
+
+ t.Nil(t.enc.Encode(in))
+ t.Nil(t.dec.Decode(&out))
+ t.True(out.Equal(in))
var zero time.Time
- c.Assert(t.enc.Encode(zero), IsNil)
- c.Assert(t.dec.Decode(&out), IsNil)
- c.Assert(out.Equal(zero), Equals, true)
- c.Assert(out.IsZero(), Equals, true)
+ t.Nil(t.enc.Encode(zero))
+ t.Nil(t.dec.Decode(&out))
+ t.True(out.Equal(zero))
+ t.True(out.IsZero())
+
}
-func (t *MsgpackTest) TestLargeBytes(c *C) {
+func (t *MsgpackTest) TestLargeBytes() {
N := int(1e6)
src := bytes.Repeat([]byte{'1'}, N)
- c.Assert(t.enc.Encode(src), IsNil)
+ t.Nil(t.enc.Encode(src))
var dst []byte
- c.Assert(t.dec.Decode(&dst), IsNil)
- c.Assert(dst, DeepEquals, src)
+ t.Nil(t.dec.Decode(&dst))
+ t.Equal(dst, src)
}
-func (t *MsgpackTest) TestLargeString(c *C) {
+func (t *MsgpackTest) TestLargeString() {
N := int(1e6)
src := string(bytes.Repeat([]byte{'1'}, N))
- c.Assert(t.enc.Encode(src), IsNil)
+ t.Nil(t.enc.Encode(src))
var dst string
- c.Assert(t.dec.Decode(&dst), IsNil)
- c.Assert(dst, Equals, src)
+ t.Nil(t.dec.Decode(&dst))
+ t.Equal(dst, src)
}
-func (t *MsgpackTest) TestSliceOfStructs(c *C) {
- in := []*nameStruct{&nameStruct{"hello"}}
+func (t *MsgpackTest) TestSliceOfStructs() {
+ in := []*nameStruct{{"hello"}}
var out []*nameStruct
- c.Assert(t.enc.Encode(in), IsNil)
- c.Assert(t.dec.Decode(&out), IsNil)
- c.Assert(out, DeepEquals, in)
+ t.Nil(t.enc.Encode(in))
+ t.Nil(t.dec.Decode(&out))
+ t.Equal(out, in)
}
-func (t *MsgpackTest) TestMap(c *C) {
+func (t *MsgpackTest) TestMap() {
for _, i := range []struct {
m map[string]string
b []byte
@@ -86,24 +88,24 @@ func (t *MsgpackTest) TestMap(c *C) {
{map[string]string{}, []byte{0x80}},
{map[string]string{"hello": "world"}, []byte{0x81, 0xa5, 0x68, 0x65, 0x6c, 0x6c, 0x6f, 0xa5, 0x77, 0x6f, 0x72, 0x6c, 0x64}},
} {
- c.Assert(t.enc.Encode(i.m), IsNil)
- c.Assert(t.buf.Bytes(), DeepEquals, i.b, Commentf("err encoding %v", i.m))
+ t.Nil(t.enc.Encode(i.m))
+ t.Equal(t.buf.Bytes(), i.b, fmt.Errorf("err encoding %v", i.m))
var m map[string]string
- c.Assert(t.dec.Decode(&m), IsNil)
- c.Assert(m, DeepEquals, i.m)
+ t.Nil(t.dec.Decode(&m))
+ t.Equal(m, i.m)
}
}
-func (t *MsgpackTest) TestStructNil(c *C) {
+func (t *MsgpackTest) TestStructNil() {
var dst *nameStruct
- c.Assert(t.enc.Encode(nameStruct{Name: "foo"}), IsNil)
- c.Assert(t.dec.Decode(&dst), IsNil)
- c.Assert(dst, Not(IsNil))
- c.Assert(dst.Name, Equals, "foo")
+ t.Nil(t.enc.Encode(nameStruct{Name: "foo"}))
+ t.Nil(t.dec.Decode(&dst))
+ t.NotNil(dst)
+ t.Equal(dst.Name, "foo")
}
-func (t *MsgpackTest) TestStructUnknownField(c *C) {
+func (t *MsgpackTest) TestStructUnknownField() {
in := struct {
Field1 string
Field2 string
@@ -113,13 +115,13 @@ func (t *MsgpackTest) TestStructUnknownField(c *C) {
Field2: "value2",
Field3: "value3",
}
- c.Assert(t.enc.Encode(in), IsNil)
+ t.Nil(t.enc.Encode(in))
out := struct {
Field2 string
}{}
- c.Assert(t.dec.Decode(&out), IsNil)
- c.Assert(out.Field2, Equals, "value2")
+ t.Nil(t.dec.Decode(&out))
+ t.Equal(out.Field2, "value2")
}
//------------------------------------------------------------------------------
@@ -149,45 +151,45 @@ func (s *coderStruct) DecodeMsgpack(dec *msgpack.Decoder) error {
return dec.Decode(&s.name)
}
-func (t *MsgpackTest) TestCoder(c *C) {
+func (t *MsgpackTest) TestCoder() {
in := &coderStruct{name: "hello"}
var out coderStruct
- c.Assert(t.enc.Encode(in), IsNil)
- c.Assert(t.dec.Decode(&out), IsNil)
- c.Assert(out.Name(), Equals, "hello")
+ t.Nil(t.enc.Encode(in))
+ t.Nil(t.dec.Decode(&out))
+ t.Equal(out.Name(), "hello")
}
-func (t *MsgpackTest) TestNilCoder(c *C) {
+func (t *MsgpackTest) TestNilCoder() {
in := &coderStruct{name: "hello"}
var out *coderStruct
- c.Assert(t.enc.Encode(in), IsNil)
- c.Assert(t.dec.Decode(&out), IsNil)
- c.Assert(out.Name(), Equals, "hello")
+ t.Nil(t.enc.Encode(in))
+ t.Nil(t.dec.Decode(&out))
+ t.Equal(out.Name(), "hello")
}
-func (t *MsgpackTest) TestNilCoderValue(c *C) {
+func (t *MsgpackTest) TestNilCoderValue() {
in := &coderStruct{name: "hello"}
var out *coderStruct
- c.Assert(t.enc.Encode(in), IsNil)
- c.Assert(t.dec.DecodeValue(reflect.ValueOf(&out)), IsNil)
- c.Assert(out.Name(), Equals, "hello")
+ t.Nil(t.enc.Encode(in))
+ t.Nil(t.dec.DecodeValue(reflect.ValueOf(&out)))
+ t.Equal(out.Name(), "hello")
}
-func (t *MsgpackTest) TestPtrToCoder(c *C) {
+func (t *MsgpackTest) TestPtrToCoder() {
in := &coderStruct{name: "hello"}
var out coderStruct
out2 := &out
- c.Assert(t.enc.Encode(in), IsNil)
- c.Assert(t.dec.Decode(&out2), IsNil)
- c.Assert(out.Name(), Equals, "hello")
+ t.Nil(t.enc.Encode(in))
+ t.Nil(t.dec.Decode(&out2))
+ t.Equal(out.Name(), "hello")
}
-func (t *MsgpackTest) TestWrappedCoder(c *C) {
+func (t *MsgpackTest) TestWrappedCoder() {
in := &wrapperStruct{coderStruct: coderStruct{name: "hello"}}
var out wrapperStruct
- c.Assert(t.enc.Encode(in), IsNil)
- c.Assert(t.dec.Decode(&out), IsNil)
- c.Assert(out.Name(), Equals, "hello")
+ t.Nil(t.enc.Encode(in))
+ t.Nil(t.dec.Decode(&out))
+ t.Equal(out.Name(), "hello")
}
//------------------------------------------------------------------------------
@@ -201,13 +203,13 @@ type struct1 struct {
Struct2 struct2
}
-func (t *MsgpackTest) TestNestedStructs(c *C) {
+func (t *MsgpackTest) TestNestedStructs() {
in := &struct1{Name: "hello", Struct2: struct2{Name: "world"}}
var out struct1
- c.Assert(t.enc.Encode(in), IsNil)
- c.Assert(t.dec.Decode(&out), IsNil)
- c.Assert(out.Name, Equals, in.Name)
- c.Assert(out.Struct2.Name, Equals, in.Struct2.Name)
+ t.Nil(t.enc.Encode(in))
+ t.Nil(t.dec.Decode(&out))
+ t.Equal(out.Name, in.Name)
+ t.Equal(out.Struct2.Name, in.Struct2.Name)
}
type Struct4 struct {
@@ -245,30 +247,231 @@ func TestEmbedding(t *testing.T) {
}
}
-func (t *MsgpackTest) TestSliceNil(c *C) {
+func (t *MsgpackTest) TestSliceNil() {
in := [][]*int{nil}
var out [][]*int
- c.Assert(t.enc.Encode(in), IsNil)
- c.Assert(t.dec.Decode(&out), IsNil)
- c.Assert(out, DeepEquals, in)
+ t.Nil(t.enc.Encode(in))
+ t.Nil(t.dec.Decode(&out))
+ t.Equal(out, in)
}
//------------------------------------------------------------------------------
-func (t *MsgpackTest) TestMapStringInterface(c *C) {
+//------------------------------------------------------------------------------
+
+func TestNoPanicOnUnsupportedKey(t *testing.T) {
+ data := []byte{0x81, 0x81, 0xa1, 0x78, 0xc3, 0xc3}
+
+ _, err := msgpack.NewDecoder(bytes.NewReader(data)).DecodeTypedMap()
+ require.EqualError(t, err, "msgpack: unsupported map key: map[string]interface {}")
+}
+
+func TestMapDefault(t *testing.T) {
in := map[string]interface{}{
"foo": "bar",
"hello": map[string]interface{}{
"foo": "bar",
},
}
+ b, err := msgpack.Marshal(in)
+ require.Nil(t, err)
+
var out map[string]interface{}
+ err = msgpack.Unmarshal(b, &out)
+ require.Nil(t, err)
+ require.Equal(t, in, out)
+}
- c.Assert(t.enc.Encode(in), IsNil)
- c.Assert(t.dec.Decode(&out), IsNil)
+func TestRawMessage(t *testing.T) {
+ type In struct {
+ Foo map[string]interface{}
+ }
+
+ type Out struct {
+ Foo msgpack.RawMessage
+ }
+
+ type Out2 struct {
+ Foo interface{}
+ }
- c.Assert(out["foo"], Equals, "bar")
- mm := out["hello"].(map[string]interface{})
- c.Assert(mm["foo"], Equals, "bar")
+ b, err := msgpack.Marshal(&In{
+ Foo: map[string]interface{}{
+ "hello": "world",
+ },
+ })
+ require.Nil(t, err)
+
+ var out Out
+ err = msgpack.Unmarshal(b, &out)
+ require.Nil(t, err)
+
+ var m map[string]string
+ err = msgpack.Unmarshal(out.Foo, &m)
+ require.Nil(t, err)
+ require.Equal(t, map[string]string{
+ "hello": "world",
+ }, m)
+
+ msg := new(msgpack.RawMessage)
+ out2 := Out2{
+ Foo: msg,
+ }
+ err = msgpack.Unmarshal(b, &out2)
+ require.Nil(t, err)
+ require.Equal(t, out.Foo, *msg)
+}
+
+func TestInterface(t *testing.T) {
+ type Interface struct {
+ Foo interface{}
+ }
+
+ in := Interface{Foo: "foo"}
+ b, err := msgpack.Marshal(in)
+ require.Nil(t, err)
+
+ var str string
+ out := Interface{Foo: &str}
+ err = msgpack.Unmarshal(b, &out)
+ require.Nil(t, err)
+ require.Equal(t, "foo", str)
+}
+
+func TestNaN(t *testing.T) {
+ in := float64(math.NaN())
+ b, err := msgpack.Marshal(in)
+ require.Nil(t, err)
+
+ var out float64
+ err = msgpack.Unmarshal(b, &out)
+ require.Nil(t, err)
+ require.True(t, math.IsNaN(out))
+}
+
+func TestSetSortMapKeys(t *testing.T) {
+ in := map[string]interface{}{
+ "a": "a",
+ "b": "b",
+ "c": "c",
+ "d": "d",
+ }
+
+ var buf bytes.Buffer
+ enc := msgpack.NewEncoder(&buf)
+ enc.SetSortMapKeys(true)
+ dec := msgpack.NewDecoder(&buf)
+
+ err := enc.Encode(in)
+ require.Nil(t, err)
+
+ wanted := make([]byte, buf.Len())
+ copy(wanted, buf.Bytes())
+ buf.Reset()
+
+ for i := 0; i < 100; i++ {
+ err := enc.Encode(in)
+ require.Nil(t, err)
+ require.Equal(t, wanted, buf.Bytes())
+
+ out, err := dec.DecodeMap()
+ require.Nil(t, err)
+ require.Equal(t, in, out)
+ }
+}
+
+func TestSetOmitEmpty(t *testing.T) {
+ var buf bytes.Buffer
+ enc := msgpack.NewEncoder(&buf)
+ enc.SetOmitEmpty(true)
+ err := enc.Encode(EmbeddingPtrTest{})
+ require.Nil(t, err)
+
+ var t2 *EmbeddingPtrTest
+ dec := msgpack.NewDecoder(&buf)
+ err = dec.Decode(&t2)
+ require.Nil(t, err)
+ require.Nil(t, t2.Exported)
+}
+
+type NullInt struct {
+ Valid bool
+ Int int
+}
+
+func (i *NullInt) Set(j int) {
+ i.Int = j
+ i.Valid = true
+}
+
+func (i NullInt) IsZero() bool {
+ return !i.Valid
+}
+
+func (i NullInt) MarshalMsgpack() ([]byte, error) {
+ return msgpack.Marshal(i.Int)
+}
+
+func (i *NullInt) UnmarshalMsgpack(b []byte) error {
+ if err := msgpack.Unmarshal(b, &i.Int); err != nil {
+ return err
+ }
+ i.Valid = true
+ return nil
+}
+
+type Secretive struct {
+ Visible bool
+ hidden bool
+}
+
+type T struct {
+ I NullInt `msgpack:",omitempty"`
+ J NullInt
+ // Secretive is not a "simple" struct because it has an hidden field.
+ S Secretive `msgpack:",omitempty"`
+}
+
+func ExampleMarshal_ignore_simple_zero_structs_when_tagged_with_omitempty() {
+ var t1 T
+ raw, err := msgpack.Marshal(t1)
+ if err != nil {
+ panic(err)
+ }
+ var t2 T
+ if err = msgpack.Unmarshal(raw, &t2); err != nil {
+ panic(err)
+ }
+ fmt.Printf("%#v\n", t2)
+
+ t2.I.Set(42)
+ t2.S.hidden = true // won't be included because it is a hidden field
+ raw, err = msgpack.Marshal(t2)
+ if err != nil {
+ panic(err)
+ }
+ var t3 T
+ if err = msgpack.Unmarshal(raw, &t3); err != nil {
+ panic(err)
+ }
+ fmt.Printf("%#v\n", t3)
+ // Output: msgpack_test.T{I:msgpack_test.NullInt{Valid:false, Int:0}, J:msgpack_test.NullInt{Valid:true, Int:0}, S:msgpack_test.Secretive{Visible:false, hidden:false}}
+ // msgpack_test.T{I:msgpack_test.NullInt{Valid:true, Int:42}, J:msgpack_test.NullInt{Valid:true, Int:0}, S:msgpack_test.Secretive{Visible:false, hidden:false}}
+}
+
+type Value interface{}
+type Wrapper struct {
+ Value Value `msgpack:"v,omitempty"`
+}
+
+func TestEncodeWrappedValue(t *testing.T) {
+ var v Value
+ v = (*time.Time)(nil)
+ c := &Wrapper{
+ Value: v,
+ }
+ var buf bytes.Buffer
+ require.Nil(t, msgpack.NewEncoder(&buf).Encode(v))
+ require.Nil(t, msgpack.NewEncoder(&buf).Encode(c))
}
diff --git a/msgpcode/msgpcode.go b/msgpcode/msgpcode.go
new file mode 100644
index 0000000..e35389c
--- /dev/null
+++ b/msgpcode/msgpcode.go
@@ -0,0 +1,88 @@
+package msgpcode
+
+var (
+ PosFixedNumHigh byte = 0x7f
+ NegFixedNumLow byte = 0xe0
+
+ Nil byte = 0xc0
+
+ False byte = 0xc2
+ True byte = 0xc3
+
+ Float byte = 0xca
+ Double byte = 0xcb
+
+ Uint8 byte = 0xcc
+ Uint16 byte = 0xcd
+ Uint32 byte = 0xce
+ Uint64 byte = 0xcf
+
+ Int8 byte = 0xd0
+ Int16 byte = 0xd1
+ Int32 byte = 0xd2
+ Int64 byte = 0xd3
+
+ FixedStrLow byte = 0xa0
+ FixedStrHigh byte = 0xbf
+ FixedStrMask byte = 0x1f
+ Str8 byte = 0xd9
+ Str16 byte = 0xda
+ Str32 byte = 0xdb
+
+ Bin8 byte = 0xc4
+ Bin16 byte = 0xc5
+ Bin32 byte = 0xc6
+
+ FixedArrayLow byte = 0x90
+ FixedArrayHigh byte = 0x9f
+ FixedArrayMask byte = 0xf
+ Array16 byte = 0xdc
+ Array32 byte = 0xdd
+
+ FixedMapLow byte = 0x80
+ FixedMapHigh byte = 0x8f
+ FixedMapMask byte = 0xf
+ Map16 byte = 0xde
+ Map32 byte = 0xdf
+
+ FixExt1 byte = 0xd4
+ FixExt2 byte = 0xd5
+ FixExt4 byte = 0xd6
+ FixExt8 byte = 0xd7
+ FixExt16 byte = 0xd8
+ Ext8 byte = 0xc7
+ Ext16 byte = 0xc8
+ Ext32 byte = 0xc9
+)
+
+func IsFixedNum(c byte) bool {
+ return c <= PosFixedNumHigh || c >= NegFixedNumLow
+}
+
+func IsFixedMap(c byte) bool {
+ return c >= FixedMapLow && c <= FixedMapHigh
+}
+
+func IsFixedArray(c byte) bool {
+ return c >= FixedArrayLow && c <= FixedArrayHigh
+}
+
+func IsFixedString(c byte) bool {
+ return c >= FixedStrLow && c <= FixedStrHigh
+}
+
+func IsString(c byte) bool {
+ return IsFixedString(c) || c == Str8 || c == Str16 || c == Str32
+}
+
+func IsBin(c byte) bool {
+ return c == Bin8 || c == Bin16 || c == Bin32
+}
+
+func IsFixedExt(c byte) bool {
+ return c >= FixExt1 && c <= FixExt16
+}
+
+func IsExt(c byte) bool {
+ return IsFixedExt(c) || c == Ext8 || c == Ext16 || c == Ext32
+}
diff --git a/package.json b/package.json
new file mode 100644
index 0000000..298910d
--- /dev/null
+++ b/package.json
@@ -0,0 +1,4 @@
+{
+ "name": "msgpack",
+ "version": "5.3.5"
+}
diff --git a/safe.go b/safe.go
new file mode 100644
index 0000000..8352c9d
--- /dev/null
+++ b/safe.go
@@ -0,0 +1,13 @@
+// +build appengine
+
+package msgpack
+
+// bytesToString converts byte slice to string.
+func bytesToString(b []byte) string {
+ return string(b)
+}
+
+// stringToBytes converts string to byte slice.
+func stringToBytes(s string) []byte {
+ return []byte(s)
+}
diff --git a/time.go b/time.go
index 91374e9..44566ec 100644
--- a/time.go
+++ b/time.go
@@ -6,15 +6,30 @@ import (
"reflect"
"time"
- "github.com/vmihailenco/msgpack/v4/codes"
+ "github.com/vmihailenco/msgpack/v5/msgpcode"
)
var timeExtID int8 = -1
-//nolint:gochecknoinits
func init() {
- timeType := reflect.TypeOf((*time.Time)(nil)).Elem()
- registerExt(timeExtID, timeType, encodeTimeValue, decodeTimeValue)
+ RegisterExtEncoder(timeExtID, time.Time{}, timeEncoder)
+ RegisterExtDecoder(timeExtID, time.Time{}, timeDecoder)
+}
+
+func timeEncoder(e *Encoder, v reflect.Value) ([]byte, error) {
+ return e.encodeTime(v.Interface().(time.Time)), nil
+}
+
+func timeDecoder(d *Decoder, v reflect.Value, extLen int) error {
+ tm, err := d.decodeTime(extLen)
+ if err != nil {
+ return err
+ }
+
+ ptr := v.Addr().Interface().(*time.Time)
+ *ptr = tm
+
+ return nil
}
func (e *Encoder) EncodeTime(tm time.Time) error {
@@ -36,11 +51,13 @@ func (e *Encoder) encodeTime(tm time.Time) []byte {
secs := uint64(tm.Unix())
if secs>>34 == 0 {
data := uint64(tm.Nanosecond())<<34 | secs
+
if data&0xffffffff00000000 == 0 {
b := e.timeBuf[:4]
binary.BigEndian.PutUint32(b, uint32(data))
return b
}
+
b := e.timeBuf[:8]
binary.BigEndian.PutUint64(b, data)
return b
@@ -53,62 +70,56 @@ func (e *Encoder) encodeTime(tm time.Time) []byte {
}
func (d *Decoder) DecodeTime() (time.Time, error) {
- tm, err := d.decodeTime()
+ c, err := d.readCode()
if err != nil {
- return tm, err
- }
-
- if tm.IsZero() {
- // Assume that zero time does not have timezone information.
- return tm.UTC(), nil
+ return time.Time{}, err
}
- return tm, nil
-}
-func (d *Decoder) decodeTime() (time.Time, error) {
- extLen := d.extLen
- d.extLen = 0
- if extLen == 0 {
- c, err := d.readCode()
+ // Legacy format.
+ if c == msgpcode.FixedArrayLow|2 {
+ sec, err := d.DecodeInt64()
if err != nil {
return time.Time{}, err
}
- // Legacy format.
- if c == codes.FixedArrayLow|2 {
- sec, err := d.DecodeInt64()
- if err != nil {
- return time.Time{}, err
- }
-
- nsec, err := d.DecodeInt64()
- if err != nil {
- return time.Time{}, err
- }
-
- return time.Unix(sec, nsec), nil
+ nsec, err := d.DecodeInt64()
+ if err != nil {
+ return time.Time{}, err
}
- if codes.IsString(c) {
- s, err := d.string(c)
- if err != nil {
- return time.Time{}, err
- }
- return time.Parse(time.RFC3339Nano, s)
- }
+ return time.Unix(sec, nsec), nil
+ }
- extLen, err = d.parseExtLen(c)
+ if msgpcode.IsString(c) {
+ s, err := d.string(c)
if err != nil {
return time.Time{}, err
}
+ return time.Parse(time.RFC3339Nano, s)
+ }
- // Skip ext id.
- _, err = d.s.ReadByte()
- if err != nil {
- return time.Time{}, nil
- }
+ extID, extLen, err := d.extHeader(c)
+ if err != nil {
+ return time.Time{}, err
+ }
+
+ if extID != timeExtID {
+ return time.Time{}, fmt.Errorf("msgpack: invalid time ext id=%d", extID)
}
+ tm, err := d.decodeTime(extLen)
+ if err != nil {
+ return tm, err
+ }
+
+ if tm.IsZero() {
+ // Zero time does not have timezone information.
+ return tm.UTC(), nil
+ }
+ return tm, nil
+}
+
+func (d *Decoder) decodeTime(extLen int) (time.Time, error) {
b, err := d.readN(extLen)
if err != nil {
return time.Time{}, err
@@ -132,21 +143,3 @@ func (d *Decoder) decodeTime() (time.Time, error) {
return time.Time{}, err
}
}
-
-func encodeTimeValue(e *Encoder, v reflect.Value) error {
- tm := v.Interface().(time.Time)
- b := e.encodeTime(tm)
- return e.write(b)
-}
-
-func decodeTimeValue(d *Decoder, v reflect.Value) error {
- tm, err := d.DecodeTime()
- if err != nil {
- return err
- }
-
- ptr := v.Addr().Interface().(*time.Time)
- *ptr = tm
-
- return nil
-}
diff --git a/types.go b/types.go
index c8b1fd5..69aca61 100644
--- a/types.go
+++ b/types.go
@@ -2,11 +2,12 @@ package msgpack
import (
"encoding"
+ "fmt"
"log"
"reflect"
"sync"
- "github.com/vmihailenco/tagparser"
+ "github.com/vmihailenco/tagparser/v2"
)
var errorType = reflect.TypeOf((*error)(nil)).Elem()
@@ -26,9 +27,14 @@ var (
binaryUnmarshalerType = reflect.TypeOf((*encoding.BinaryUnmarshaler)(nil)).Elem()
)
+var (
+ textMarshalerType = reflect.TypeOf((*encoding.TextMarshaler)(nil)).Elem()
+ textUnmarshalerType = reflect.TypeOf((*encoding.TextUnmarshaler)(nil)).Elem()
+)
+
type (
- encoderFunc = func(*Encoder, reflect.Value) error
- decoderFunc = func(*Decoder, reflect.Value) error
+ encoderFunc func(*Encoder, reflect.Value) error
+ decoderFunc func(*Decoder, reflect.Value) error
)
var (
@@ -38,7 +44,7 @@ var (
// Register registers encoder and decoder functions for a value.
// This is low level API and in most cases you should prefer implementing
-// Marshaler/CustomEncoder and Unmarshaler/CustomDecoder interfaces.
+// CustomEncoder/CustomDecoder or Marshaler/Unmarshaler interfaces.
func Register(value interface{}, enc encoderFunc, dec decoderFunc) {
typ := reflect.TypeOf(value)
if enc != nil {
@@ -51,28 +57,33 @@ func Register(value interface{}, enc encoderFunc, dec decoderFunc) {
//------------------------------------------------------------------------------
-var structs = newStructCache(false)
-var jsonStructs = newStructCache(true)
+const defaultStructTag = "msgpack"
+
+var structs = newStructCache()
type structCache struct {
m sync.Map
+}
- useJSONTag bool
+type structCacheKey struct {
+ tag string
+ typ reflect.Type
}
-func newStructCache(useJSONTag bool) *structCache {
- return &structCache{
- useJSONTag: useJSONTag,
- }
+func newStructCache() *structCache {
+ return new(structCache)
}
-func (m *structCache) Fields(typ reflect.Type) *fields {
- if v, ok := m.m.Load(typ); ok {
+func (m *structCache) Fields(typ reflect.Type, tag string) *fields {
+ key := structCacheKey{tag: tag, typ: typ}
+
+ if v, ok := m.m.Load(key); ok {
return v.(*fields)
}
- fs := getFields(typ, m.useJSONTag)
- m.m.Store(typ, fs)
+ fs := getFields(typ, tag)
+ m.m.Store(key, fs)
+
return fs
}
@@ -86,24 +97,24 @@ type field struct {
decoder decoderFunc
}
-func (f *field) Omit(strct reflect.Value) bool {
- v, isNil := fieldByIndex(strct, f.index)
- if isNil {
+func (f *field) Omit(strct reflect.Value, forced bool) bool {
+ v, ok := fieldByIndex(strct, f.index)
+ if !ok {
return true
}
- return f.omitEmpty && isEmptyValue(v)
+ return (f.omitEmpty || forced) && isEmptyValue(v)
}
func (f *field) EncodeValue(e *Encoder, strct reflect.Value) error {
- v, isNil := fieldByIndex(strct, f.index)
- if isNil {
+ v, ok := fieldByIndex(strct, f.index)
+ if !ok {
return e.EncodeNil()
}
return f.encoder(e, v)
}
func (f *field) DecodeValue(d *Decoder, strct reflect.Value) error {
- v := fieldByIndexNewIfNil(strct, f.index)
+ v := fieldByIndexAlloc(strct, f.index)
return f.decoder(d, v)
}
@@ -127,10 +138,7 @@ func newFields(typ reflect.Type) *fields {
}
func (fs *fields) Add(field *field) {
- if _, ok := fs.Map[field.name]; ok {
- log.Printf("msgpack: %s already has field=%s", fs.Type, field.name)
- }
-
+ fs.warnIfFieldExists(field.name)
fs.Map[field.name] = field
fs.List = append(fs.List, field)
if field.omitEmpty {
@@ -138,15 +146,21 @@ func (fs *fields) Add(field *field) {
}
}
-func (fs *fields) OmitEmpty(strct reflect.Value) []*field {
- if !fs.hasOmitEmpty {
+func (fs *fields) warnIfFieldExists(name string) {
+ if _, ok := fs.Map[name]; ok {
+ log.Printf("msgpack: %s already has field=%s", fs.Type, name)
+ }
+}
+
+func (fs *fields) OmitEmpty(strct reflect.Value, forced bool) []*field {
+ if !fs.hasOmitEmpty && !forced {
return fs.List
}
fields := make([]*field, 0, len(fs.List))
for _, f := range fs.List {
- if !f.Omit(strct) {
+ if !f.Omit(strct, forced) {
fields = append(fields, f)
}
}
@@ -154,16 +168,16 @@ func (fs *fields) OmitEmpty(strct reflect.Value) []*field {
return fields
}
-func getFields(typ reflect.Type, useJSONTag bool) *fields {
+func getFields(typ reflect.Type, fallbackTag string) *fields {
fs := newFields(typ)
var omitEmpty bool
for i := 0; i < typ.NumField(); i++ {
f := typ.Field(i)
- tagStr := f.Tag.Get("msgpack")
- if useJSONTag && tagStr == "" {
- tagStr = f.Tag.Get("json")
+ tagStr := f.Tag.Get(defaultStructTag)
+ if tagStr == "" && fallbackTag != "" {
+ tagStr = f.Tag.Get(fallbackTag)
}
tag := tagparser.Parse(tagStr)
@@ -172,9 +186,7 @@ func getFields(typ reflect.Type, useJSONTag bool) *fields {
}
if f.Name == "_msgpack" {
- if tag.HasOption("asArray") {
- fs.AsArray = true
- }
+ fs.AsArray = tag.HasOption("as_array") || tag.HasOption("asArray")
if tag.HasOption("omitempty") {
omitEmpty = true
}
@@ -188,8 +200,23 @@ func getFields(typ reflect.Type, useJSONTag bool) *fields {
name: tag.Name,
index: f.Index,
omitEmpty: omitEmpty || tag.HasOption("omitempty"),
- encoder: getEncoder(f.Type),
- decoder: getDecoder(f.Type),
+ }
+
+ if tag.HasOption("intern") {
+ switch f.Type.Kind() {
+ case reflect.Interface:
+ field.encoder = encodeInternedInterfaceValue
+ field.decoder = decodeInternedInterfaceValue
+ case reflect.String:
+ field.encoder = encodeInternedStringValue
+ field.decoder = decodeInternedStringValue
+ default:
+ err := fmt.Errorf("msgpack: intern strings are not supported on %s", f.Type)
+ panic(err)
+ }
+ } else {
+ field.encoder = getEncoder(f.Type)
+ field.decoder = getDecoder(f.Type)
}
if field.name == "" {
@@ -199,9 +226,9 @@ func getFields(typ reflect.Type, useJSONTag bool) *fields {
if f.Anonymous && !tag.HasOption("noinline") {
inline := tag.HasOption("inline")
if inline {
- inlineFields(fs, f.Type, field, useJSONTag)
+ inlineFields(fs, f.Type, field, fallbackTag)
} else {
- inline = shouldInline(fs, f.Type, field, useJSONTag)
+ inline = shouldInline(fs, f.Type, field, fallbackTag)
}
if inline {
@@ -214,12 +241,19 @@ func getFields(typ reflect.Type, useJSONTag bool) *fields {
}
fs.Add(field)
+
+ if alias, ok := tag.Options["alias"]; ok {
+ fs.warnIfFieldExists(alias)
+ fs.Map[alias] = field
+ }
}
return fs
}
-var encodeStructValuePtr uintptr
-var decodeStructValuePtr uintptr
+var (
+ encodeStructValuePtr uintptr
+ decodeStructValuePtr uintptr
+)
//nolint:gochecknoinits
func init() {
@@ -227,8 +261,8 @@ func init() {
decodeStructValuePtr = reflect.ValueOf(decodeStructValue).Pointer()
}
-func inlineFields(fs *fields, typ reflect.Type, f *field, useJSONTag bool) {
- inlinedFields := getFields(typ, useJSONTag).List
+func inlineFields(fs *fields, typ reflect.Type, f *field, tag string) {
+ inlinedFields := getFields(typ, tag).List
for _, field := range inlinedFields {
if _, ok := fs.Map[field.name]; ok {
// Don't inline shadowed fields.
@@ -239,7 +273,7 @@ func inlineFields(fs *fields, typ reflect.Type, f *field, useJSONTag bool) {
}
}
-func shouldInline(fs *fields, typ reflect.Type, f *field, useJSONTag bool) bool {
+func shouldInline(fs *fields, typ reflect.Type, f *field, tag string) bool {
var encoder encoderFunc
var decoder decoderFunc
@@ -264,7 +298,7 @@ func shouldInline(fs *fields, typ reflect.Type, f *field, useJSONTag bool) bool
return false
}
- inlinedFields := getFields(typ, useJSONTag).List
+ inlinedFields := getFields(typ, tag).List
for _, field := range inlinedFields {
if _, ok := fs.Map[field.name]; ok {
// Don't auto inline if there are shadowed fields.
@@ -279,8 +313,26 @@ func shouldInline(fs *fields, typ reflect.Type, f *field, useJSONTag bool) bool
return true
}
+type isZeroer interface {
+ IsZero() bool
+}
+
func isEmptyValue(v reflect.Value) bool {
- switch v.Kind() {
+ kind := v.Kind()
+
+ for kind == reflect.Interface {
+ if v.IsNil() {
+ return true
+ }
+ v = v.Elem()
+ kind = v.Kind()
+ }
+
+ if z, ok := v.Interface().(isZeroer); ok {
+ return nilable(kind) && v.IsNil() || z.IsZero()
+ }
+
+ switch kind {
case reflect.Array, reflect.Map, reflect.Slice, reflect.String:
return v.Len() == 0
case reflect.Bool:
@@ -291,22 +343,23 @@ func isEmptyValue(v reflect.Value) bool {
return v.Uint() == 0
case reflect.Float32, reflect.Float64:
return v.Float() == 0
- case reflect.Interface, reflect.Ptr:
+ case reflect.Ptr:
return v.IsNil()
+ default:
+ return false
}
- return false
}
-func fieldByIndex(v reflect.Value, index []int) (_ reflect.Value, isNil bool) {
+func fieldByIndex(v reflect.Value, index []int) (_ reflect.Value, ok bool) {
if len(index) == 1 {
- return v.Field(index[0]), false
+ return v.Field(index[0]), true
}
for i, idx := range index {
if i > 0 {
if v.Kind() == reflect.Ptr {
if v.IsNil() {
- return v, true
+ return v, false
}
v = v.Elem()
}
@@ -314,10 +367,10 @@ func fieldByIndex(v reflect.Value, index []int) (_ reflect.Value, isNil bool) {
v = v.Field(idx)
}
- return v, false
+ return v, true
}
-func fieldByIndexNewIfNil(v reflect.Value, index []int) reflect.Value {
+func fieldByIndexAlloc(v reflect.Value, index []int) reflect.Value {
if len(index) == 1 {
return v.Field(index[0])
}
@@ -325,7 +378,7 @@ func fieldByIndexNewIfNil(v reflect.Value, index []int) reflect.Value {
for i, idx := range index {
if i > 0 {
var ok bool
- v, ok = indirectNew(v)
+ v, ok = indirectNil(v)
if !ok {
return v
}
@@ -336,7 +389,7 @@ func fieldByIndexNewIfNil(v reflect.Value, index []int) reflect.Value {
return v
}
-func indirectNew(v reflect.Value) (reflect.Value, bool) {
+func indirectNil(v reflect.Value) (reflect.Value, bool) {
if v.Kind() == reflect.Ptr {
if v.IsNil() {
if !v.CanSet() {
diff --git a/types_test.go b/types_test.go
index 40d65c6..fbc1ad3 100644
--- a/types_test.go
+++ b/types_test.go
@@ -5,20 +5,22 @@ import (
"encoding/hex"
"fmt"
"math"
+ "math/big"
"net/url"
"reflect"
"strings"
"testing"
"time"
- "github.com/vmihailenco/msgpack/v4"
- "github.com/vmihailenco/msgpack/v4/codes"
+ "github.com/stretchr/testify/require"
+ "github.com/vmihailenco/msgpack/v5"
+ "github.com/vmihailenco/msgpack/v5/msgpcode"
)
//------------------------------------------------------------------------------
type Object struct {
- n int
+ n int64
}
func (o *Object) MarshalMsgpack() ([]byte, error) {
@@ -51,8 +53,10 @@ func (t *CustomTime) DecodeMsgpack(dec *msgpack.Decoder) error {
type IntSet map[int]struct{}
-var _ msgpack.CustomEncoder = (*IntSet)(nil)
-var _ msgpack.CustomDecoder = (*IntSet)(nil)
+var (
+ _ msgpack.CustomEncoder = (*IntSet)(nil)
+ _ msgpack.CustomDecoder = (*IntSet)(nil)
+)
func (set IntSet) EncodeMsgpack(enc *msgpack.Encoder) error {
slice := make([]int, 0, len(set))
@@ -89,8 +93,10 @@ type CustomEncoder struct {
num int
}
-var _ msgpack.CustomEncoder = (*CustomEncoder)(nil)
-var _ msgpack.CustomDecoder = (*CustomEncoder)(nil)
+var (
+ _ msgpack.CustomEncoder = (*CustomEncoder)(nil)
+ _ msgpack.CustomDecoder = (*CustomEncoder)(nil)
+)
func (s *CustomEncoder) EncodeMsgpack(enc *msgpack.Encoder) error {
if s == nil {
@@ -107,6 +113,17 @@ type CustomEncoderField struct {
Field CustomEncoder
}
+type CustomEncoderEmbeddedPtr struct {
+ *CustomEncoder
+}
+
+func (s *CustomEncoderEmbeddedPtr) DecodeMsgpack(dec *msgpack.Decoder) error {
+ if s.CustomEncoder == nil {
+ s.CustomEncoder = new(CustomEncoder)
+ }
+ return s.CustomEncoder.DecodeMsgpack(dec)
+}
+
//------------------------------------------------------------------------------
type JSONFallbackTest struct {
@@ -117,14 +134,18 @@ type JSONFallbackTest struct {
func TestUseJsonTag(t *testing.T) {
var buf bytes.Buffer
- enc := msgpack.NewEncoder(&buf).UseJSONTag(true)
+ enc := msgpack.NewEncoder(&buf)
+ enc.SetCustomStructTag("json")
+
in := &JSONFallbackTest{Foo: "hello", Bar: "world"}
err := enc.Encode(in)
if err != nil {
t.Fatal(err)
}
- dec := msgpack.NewDecoder(&buf).UseJSONTag(true)
+ dec := msgpack.NewDecoder(&buf)
+ dec.SetCustomStructTag("json")
+
out := new(JSONFallbackTest)
err = dec.Decode(out)
if err != nil {
@@ -141,6 +162,45 @@ func TestUseJsonTag(t *testing.T) {
//------------------------------------------------------------------------------
+type CustomFallbackTest struct {
+ Foo string `custom:"foo,omitempty"`
+ Bar string `custom:",omitempty" msgpack:"bar"`
+}
+
+func TestUseCustomTag(t *testing.T) {
+ var buf bytes.Buffer
+
+ enc := msgpack.NewEncoder(&buf)
+ enc.SetCustomStructTag("custom")
+ in := &CustomFallbackTest{Foo: "hello", Bar: "world"}
+ err := enc.Encode(in)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ dec := msgpack.NewDecoder(&buf)
+ dec.SetCustomStructTag("custom")
+ out := new(CustomFallbackTest)
+ err = dec.Decode(out)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if out.Foo != in.Foo || out.Foo != "hello" {
+ t.Fatalf("got %q, wanted %q", out.Foo, in.Foo)
+ }
+ if out.Bar != in.Bar || out.Bar != "world" {
+ t.Fatalf("got %q, wanted %q", out.Foo, in.Foo)
+ }
+}
+
+//------------------------------------------------------------------------------
+
+type OmitTimeTest struct {
+ Foo time.Time `msgpack:",omitempty"`
+ Bar *time.Time `msgpack:",omitempty"`
+}
+
type OmitEmptyTest struct {
Foo string `msgpack:",omitempty"`
Bar string `msgpack:",omitempty"`
@@ -166,7 +226,7 @@ type InlineDupTest struct {
}
type AsArrayTest struct {
- _msgpack struct{} `msgpack:",asArray"`
+ _msgpack struct{} `msgpack:",as_array"`
OmitEmptyTest
}
@@ -224,18 +284,61 @@ var encoderTests = []encoderTest{
{&JSONFallbackTest{Foo: "hello"}, "82a3666f6fa568656c6c6fa3626172a0"},
{&JSONFallbackTest{Bar: "world"}, "81a3626172a5776f726c64"},
{&JSONFallbackTest{Foo: "hello", Bar: "world"}, "82a3666f6fa568656c6c6fa3626172a5776f726c64"},
+
+ {&NoIntern{A: "foo", B: "foo", C: "foo"}, "83a141a3666f6fa142a3666f6fa143a3666f6f"},
+ {&Intern{A: "foo", B: "foo", C: "foo"}, "83a141a3666f6fa142d48000a143d48000"},
}
func TestEncoder(t *testing.T) {
var buf bytes.Buffer
- enc := msgpack.NewEncoder(&buf).
- UseJSONTag(true).
- SortMapKeys(true).
- UseCompactEncoding(true)
+ enc := msgpack.NewEncoder(&buf)
+ enc.SetCustomStructTag("json")
+ enc.SetSortMapKeys(true)
+ enc.UseCompactInts(true)
+
+ for i, test := range encoderTests {
+ buf.Reset()
+
+ err := enc.Encode(test.in)
+ require.Nil(t, err)
+
+ s := hex.EncodeToString(buf.Bytes())
+ require.Equal(t, test.wanted, s, "#%d", i)
+ }
+}
+
+type floatEncoderTest struct {
+ in interface{}
+ wanted string
+ compact bool
+}
+
+var floatEncoderTests = []floatEncoderTest{
+ {float32(3.0), "ca40400000", false},
+ {float32(3.0), "03", true},
+
+ {float64(3.0), "cb4008000000000000", false},
+ {float64(3.0), "03", true},
+
+ {float64(-3.0), "cbc008000000000000", false},
+ {float64(-3.0), "fd", true},
+
+ {math.NaN(), "cb7ff8000000000001", false},
+ {math.NaN(), "cb7ff8000000000001", true},
+ {math.Inf(1), "cb7ff0000000000000", false},
+ {math.Inf(1), "cb7ff0000000000000", true},
+}
+
+func TestFloatEncoding(t *testing.T) {
+ var buf bytes.Buffer
+ enc := msgpack.NewEncoder(&buf)
+ enc.UseCompactInts(true)
- for _, test := range encoderTests {
+ for _, test := range floatEncoderTests {
buf.Reset()
+ enc.UseCompactFloats(test.compact)
+
err := enc.Encode(test.in)
if err != nil {
t.Fatal(err)
@@ -257,10 +360,10 @@ type decoderTest struct {
}
var decoderTests = []decoderTest{
- {b: []byte{byte(codes.Bin32), 0x0f, 0xff, 0xff, 0xff}, out: new([]byte), err: "EOF"},
- {b: []byte{byte(codes.Str32), 0x0f, 0xff, 0xff, 0xff}, out: new([]byte), err: "EOF"},
- {b: []byte{byte(codes.Array32), 0x0f, 0xff, 0xff, 0xff}, out: new([]int), err: "EOF"},
- {b: []byte{byte(codes.Map32), 0x0f, 0xff, 0xff, 0xff}, out: new(map[int]int), err: "EOF"},
+ {b: []byte{byte(msgpcode.Bin32), 0x0f, 0xff, 0xff, 0xff}, out: new([]byte), err: "EOF"},
+ {b: []byte{byte(msgpcode.Str32), 0x0f, 0xff, 0xff, 0xff}, out: new([]byte), err: "EOF"},
+ {b: []byte{byte(msgpcode.Array32), 0x0f, 0xff, 0xff, 0xff}, out: new([]int), err: "EOF"},
+ {b: []byte{byte(msgpcode.Map32), 0x0f, 0xff, 0xff, 0xff}, out: new(map[int]int), err: "EOF"},
}
func TestDecoder(t *testing.T) {
@@ -298,10 +401,6 @@ type EmbeddedTime struct {
time.Time
}
-type Interface struct {
- Foo interface{}
-}
-
type (
interfaceAlias interface{}
byteAlias byte
@@ -334,7 +433,7 @@ func (t typeTest) String() string {
return fmt.Sprintf("in=%#v, out=%#v", t.in, t.out)
}
-func (t *typeTest) assertErr(err error, s string) {
+func (t *typeTest) requireErr(err error, s string) {
if err == nil {
t.Fatalf("got %v error, wanted %q", err, s)
}
@@ -350,8 +449,8 @@ var (
{in: make(chan bool), encErr: "msgpack: Encode(unsupported chan bool)"},
{in: nil, out: nil, decErr: "msgpack: Decode(nil)"},
- {in: nil, out: 0, decErr: "msgpack: Decode(nonsettable int)"},
- {in: nil, out: (*int)(nil), decErr: "msgpack: Decode(nonsettable *int)"},
+ {in: nil, out: 0, decErr: "msgpack: Decode(non-pointer int)"},
+ {in: nil, out: (*int)(nil), decErr: "msgpack: Decode(non-settable *int)"},
{in: nil, out: new(chan bool), decErr: "msgpack: Decode(unsupported chan bool)"},
{in: true, out: new(bool)},
@@ -455,6 +554,7 @@ var (
},
{in: time.Unix(0, 0), out: new(time.Time)},
+ {in: new(time.Time), out: new(time.Time)},
{in: time.Unix(0, 1), out: new(time.Time)},
{in: time.Unix(1, 0), out: new(time.Time)},
{in: time.Unix(1, 1), out: new(time.Time)},
@@ -477,42 +577,51 @@ var (
in: &CustomEncoderField{Field: CustomEncoder{"a", nil, 1}},
out: new(CustomEncoderField),
},
+ {
+ in: &CustomEncoderEmbeddedPtr{&CustomEncoder{"a", nil, 1}},
+ out: new(CustomEncoderEmbeddedPtr),
+ },
{in: repoURL, out: new(url.URL)},
{in: repoURL, out: new(*url.URL)},
+ {in: OmitEmptyTest{}, out: new(OmitEmptyTest)},
+ {in: OmitTimeTest{}, out: new(OmitTimeTest)},
+
{in: nil, out: new(*AsArrayTest), wantnil: true},
{in: nil, out: new(AsArrayTest), wantzero: true},
{in: AsArrayTest{OmitEmptyTest: OmitEmptyTest{"foo", "bar"}}, out: new(AsArrayTest)},
{
in: AsArrayTest{OmitEmptyTest: OmitEmptyTest{"foo", "bar"}},
out: new(unexported),
- wanted: unexported{Foo: "foo"},
+ decErr: "msgpack: number of fields in array-encoded struct has changed",
},
{in: (*EventTime)(nil), out: new(*EventTime)},
- {in: &EventTime{time.Unix(0, 0)}, out: new(EventTime)},
+ {in: &EventTime{time.Unix(0, 0)}, out: new(*EventTime)},
{in: (*ExtTest)(nil), out: new(*ExtTest)},
- {in: &ExtTest{"world"}, out: new(ExtTest), wanted: ExtTest{"hello world"}},
+ {in: &ExtTest{"world"}, out: new(*ExtTest), wanted: ExtTest{"hello world"}},
{
- in: ExtTestField{ExtTest{"world"}},
- out: new(ExtTestField),
+ in: &ExtTestField{ExtTest{"world"}},
+ out: new(*ExtTestField),
wanted: ExtTestField{ExtTest{"hello world"}},
},
- {in: Interface{}, out: &Interface{Foo: "bar"}},
-
{
in: &InlineTest{OmitEmptyTest: OmitEmptyTest{Bar: "world"}},
out: new(InlineTest),
- }, {
+ },
+ {
in: &InlinePtrTest{OmitEmptyTest: &OmitEmptyTest{Bar: "world"}},
out: new(InlinePtrTest),
- }, {
- in: InlineDupTest{FooTest{"foo"}, FooDupTest{"foo dup"}},
+ },
+ {
+ in: InlineDupTest{FooTest{"foo"}, FooDupTest{"foo"}},
out: new(InlineDupTest),
},
+
+ {in: big.NewInt(123), out: new(big.Int)},
}
)
@@ -528,6 +637,8 @@ func indirect(viface interface{}) interface{} {
}
func TestTypes(t *testing.T) {
+ msgpack.RegisterExt(1, (*EventTime)(nil))
+
for _, test := range typeTests {
test.T = t
@@ -536,7 +647,7 @@ func TestTypes(t *testing.T) {
enc := msgpack.NewEncoder(&buf)
err := enc.Encode(test.in)
if test.encErr != "" {
- test.assertErr(err, test.encErr)
+ test.requireErr(err, test.encErr)
continue
}
if err != nil {
@@ -546,7 +657,7 @@ func TestTypes(t *testing.T) {
dec := msgpack.NewDecoder(&buf)
err = dec.Decode(test.out)
if test.decErr != "" {
- test.assertErr(err, test.decErr)
+ test.requireErr(err, test.decErr)
continue
}
if err != nil {
@@ -576,9 +687,7 @@ func TestTypes(t *testing.T) {
if wanted == nil {
wanted = indirect(test.in)
}
- if !reflect.DeepEqual(out, wanted) {
- t.Fatalf("%#v != %#v (%s)", out, wanted, test)
- }
+ require.Equal(t, wanted, out)
}
for _, test := range typeTests {
@@ -592,15 +701,24 @@ func TestTypes(t *testing.T) {
}
var dst interface{}
- err = msgpack.Unmarshal(b, &dst)
+ dec := msgpack.NewDecoder(bytes.NewReader(b))
+ dec.SetMapDecoder(func(dec *msgpack.Decoder) (interface{}, error) {
+ return dec.DecodeUntypedMap()
+ })
+
+ err = dec.Decode(&dst)
if err != nil {
- t.Fatalf("Decode failed: %s (%s)", err, test)
+ t.Fatalf("Unmarshal into interface{} failed: %s (%s)", err, test)
}
- dec := msgpack.NewDecoder(bytes.NewReader(b))
+ dec = msgpack.NewDecoder(bytes.NewReader(b))
+ dec.SetMapDecoder(func(dec *msgpack.Decoder) (interface{}, error) {
+ return dec.DecodeUntypedMap()
+ })
+
_, err = dec.DecodeInterface()
if err != nil {
- t.Fatalf("Decode failed: %s (%s)", err, test)
+ t.Fatalf("DecodeInterface failed: %s (%s)", err, test)
}
}
}
@@ -641,38 +759,29 @@ func TestStringsBin(t *testing.T) {
for _, test := range tests {
b, err := msgpack.Marshal(test.in)
- if err != nil {
- t.Fatal(err)
- }
+ require.Nil(t, err)
s := hex.EncodeToString(b)
- if s != test.wanted {
- t.Fatalf("%.32s != %.32s", s, test.wanted)
- }
+ require.Equal(t, s, test.wanted)
var out string
err = msgpack.Unmarshal(b, &out)
- if err != nil {
- t.Fatal(err)
- }
- if out != test.in {
- t.Fatalf("%s != %s", out, test.in)
- }
+ require.Nil(t, err)
+ require.Equal(t, out, test.in)
+
+ var msg msgpack.RawMessage
+ err = msgpack.Unmarshal(b, &msg)
+ require.Nil(t, err)
+ require.Equal(t, []byte(msg), b)
dec := msgpack.NewDecoder(bytes.NewReader(b))
v, err := dec.DecodeInterface()
- if err != nil {
- t.Fatal(err)
- }
- if v.(string) != test.in {
- t.Fatalf("%s != %s", v, test.in)
- }
+ require.Nil(t, err)
+ require.Equal(t, v.(string), test.in)
var dst interface{}
dst = ""
err = msgpack.Unmarshal(b, &dst)
- if err.Error() != "msgpack: Decode(nonsettable string)" {
- t.Fatal(err)
- }
+ require.EqualError(t, err, "msgpack: Decode(non-pointer string)")
}
}
@@ -740,7 +849,7 @@ func TestBin(t *testing.T) {
var dst interface{}
dst = make([]byte, 0)
err = msgpack.Unmarshal(b, &dst)
- if err.Error() != "msgpack: Decode(nonsettable []uint8)" {
+ if err.Error() != "msgpack: Decode(non-pointer []uint8)" {
t.Fatal(err)
}
}
@@ -770,7 +879,8 @@ func TestUint64(t *testing.T) {
}
var buf bytes.Buffer
- enc := msgpack.NewEncoder(&buf).UseCompactEncoding(true)
+ enc := msgpack.NewEncoder(&buf)
+ enc.UseCompactInts(true)
for _, test := range tests {
err := enc.Encode(test.in)
@@ -803,7 +913,7 @@ func TestUint64(t *testing.T) {
var out3 interface{}
out3 = uint64(0)
err = msgpack.Unmarshal(buf.Bytes(), &out3)
- if err.Error() != "msgpack: Decode(nonsettable uint64)" {
+ if err.Error() != "msgpack: Decode(non-pointer uint64)" {
t.Fatal(err)
}
@@ -857,7 +967,8 @@ func TestInt64(t *testing.T) {
}
var buf bytes.Buffer
- enc := msgpack.NewEncoder(&buf).UseCompactEncoding(true)
+ enc := msgpack.NewEncoder(&buf)
+ enc.UseCompactInts(true)
for _, test := range tests {
err := enc.Encode(test.in)
@@ -890,7 +1001,7 @@ func TestInt64(t *testing.T) {
var out3 interface{}
out3 = int64(0)
err = msgpack.Unmarshal(buf.Bytes(), &out3)
- if err.Error() != "msgpack: Decode(nonsettable int64)" {
+ if err.Error() != "msgpack: Decode(non-pointer int64)" {
t.Fatal(err)
}
@@ -960,7 +1071,7 @@ func TestFloat32(t *testing.T) {
var dst interface{}
dst = float32(0)
err = msgpack.Unmarshal(b, &dst)
- if err.Error() != "msgpack: Decode(nonsettable float32)" {
+ if err.Error() != "msgpack: Decode(non-pointer float32)" {
t.Fatal(err)
}
}
@@ -1026,25 +1137,10 @@ func TestFloat64(t *testing.T) {
var dst interface{}
dst = float64(0)
err = msgpack.Unmarshal(b, &dst)
- if err.Error() != "msgpack: Decode(nonsettable float64)" {
+ if err.Error() != "msgpack: Decode(non-pointer float64)" {
t.Fatal(err)
}
}
-
- in := float64(math.NaN())
- b, err := msgpack.Marshal(in)
- if err != nil {
- t.Fatal(err)
- }
-
- var out float64
- err = msgpack.Unmarshal(b, &out)
- if err != nil {
- t.Fatal(err)
- }
- if !math.IsNaN(out) {
- t.Fatal("not NaN")
- }
}
func mustParseTime(format, s string) time.Time {
diff --git a/unsafe.go b/unsafe.go
new file mode 100644
index 0000000..192ac47
--- /dev/null
+++ b/unsafe.go
@@ -0,0 +1,22 @@
+// +build !appengine
+
+package msgpack
+
+import (
+ "unsafe"
+)
+
+// bytesToString converts byte slice to string.
+func bytesToString(b []byte) string {
+ return *(*string)(unsafe.Pointer(&b))
+}
+
+// stringToBytes converts string to byte slice.
+func stringToBytes(s string) []byte {
+ return *(*[]byte)(unsafe.Pointer(
+ &struct {
+ string
+ Cap int
+ }{s, len(s)},
+ ))
+}
diff --git a/version.go b/version.go
new file mode 100644
index 0000000..1d49337
--- /dev/null
+++ b/version.go
@@ -0,0 +1,6 @@
+package msgpack
+
+// Version is the current release version.
+func Version() string {
+ return "5.3.5"
+}
Debdiff
[The following lists of changes regard files as different if they have different names, permissions or owners.]
Files in second set of .debs but not in first
-rw-r--r-- root/root /usr/share/gocode/src/github.com/vmihailenco/msgpack/extra/msgpappengine/appengine.go -rw-r--r-- root/root /usr/share/gocode/src/github.com/vmihailenco/msgpack/intern.go -rw-r--r-- root/root /usr/share/gocode/src/github.com/vmihailenco/msgpack/intern_test.go -rw-r--r-- root/root /usr/share/gocode/src/github.com/vmihailenco/msgpack/msgpcode/msgpcode.go -rw-r--r-- root/root /usr/share/gocode/src/github.com/vmihailenco/msgpack/safe.go -rw-r--r-- root/root /usr/share/gocode/src/github.com/vmihailenco/msgpack/unsafe.go -rw-r--r-- root/root /usr/share/gocode/src/github.com/vmihailenco/msgpack/version.go
Files in first set of .debs but not in second
-rw-r--r-- root/root /usr/share/gocode/src/github.com/vmihailenco/msgpack/appengine.go -rw-r--r-- root/root /usr/share/gocode/src/github.com/vmihailenco/msgpack/codes/codes.go
No differences were encountered in the control files