New Upstream Release - golang-refraction-networking-utls
Ready changes
Summary
Merged new upstream version: 1.3.2 (was: 1.2.1).
Diff
diff --git a/cipher_suites.go b/cipher_suites.go
index b0b6af3..5add594 100644
--- a/cipher_suites.go
+++ b/cipher_suites.go
@@ -10,6 +10,8 @@ import (
"crypto/cipher"
"crypto/des"
"crypto/hmac"
+
+ // "crypto/internal/boring"
"crypto/rc4"
"crypto/sha1"
"crypto/sha256"
diff --git a/common.go b/common.go
index fc0f4c5..ec3b849 100644
--- a/common.go
+++ b/common.go
@@ -42,7 +42,7 @@ const (
maxCiphertextTLS13 = 16384 + 256 // maximum ciphertext length in TLS 1.3
recordHeaderLen = 5 // record header length
maxHandshake = 65536 // maximum handshake we support (protocol max is 16 MB)
- maxUselessRecords = 16 // maximum number of consecutive non-advancing records
+ maxUselessRecords = 32 // maximum number of consecutive non-advancing records
)
// TLS record types.
@@ -101,7 +101,6 @@ const (
extensionCertificateAuthorities uint16 = 47
extensionSignatureAlgorithmsCert uint16 = 50
extensionKeyShare uint16 = 51
- extensionNextProtoNeg uint16 = 13172 // not IANA assigned // Pending discussion on whether or not remove this. crypto/tls removed it on Nov 21, 2019.
extensionRenegotiationInfo uint16 = 0xff01
)
@@ -151,7 +150,8 @@ const (
// TLS CertificateStatusType (RFC 3546)
const (
- statusTypeOCSP uint8 = 1
+ statusTypeOCSP uint8 = 1
+ statusV2TypeOCSP uint8 = 2
)
// Certificate types (for certificateRequestMsg)
@@ -656,6 +656,13 @@ type Config struct {
// testing or in combination with VerifyConnection or VerifyPeerCertificate.
InsecureSkipVerify bool
+ // InsecureSkipTimeVerify controls whether a client verifies the server's
+ // certificate chain against time. If InsecureSkipTimeVerify is true,
+ // crypto/tls accepts the certificate even when it is expired.
+ //
+ // This field is ignored when InsecureSkipVerify is true.
+ InsecureSkipTimeVerify bool // [uTLS]
+
// InsecureServerNameToVerify is used to verify the hostname on the returned
// certificates. It is intended to use with spoofed ServerName.
// If InsecureServerNameToVerify is "*", crypto/tls will do normal
@@ -821,6 +828,7 @@ func (c *Config) Clone() *Config {
ClientAuth: c.ClientAuth,
ClientCAs: c.ClientCAs,
InsecureSkipVerify: c.InsecureSkipVerify,
+ InsecureSkipTimeVerify: c.InsecureSkipTimeVerify,
InsecureServerNameToVerify: c.InsecureServerNameToVerify,
CipherSuites: c.CipherSuites,
PreferServerCipherSuites: c.PreferServerCipherSuites,
@@ -1405,7 +1413,7 @@ func (c *Certificate) leaf() (*x509.Certificate, error) {
}
type handshakeMessage interface {
- marshal() []byte
+ marshal() ([]byte, error)
unmarshal([]byte) bool
}
diff --git a/conn.go b/conn.go
index 13a7963..7f6f1f8 100644
--- a/conn.go
+++ b/conn.go
@@ -731,7 +731,7 @@ func (c *Conn) readRecordOrCCS(expectChangeCipherSpec bool) error {
// 5, a server can send a ChangeCipherSpec before its ServerHello, when
// c.vers is still unset. That's not useful though and suspicious if the
// server then selects a lower protocol version, so don't allow that.
- if c.vers == VersionTLS13 {
+ if c.vers == VersionTLS13 && !handshakeComplete {
return c.retryReadRecord(expectChangeCipherSpec)
}
if !expectChangeCipherSpec {
@@ -1007,18 +1007,37 @@ func (c *Conn) writeRecordLocked(typ recordType, data []byte) (int, error) {
return n, nil
}
-// writeRecord writes a TLS record with the given type and payload to the
-// connection and updates the record layer state.
-func (c *Conn) writeRecord(typ recordType, data []byte) (int, error) {
+// writeHandshakeRecord writes a handshake message to the connection and updates
+// the record layer state. If transcript is non-nil the marshalled message is
+// written to it.
+func (c *Conn) writeHandshakeRecord(msg handshakeMessage, transcript transcriptHash) (int, error) {
c.out.Lock()
defer c.out.Unlock()
- return c.writeRecordLocked(typ, data)
+ data, err := msg.marshal()
+ if err != nil {
+ return 0, err
+ }
+ if transcript != nil {
+ transcript.Write(data)
+ }
+
+ return c.writeRecordLocked(recordTypeHandshake, data)
+}
+
+// writeChangeCipherRecord writes a ChangeCipherSpec message to the connection and
+// updates the record layer state.
+func (c *Conn) writeChangeCipherRecord() error {
+ c.out.Lock()
+ defer c.out.Unlock()
+ _, err := c.writeRecordLocked(recordTypeChangeCipherSpec, []byte{1})
+ return err
}
// readHandshake reads the next handshake message from
-// the record layer.
-func (c *Conn) readHandshake() (any, error) {
+// the record layer. If transcript is non-nil, the message
+// is written to the passed transcriptHash.
+func (c *Conn) readHandshake(transcript transcriptHash) (any, error) {
for c.hand.Len() < 4 {
if err := c.readRecord(); err != nil {
return nil, err
@@ -1106,6 +1125,11 @@ func (c *Conn) readHandshake() (any, error) {
if !m.unmarshal(data) {
return nil, c.in.setErrorLocked(c.sendAlert(alertUnexpectedMessage))
}
+
+ if transcript != nil {
+ transcript.Write(data)
+ }
+
return m, nil
}
@@ -1181,7 +1205,7 @@ func (c *Conn) handleRenegotiation() error {
return errors.New("tls: internal error: unexpected renegotiation")
}
- msg, err := c.readHandshake()
+ msg, err := c.readHandshake(nil)
if err != nil {
return err
}
@@ -1227,7 +1251,7 @@ func (c *Conn) handlePostHandshakeMessage() error {
return c.handleRenegotiation()
}
- msg, err := c.readHandshake()
+ msg, err := c.readHandshake(nil)
if err != nil {
return err
}
@@ -1263,7 +1287,11 @@ func (c *Conn) handleKeyUpdate(keyUpdate *keyUpdateMsg) error {
defer c.out.Unlock()
msg := &keyUpdateMsg{}
- _, err := c.writeRecordLocked(recordTypeHandshake, msg.marshal())
+ msgBytes, err := msg.marshal()
+ if err != nil {
+ return err
+ }
+ _, err = c.writeRecordLocked(recordTypeHandshake, msgBytes)
if err != nil {
// Surface the error at the next write.
c.out.setErrorLocked(err)
@@ -1523,7 +1551,9 @@ func (c *Conn) connectionStateLocked() ConnectionState {
} else {
state.ekm = c.ekm
}
-
+ // [UTLS SECTION START]
+ c.utlsConnectionStateLocked(&state)
+ // [UTLS SECTION END]
return state
}
diff --git a/debian/changelog b/debian/changelog
index 597296e..612cf71 100644
--- a/debian/changelog
+++ b/debian/changelog
@@ -1,3 +1,9 @@
+golang-refraction-networking-utls (1.3.2-1) UNRELEASED; urgency=low
+
+ * New upstream release.
+
+ -- Debian Janitor <janitor@jelmer.uk> Mon, 17 Apr 2023 17:15:58 -0000
+
golang-refraction-networking-utls (1.2.1-2) unstable; urgency=medium
* Team upload
diff --git a/debian/patches/0001-Skip-client-server-handshake-tests.patch b/debian/patches/0001-Skip-client-server-handshake-tests.patch
index 3f16df3..b819eb1 100644
--- a/debian/patches/0001-Skip-client-server-handshake-tests.patch
+++ b/debian/patches/0001-Skip-client-server-handshake-tests.patch
@@ -9,11 +9,11 @@ Go1.19 std library, which should be just working.
handshake_server_test.go | 1 +
2 files changed, 2 insertions(+)
-diff --git a/handshake_client_test.go b/handshake_client_test.go
-index 380de9f..52e2a2e 100644
---- a/handshake_client_test.go
-+++ b/handshake_client_test.go
-@@ -784,6 +784,7 @@ func TestHandshakeClientCertRSA(t *testing.T) {
+Index: golang-refraction-networking-utls.git/handshake_client_test.go
+===================================================================
+--- golang-refraction-networking-utls.git.orig/handshake_client_test.go
++++ golang-refraction-networking-utls.git/handshake_client_test.go
+@@ -784,6 +784,7 @@ func TestHandshakeClientCertRSA(t *testi
}
func TestHandshakeClientCertECDSA(t *testing.T) {
@@ -21,11 +21,11 @@ index 380de9f..52e2a2e 100644
config := testConfig.Clone()
cert, _ := X509KeyPair([]byte(clientECDSACertificatePEM), []byte(clientECDSAKeyPEM))
config.Certificates = []Certificate{cert}
-diff --git a/handshake_server_test.go b/handshake_server_test.go
-index 39b65f6..16e9ac7 100644
---- a/handshake_server_test.go
-+++ b/handshake_server_test.go
-@@ -837,6 +837,7 @@ func TestHandshakeServerCHACHA20SHA256(t *testing.T) {
+Index: golang-refraction-networking-utls.git/handshake_server_test.go
+===================================================================
+--- golang-refraction-networking-utls.git.orig/handshake_server_test.go
++++ golang-refraction-networking-utls.git/handshake_server_test.go
+@@ -852,6 +852,7 @@ func TestHandshakeServerCHACHA20SHA256(t
}
func TestHandshakeServerECDHEECDSAAES(t *testing.T) {
diff --git a/go.mod b/go.mod
index d7f2ef0..9703f97 100644
--- a/go.mod
+++ b/go.mod
@@ -4,12 +4,13 @@ go 1.19
require (
github.com/andybalholm/brotli v1.0.4
+ github.com/gaukas/godicttls v0.0.3
github.com/klauspost/compress v1.15.15
golang.org/x/crypto v0.5.0
- golang.org/x/net v0.5.0
+ golang.org/x/net v0.7.0
)
require (
- golang.org/x/sys v0.4.0 // indirect
- golang.org/x/text v0.6.0 // indirect
+ golang.org/x/sys v0.5.0 // indirect
+ golang.org/x/text v0.7.0 // indirect
)
diff --git a/go.sum b/go.sum
index 8b58c2d..85b80ed 100644
--- a/go.sum
+++ b/go.sum
@@ -1,24 +1,14 @@
github.com/andybalholm/brotli v1.0.4 h1:V7DdXeJtZscaqfNuAdSRuRFzuiKlHSC/Zh3zl9qY3JY=
github.com/andybalholm/brotli v1.0.4/go.mod h1:fO7iG3H7G2nSZ7m0zPUDn85XEX2GTukHGRSepvi9Eig=
-github.com/klauspost/compress v1.15.12 h1:YClS/PImqYbn+UILDnqxQCZ3RehC9N318SU3kElDUEM=
-github.com/klauspost/compress v1.15.12/go.mod h1:QPwzmACJjUTFsnSHH934V6woptycfrDDJnH7hvFVbGM=
-github.com/klauspost/compress v1.15.14 h1:i7WCKDToww0wA+9qrUZ1xOjp218vfFo3nTU6UHp+gOc=
-github.com/klauspost/compress v1.15.14/go.mod h1:QPwzmACJjUTFsnSHH934V6woptycfrDDJnH7hvFVbGM=
+github.com/gaukas/godicttls v0.0.3 h1:YNDIf0d9adcxOijiLrEzpfZGAkNwLRzPaG6OjU7EITk=
+github.com/gaukas/godicttls v0.0.3/go.mod h1:l6EenT4TLWgTdwslVb4sEMOCf7Bv0JAK67deKr9/NCI=
github.com/klauspost/compress v1.15.15 h1:EF27CXIuDsYJ6mmvtBRlEuB2UVOqHG1tAXgZ7yIO+lw=
github.com/klauspost/compress v1.15.15/go.mod h1:ZcK2JAFqKOpnBlxcLsJzYfrS9X1akm9fHZNnD9+Vo/4=
-golang.org/x/crypto v0.1.0 h1:MDRAIl0xIo9Io2xV565hzXHw3zVseKrJKodhohM5CjU=
-golang.org/x/crypto v0.1.0/go.mod h1:RecgLatLF4+eUMCP1PoPZQb+cVrJcOPbHkTkbkB9sbw=
golang.org/x/crypto v0.5.0 h1:U/0M97KRkSFvyD/3FSmdP5W5swImpNgle/EHFhOsQPE=
golang.org/x/crypto v0.5.0/go.mod h1:NK/OQwhpMQP3MwtdjgLlYHnH9ebylxKWv3e0fK+mkQU=
-golang.org/x/net v0.1.0 h1:hZ/3BUoy5aId7sCpA/Tc5lt8DkFgdVS2onTpJsZ/fl0=
-golang.org/x/net v0.1.0/go.mod h1:Cx3nUiGt4eDBEyega/BKRp+/AlGL8hYe7U9odMt2Cco=
-golang.org/x/net v0.5.0 h1:GyT4nK/YDHSqa1c4753ouYCDajOYKTja9Xb/OHtgvSw=
-golang.org/x/net v0.5.0/go.mod h1:DivGGAXEgPSlEBzxGzZI+ZLohi+xUj054jfeKui00ws=
-golang.org/x/sys v0.1.0 h1:kunALQeHf1/185U1i0GOB/fy1IPRDDpuoOOqRReG57U=
-golang.org/x/sys v0.1.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/sys v0.4.0 h1:Zr2JFtRQNX3BCZ8YtxRE9hNJYC8J6I1MVbMg6owUp18=
-golang.org/x/sys v0.4.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
-golang.org/x/text v0.4.0 h1:BrVqGRd7+k1DiOgtnFvAkoQEWQvBc25ouMJM6429SFg=
-golang.org/x/text v0.4.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
-golang.org/x/text v0.6.0 h1:3XmdazWV+ubf7QgHSTWeykHOci5oeekaGJBLkrkaw4k=
-golang.org/x/text v0.6.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
+golang.org/x/net v0.7.0 h1:rJrUqqhjsgNp7KqAIc25s9pZnjU7TUcSY7HcVZjdn1g=
+golang.org/x/net v0.7.0/go.mod h1:2Tu9+aMcznHK/AK1HMvgo6xiTLG5rD5rZLDS+rp2Bjs=
+golang.org/x/sys v0.5.0 h1:MUK/U/4lj1t1oPg0HfuXDN/Z1wv31ZJ/YcPiGccS4DU=
+golang.org/x/sys v0.5.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
+golang.org/x/text v0.7.0 h1:4BRB4x83lYWy72KwLD/qYDuTu7q9PjSagHvijDw7cLo=
+golang.org/x/text v0.7.0/go.mod h1:mrYo+phRRbMaCq/xk9113O4dZlRixOauAjOtrjsXDZ8=
diff --git a/handshake_client.go b/handshake_client.go
index 50e1b00..d902148 100644
--- a/handshake_client.go
+++ b/handshake_client.go
@@ -167,7 +167,10 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) {
}
c.serverName = hello.serverName
- cacheKey, session, earlySecret, binderKey := c.loadSession(hello)
+ cacheKey, session, earlySecret, binderKey, err := c.loadSession(hello)
+ if err != nil {
+ return err
+ }
if cacheKey != "" && session != nil {
defer func() {
// If we got a handshake failure when resuming a session, throw away
@@ -182,11 +185,12 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) {
}()
}
- if _, err := c.writeRecord(recordTypeHandshake, hello.marshal()); err != nil {
+ if _, err := c.writeHandshakeRecord(hello, nil); err != nil {
return err
}
- msg, err := c.readHandshake()
+ // serverHelloMsg is not included in the transcript
+ msg, err := c.readHandshake(nil)
if err != nil {
return err
}
@@ -215,14 +219,15 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) {
if c.vers == VersionTLS13 {
hs := &clientHandshakeStateTLS13{
- c: c,
- ctx: ctx,
- serverHello: serverHello,
- hello: hello,
- ecdheParams: ecdheParams,
- session: session,
- earlySecret: earlySecret,
- binderKey: binderKey,
+ c: c,
+ ctx: ctx,
+ serverHello: serverHello,
+ hello: hello,
+ ecdheParams: ecdheParams,
+ keySharesEcdheParams: make(KeySharesEcdheParameters, 2), // [uTLS]
+ session: session,
+ earlySecret: earlySecret,
+ binderKey: binderKey,
}
// In TLS 1.3, session tickets are delivered after the handshake.
@@ -251,9 +256,9 @@ func (c *Conn) clientHandshake(ctx context.Context) (err error) {
}
func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string,
- session *ClientSessionState, earlySecret, binderKey []byte) {
+ session *ClientSessionState, earlySecret, binderKey []byte, err error) {
if c.config.SessionTicketsDisabled || c.config.ClientSessionCache == nil {
- return "", nil, nil, nil
+ return "", nil, nil, nil, nil
}
hello.ticketSupported = true
@@ -268,14 +273,14 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string,
// renegotiation is primarily used to allow a client to send a client
// certificate, which would be skipped if session resumption occurred.
if c.handshakes != 0 {
- return "", nil, nil, nil
+ return "", nil, nil, nil, nil
}
// Try to resume a previously negotiated TLS session, if available.
cacheKey = clientSessionCacheKey(c.conn.RemoteAddr(), c.config)
session, ok := c.config.ClientSessionCache.Get(cacheKey)
if !ok || session == nil {
- return cacheKey, nil, nil, nil
+ return cacheKey, nil, nil, nil, nil
}
// Check that version used for the previous session is still valid.
@@ -287,7 +292,7 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string,
}
}
if !versOk {
- return cacheKey, nil, nil, nil
+ return cacheKey, nil, nil, nil, nil
}
// Check that the cached server certificate is not expired, and that it's
@@ -296,24 +301,36 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string,
if !c.config.InsecureSkipVerify {
if len(session.verifiedChains) == 0 {
// The original connection had InsecureSkipVerify, while this doesn't.
- return cacheKey, nil, nil, nil
+ return cacheKey, nil, nil, nil, nil
}
serverCert := session.serverCertificates[0]
- if c.config.time().After(serverCert.NotAfter) {
- // Expired certificate, delete the entry.
- c.config.ClientSessionCache.Put(cacheKey, nil)
- return cacheKey, nil, nil, nil
+ // [UTLS SECTION START]
+ if !c.config.InsecureSkipTimeVerify {
+ if c.config.time().After(serverCert.NotAfter) {
+ // Expired certificate, delete the entry.
+ c.config.ClientSessionCache.Put(cacheKey, nil)
+ return cacheKey, nil, nil, nil, nil
+ }
}
- if err := serverCert.VerifyHostname(c.config.ServerName); err != nil {
- return cacheKey, nil, nil, nil
+ var dnsName string
+ if len(c.config.InsecureServerNameToVerify) == 0 {
+ dnsName = c.config.ServerName
+ } else if c.config.InsecureServerNameToVerify != "*" {
+ dnsName = c.config.InsecureServerNameToVerify
+ }
+ if len(dnsName) > 0 {
+ if err := serverCert.VerifyHostname(dnsName); err != nil {
+ return cacheKey, nil, nil, nil, nil
+ }
}
+ // [UTLS SECTION END]
}
if session.vers != VersionTLS13 {
// In TLS 1.2 the cipher suite must match the resumed session. Ensure we
// are still offering it.
if mutualCipherSuite(hello.cipherSuites, session.cipherSuite) == nil {
- return cacheKey, nil, nil, nil
+ return cacheKey, nil, nil, nil, nil
}
hello.sessionTicket = session.sessionTicket
@@ -323,14 +340,14 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string,
// Check that the session ticket is not expired.
if c.config.time().After(session.useBy) {
c.config.ClientSessionCache.Put(cacheKey, nil)
- return cacheKey, nil, nil, nil
+ return cacheKey, nil, nil, nil, nil
}
// In TLS 1.3 the KDF hash must match the resumed session. Ensure we
// offer at least one cipher suite with that hash.
cipherSuite := cipherSuiteTLS13ByID(session.cipherSuite)
if cipherSuite == nil {
- return cacheKey, nil, nil, nil
+ return cacheKey, nil, nil, nil, nil
}
cipherSuiteOk := false
for _, offeredID := range hello.cipherSuites {
@@ -341,7 +358,7 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string,
}
}
if !cipherSuiteOk {
- return cacheKey, nil, nil, nil
+ return cacheKey, nil, nil, nil, nil
}
// Set the pre_shared_key extension. See RFC 8446, Section 4.2.11.1.
@@ -359,9 +376,15 @@ func (c *Conn) loadSession(hello *clientHelloMsg) (cacheKey string,
earlySecret = cipherSuite.extract(psk, nil)
binderKey = cipherSuite.deriveSecret(earlySecret, resumptionBinderLabel, nil)
transcript := cipherSuite.hash.New()
- transcript.Write(hello.marshalWithoutBinders())
+ helloBytes, err := hello.marshalWithoutBinders()
+ if err != nil {
+ return "", nil, nil, nil, err
+ }
+ transcript.Write(helloBytes)
pskBinders := [][]byte{cipherSuite.finishedHash(binderKey, transcript)}
- hello.updateBinders(pskBinders)
+ if err := hello.updateBinders(pskBinders); err != nil {
+ return "", nil, nil, nil, err
+ }
return
}
@@ -406,8 +429,12 @@ func (hs *clientHandshakeState) handshake() error {
hs.finishedHash.discardHandshakeBuffer()
}
- hs.finishedHash.Write(hs.hello.marshal())
- hs.finishedHash.Write(hs.serverHello.marshal())
+ if err := transcriptMsg(hs.hello, &hs.finishedHash); err != nil {
+ return err
+ }
+ if err := transcriptMsg(hs.serverHello, &hs.finishedHash); err != nil {
+ return err
+ }
c.buffering = true
c.didResume = isResume
@@ -478,7 +505,7 @@ func (hs *clientHandshakeState) pickCipherSuite() error {
func (hs *clientHandshakeState) doFullHandshake() error {
c := hs.c
- msg, err := c.readHandshake()
+ msg, err := c.readHandshake(&hs.finishedHash)
if err != nil {
return err
}
@@ -487,9 +514,8 @@ func (hs *clientHandshakeState) doFullHandshake() error {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(certMsg, msg)
}
- hs.finishedHash.Write(certMsg.marshal())
- msg, err = c.readHandshake()
+ msg, err = c.readHandshake(&hs.finishedHash)
if err != nil {
return err
}
@@ -507,11 +533,10 @@ func (hs *clientHandshakeState) doFullHandshake() error {
c.sendAlert(alertUnexpectedMessage)
return errors.New("tls: received unexpected CertificateStatus message")
}
- hs.finishedHash.Write(cs.marshal())
c.ocspResponse = cs.response
- msg, err = c.readHandshake()
+ msg, err = c.readHandshake(&hs.finishedHash)
if err != nil {
return err
}
@@ -540,14 +565,13 @@ func (hs *clientHandshakeState) doFullHandshake() error {
skx, ok := msg.(*serverKeyExchangeMsg)
if ok {
- hs.finishedHash.Write(skx.marshal())
err = keyAgreement.processServerKeyExchange(c.config, hs.hello, hs.serverHello, c.peerCertificates[0], skx)
if err != nil {
c.sendAlert(alertUnexpectedMessage)
return err
}
- msg, err = c.readHandshake()
+ msg, err = c.readHandshake(&hs.finishedHash)
if err != nil {
return err
}
@@ -558,7 +582,6 @@ func (hs *clientHandshakeState) doFullHandshake() error {
certReq, ok := msg.(*certificateRequestMsg)
if ok {
certRequested = true
- hs.finishedHash.Write(certReq.marshal())
cri := certificateRequestInfoFromMsg(hs.ctx, c.vers, certReq)
if chainToSend, err = c.getClientCertificate(cri); err != nil {
@@ -566,7 +589,7 @@ func (hs *clientHandshakeState) doFullHandshake() error {
return err
}
- msg, err = c.readHandshake()
+ msg, err = c.readHandshake(&hs.finishedHash)
if err != nil {
return err
}
@@ -577,7 +600,6 @@ func (hs *clientHandshakeState) doFullHandshake() error {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(shd, msg)
}
- hs.finishedHash.Write(shd.marshal())
// If the server requested a certificate then we have to send a
// Certificate message, even if it's empty because we don't have a
@@ -585,8 +607,7 @@ func (hs *clientHandshakeState) doFullHandshake() error {
if certRequested {
certMsg = new(certificateMsg)
certMsg.certificates = chainToSend.Certificate
- hs.finishedHash.Write(certMsg.marshal())
- if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil {
+ if _, err := hs.c.writeHandshakeRecord(certMsg, &hs.finishedHash); err != nil {
return err
}
}
@@ -597,8 +618,7 @@ func (hs *clientHandshakeState) doFullHandshake() error {
return err
}
if ckx != nil {
- hs.finishedHash.Write(ckx.marshal())
- if _, err := c.writeRecord(recordTypeHandshake, ckx.marshal()); err != nil {
+ if _, err := hs.c.writeHandshakeRecord(ckx, &hs.finishedHash); err != nil {
return err
}
}
@@ -650,8 +670,7 @@ func (hs *clientHandshakeState) doFullHandshake() error {
return err
}
- hs.finishedHash.Write(certVerify.marshal())
- if _, err := c.writeRecord(recordTypeHandshake, certVerify.marshal()); err != nil {
+ if _, err := hs.c.writeHandshakeRecord(certVerify, &hs.finishedHash); err != nil {
return err
}
}
@@ -793,7 +812,10 @@ func (hs *clientHandshakeState) readFinished(out []byte) error {
return err
}
- msg, err := c.readHandshake()
+ // finishedMsg is included in the transcript, but not until after we
+ // check the client version, since the state before this message was
+ // sent is used during verification.
+ msg, err := c.readHandshake(nil)
if err != nil {
return err
}
@@ -809,7 +831,11 @@ func (hs *clientHandshakeState) readFinished(out []byte) error {
c.sendAlert(alertHandshakeFailure)
return errors.New("tls: server's Finished message was incorrect")
}
- hs.finishedHash.Write(serverFinished.marshal())
+
+ if err := transcriptMsg(serverFinished, &hs.finishedHash); err != nil {
+ return err
+ }
+
copy(out, verify)
return nil
}
@@ -820,7 +846,7 @@ func (hs *clientHandshakeState) readSessionTicket() error {
}
c := hs.c
- msg, err := c.readHandshake()
+ msg, err := c.readHandshake(&hs.finishedHash)
if err != nil {
return err
}
@@ -829,7 +855,6 @@ func (hs *clientHandshakeState) readSessionTicket() error {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(sessionTicketMsg, msg)
}
- hs.finishedHash.Write(sessionTicketMsg.marshal())
hs.session = &ClientSessionState{
sessionTicket: sessionTicketMsg.ticket,
@@ -849,14 +874,13 @@ func (hs *clientHandshakeState) readSessionTicket() error {
func (hs *clientHandshakeState) sendFinished(out []byte) error {
c := hs.c
- if _, err := c.writeRecord(recordTypeChangeCipherSpec, []byte{1}); err != nil {
+ if err := c.writeChangeCipherRecord(); err != nil {
return err
}
finished := new(finishedMsg)
finished.verifyData = hs.finishedHash.clientSum(hs.masterSecret)
- hs.finishedHash.Write(finished.marshal())
- if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil {
+ if _, err := hs.c.writeHandshakeRecord(finished, &hs.finishedHash); err != nil {
return err
}
copy(out, finished.verifyData)
@@ -884,6 +908,10 @@ func (c *Conn) verifyServerCertificate(certificates [][]byte) error {
Intermediates: x509.NewCertPool(),
}
+ if c.config.InsecureSkipTimeVerify {
+ opts.CurrentTime = certs[0].NotAfter
+ }
+
if len(c.config.InsecureServerNameToVerify) == 0 {
opts.DNSName = c.config.ServerName
} else if c.config.InsecureServerNameToVerify != "*" {
diff --git a/handshake_client_test.go b/handshake_client_test.go
index 380de9f..749c9fc 100644
--- a/handshake_client_test.go
+++ b/handshake_client_test.go
@@ -1257,7 +1257,7 @@ func TestServerSelectingUnconfiguredApplicationProtocol(t *testing.T) {
cipherSuite: TLS_RSA_WITH_AES_128_GCM_SHA256,
alpnProtocol: "how-about-this",
}
- serverHelloBytes := serverHello.marshal()
+ serverHelloBytes := mustMarshal(t, serverHello)
s.Write([]byte{
byte(recordTypeHandshake),
@@ -1500,7 +1500,7 @@ func TestServerSelectingUnconfiguredCipherSuite(t *testing.T) {
random: make([]byte, 32),
cipherSuite: TLS_RSA_WITH_AES_256_GCM_SHA384,
}
- serverHelloBytes := serverHello.marshal()
+ serverHelloBytes := mustMarshal(t, serverHello)
s.Write([]byte{
byte(recordTypeHandshake),
diff --git a/handshake_client_tls13.go b/handshake_client_tls13.go
index cddf081..2385c15 100644
--- a/handshake_client_tls13.go
+++ b/handshake_client_tls13.go
@@ -17,12 +17,30 @@ import (
"time"
)
+// [uTLS SECTION START]
+type KeySharesEcdheParameters map[CurveID]ecdheParameters
+
+func (keymap KeySharesEcdheParameters) AddEcdheParams(curveID CurveID, params ecdheParameters) {
+ keymap[curveID] = params
+}
+func (keymap KeySharesEcdheParameters) GetEcdheParams(curveID CurveID) (params ecdheParameters, ok bool) {
+ params, ok = keymap[curveID]
+ return
+}
+func (keymap KeySharesEcdheParameters) GetPublicEcdheParams(curveID CurveID) (params EcdheParameters, ok bool) {
+ params, ok = keymap[curveID]
+ return
+}
+
+// [uTLS SECTION END]
+
type clientHandshakeStateTLS13 struct {
- c *Conn
- ctx context.Context
- serverHello *serverHelloMsg
- hello *clientHelloMsg
- ecdheParams ecdheParameters
+ c *Conn
+ ctx context.Context
+ serverHello *serverHelloMsg
+ hello *clientHelloMsg
+ ecdheParams ecdheParameters
+ keySharesEcdheParams KeySharesEcdheParameters // [uTLS]
session *ClientSessionState
earlySecret []byte
@@ -55,6 +73,14 @@ func (hs *clientHandshakeStateTLS13) handshake() error {
return errors.New("tls: server selected TLS 1.3 in a renegotiation")
}
+ // [uTLS SECTION START]
+
+ // set echdheParams to what we received from server
+ if ecdheParams, ok := hs.keySharesEcdheParams.GetEcdheParams(hs.serverHello.serverShare.group); ok {
+ hs.ecdheParams = ecdheParams
+ }
+ // [uTLS SECTION END]
+
// Consistency check on the presence of a keyShare and its parameters.
if hs.ecdheParams == nil || len(hs.hello.keyShares) < 1 { // [uTLS]
// keyshares "< 1" instead of "!= 1", as uTLS may send multiple
@@ -66,7 +92,10 @@ func (hs *clientHandshakeStateTLS13) handshake() error {
}
hs.transcript = hs.suite.hash.New()
- hs.transcript.Write(hs.hello.marshal())
+
+ if err := transcriptMsg(hs.hello, hs.transcript); err != nil {
+ return err
+ }
if bytes.Equal(hs.serverHello.random, helloRetryRequestRandom) {
if err := hs.sendDummyChangeCipherSpec(); err != nil {
@@ -77,7 +106,9 @@ func (hs *clientHandshakeStateTLS13) handshake() error {
}
}
- hs.transcript.Write(hs.serverHello.marshal())
+ if err := transcriptMsg(hs.serverHello, hs.transcript); err != nil {
+ return err
+ }
c.buffering = true
if err := hs.processServerHello(); err != nil {
@@ -181,8 +212,7 @@ func (hs *clientHandshakeStateTLS13) sendDummyChangeCipherSpec() error {
}
hs.sentDummyCCS = true
- _, err := hs.c.writeRecord(recordTypeChangeCipherSpec, []byte{1})
- return err
+ return hs.c.writeChangeCipherRecord()
}
// processHelloRetryRequest handles the HRR in hs.serverHello, modifies and
@@ -197,7 +227,21 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error {
hs.transcript.Reset()
hs.transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))})
hs.transcript.Write(chHash)
- hs.transcript.Write(hs.serverHello.marshal())
+ if err := transcriptMsg(hs.serverHello, hs.transcript); err != nil {
+ return err
+ }
+
+ // The only HelloRetryRequest extensions we support are key_share and
+ // cookie, and clients must abort the handshake if the HRR would not result
+ // in any change in the ClientHello.
+ if hs.serverHello.selectedGroup == 0 && hs.serverHello.cookie == nil {
+ c.sendAlert(alertIllegalParameter)
+ return errors.New("tls: server sent an unnecessary HelloRetryRequest message")
+ }
+
+ if hs.serverHello.cookie != nil {
+ hs.hello.cookie = hs.serverHello.cookie
+ }
// The only HelloRetryRequest extensions we support are key_share and
// cookie, and clients must abort the handshake if the HRR would not result
@@ -262,10 +306,18 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error {
transcript := hs.suite.hash.New()
transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))})
transcript.Write(chHash)
- transcript.Write(hs.serverHello.marshal())
- transcript.Write(hs.hello.marshalWithoutBinders())
+ if err := transcriptMsg(hs.serverHello, hs.transcript); err != nil {
+ return err
+ }
+ helloBytes, err := hs.hello.marshalWithoutBinders()
+ if err != nil {
+ return err
+ }
+ transcript.Write(helloBytes)
pskBinders := [][]byte{hs.suite.finishedHash(hs.binderKey, transcript)}
- hs.hello.updateBinders(pskBinders)
+ if err := hs.hello.updateBinders(pskBinders); err != nil {
+ return err
+ }
} else {
// Server selected a cipher suite incompatible with the PSK.
hs.hello.pskIdentities = nil
@@ -335,12 +387,12 @@ func (hs *clientHandshakeStateTLS13) processHelloRetryRequest() error {
}
// [uTLS SECTION ENDS]
- hs.transcript.Write(hs.hello.marshal())
- if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil {
+ if _, err := hs.c.writeHandshakeRecord(hs.hello, hs.transcript); err != nil {
return err
}
- msg, err := c.readHandshake()
+ // serverHelloMsg is not included in the transcript
+ msg, err := c.readHandshake(nil)
if err != nil {
return err
}
@@ -429,6 +481,7 @@ func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error {
if !hs.usingPSK {
earlySecret = hs.suite.extract(nil, nil)
}
+
handshakeSecret := hs.suite.extract(sharedKey,
hs.suite.deriveSecret(earlySecret, "derived", nil))
@@ -459,7 +512,7 @@ func (hs *clientHandshakeStateTLS13) establishHandshakeKeys() error {
func (hs *clientHandshakeStateTLS13) readServerParameters() error {
c := hs.c
- msg, err := c.readHandshake()
+ msg, err := c.readHandshake(hs.transcript)
if err != nil {
return err
}
@@ -469,7 +522,6 @@ func (hs *clientHandshakeStateTLS13) readServerParameters() error {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(encryptedExtensions, msg)
}
- hs.transcript.Write(encryptedExtensions.marshal())
if err := checkALPN(hs.hello.alpnProtocols, encryptedExtensions.alpnProtocol); err != nil {
c.sendAlert(alertUnsupportedExtension)
@@ -507,18 +559,21 @@ func (hs *clientHandshakeStateTLS13) readServerCertificate() error {
return nil
}
- msg, err := c.readHandshake()
+ // [UTLS SECTION BEGINS]
+ // msg, err := c.readHandshake(hs.transcript)
+ msg, err := c.readHandshake(nil) // hold writing to transcript until we know it is not compressed cert
+ // [UTLS SECTION ENDS]
if err != nil {
return err
}
certReq, ok := msg.(*certificateRequestMsgTLS13)
if ok {
- hs.transcript.Write(certReq.marshal())
-
hs.certReq = certReq
+ transcriptMsg(certReq, hs.transcript) // [UTLS] if it is certReq (not compressedCert), write to transcript
- msg, err = c.readHandshake()
+ // msg, err = c.readHandshake(hs.transcript)
+ msg, err = c.readHandshake(nil) // [UTLS] we don't write to transcript until make sure it is not compressed cert
if err != nil {
return err
}
@@ -548,9 +603,10 @@ func (hs *clientHandshakeStateTLS13) readServerCertificate() error {
return errors.New("tls: received empty certificates message")
}
// [UTLS SECTION BEGINS]
- // Previously, this was simply 'hs.transcript.Write(certMsg.marshal())' (without the if).
- if !skipWritingCertToTranscript {
- hs.transcript.Write(certMsg.marshal())
+ if !skipWritingCertToTranscript { // write to transcript only if it is not compressedCert (i.e. if not processed by extension)
+ if err = transcriptMsg(certMsg, hs.transcript); err != nil {
+ return err
+ }
}
// [UTLS SECTION ENDS]
@@ -561,7 +617,10 @@ func (hs *clientHandshakeStateTLS13) readServerCertificate() error {
return err
}
- msg, err = c.readHandshake()
+ // certificateVerifyMsg is included in the transcript, but not until
+ // after we verify the handshake signature, since the state before
+ // this message was sent is used.
+ msg, err = c.readHandshake(nil)
if err != nil {
return err
}
@@ -575,7 +634,7 @@ func (hs *clientHandshakeStateTLS13) readServerCertificate() error {
// See RFC 8446, Section 4.4.3.
if !isSupportedSignatureAlgorithm(certVerify.signatureAlgorithm, supportedSignatureAlgorithms()) {
c.sendAlert(alertIllegalParameter)
- return errors.New("tls: certificate used with invalid signature algorithm -- not implemented")
+ return errors.New("tls: certificate used with invalid signature algorithm")
}
sigType, sigHash, err := typeAndHashFromSignatureScheme(certVerify.signatureAlgorithm)
if err != nil {
@@ -583,7 +642,7 @@ func (hs *clientHandshakeStateTLS13) readServerCertificate() error {
}
if sigType == signaturePKCS1v15 || sigHash == crypto.SHA1 {
c.sendAlert(alertIllegalParameter)
- return errors.New("tls: certificate used with invalid signature algorithm -- obsolete")
+ return errors.New("tls: certificate used with invalid signature algorithm")
}
signed := signedMessage(sigHash, serverSignatureContext, hs.transcript)
if err := verifyHandshakeSignature(sigType, c.peerCertificates[0].PublicKey,
@@ -592,7 +651,9 @@ func (hs *clientHandshakeStateTLS13) readServerCertificate() error {
return errors.New("tls: invalid signature by the server certificate: " + err.Error())
}
- hs.transcript.Write(certVerify.marshal())
+ if err := transcriptMsg(certVerify, hs.transcript); err != nil {
+ return err
+ }
return nil
}
@@ -600,7 +661,10 @@ func (hs *clientHandshakeStateTLS13) readServerCertificate() error {
func (hs *clientHandshakeStateTLS13) readServerFinished() error {
c := hs.c
- msg, err := c.readHandshake()
+ // finishedMsg is included in the transcript, but not until after we
+ // check the client version, since the state before this message was
+ // sent is used during verification.
+ msg, err := c.readHandshake(nil)
if err != nil {
return err
}
@@ -617,7 +681,9 @@ func (hs *clientHandshakeStateTLS13) readServerFinished() error {
return errors.New("tls: invalid server finished hash")
}
- hs.transcript.Write(finished.marshal())
+ if err := transcriptMsg(finished, hs.transcript); err != nil {
+ return err
+ }
// Derive secrets that take context through the server Finished.
@@ -666,8 +732,7 @@ func (hs *clientHandshakeStateTLS13) sendClientCertificate() error {
certMsg.scts = hs.certReq.scts && len(cert.SignedCertificateTimestamps) > 0
certMsg.ocspStapling = hs.certReq.ocspStapling && len(cert.OCSPStaple) > 0
- hs.transcript.Write(certMsg.marshal())
- if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil {
+ if _, err := hs.c.writeHandshakeRecord(certMsg, hs.transcript); err != nil {
return err
}
@@ -704,8 +769,7 @@ func (hs *clientHandshakeStateTLS13) sendClientCertificate() error {
}
certVerifyMsg.signature = sig
- hs.transcript.Write(certVerifyMsg.marshal())
- if _, err := c.writeRecord(recordTypeHandshake, certVerifyMsg.marshal()); err != nil {
+ if _, err := hs.c.writeHandshakeRecord(certVerifyMsg, hs.transcript); err != nil {
return err
}
@@ -719,8 +783,7 @@ func (hs *clientHandshakeStateTLS13) sendClientFinished() error {
verifyData: hs.suite.finishedHash(c.out.trafficSecret, hs.transcript),
}
- hs.transcript.Write(finished.marshal())
- if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil {
+ if _, err := hs.c.writeHandshakeRecord(finished, hs.transcript); err != nil {
return err
}
diff --git a/handshake_messages.go b/handshake_messages.go
index 6d81cc7..296512e 100644
--- a/handshake_messages.go
+++ b/handshake_messages.go
@@ -5,6 +5,7 @@
package tls
import (
+ "errors"
"fmt"
"strings"
@@ -98,9 +99,181 @@ type clientHelloMsg struct {
nextProtoNeg bool
}
-func (m *clientHelloMsg) marshal() []byte {
+func (m *clientHelloMsg) marshal() ([]byte, error) {
if m.raw != nil {
- return m.raw
+ return m.raw, nil
+ }
+
+ var exts cryptobyte.Builder
+ if len(m.serverName) > 0 {
+ // RFC 6066, Section 3
+ exts.AddUint16(extensionServerName)
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddUint8(0) // name_type = host_name
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddBytes([]byte(m.serverName))
+ })
+ })
+ })
+ }
+ if m.ocspStapling {
+ // RFC 4366, Section 3.6
+ exts.AddUint16(extensionStatusRequest)
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddUint8(1) // status_type = ocsp
+ exts.AddUint16(0) // empty responder_id_list
+ exts.AddUint16(0) // empty request_extensions
+ })
+ }
+ if len(m.supportedCurves) > 0 {
+ // RFC 4492, sections 5.1.1 and RFC 8446, Section 4.2.7
+ exts.AddUint16(extensionSupportedCurves)
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ for _, curve := range m.supportedCurves {
+ exts.AddUint16(uint16(curve))
+ }
+ })
+ })
+ }
+ if len(m.supportedPoints) > 0 {
+ // RFC 4492, Section 5.1.2
+ exts.AddUint16(extensionSupportedPoints)
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddBytes(m.supportedPoints)
+ })
+ })
+ }
+ if m.ticketSupported {
+ // RFC 5077, Section 3.2
+ exts.AddUint16(extensionSessionTicket)
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddBytes(m.sessionTicket)
+ })
+ }
+ if len(m.supportedSignatureAlgorithms) > 0 {
+ // RFC 5246, Section 7.4.1.4.1
+ exts.AddUint16(extensionSignatureAlgorithms)
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ for _, sigAlgo := range m.supportedSignatureAlgorithms {
+ exts.AddUint16(uint16(sigAlgo))
+ }
+ })
+ })
+ }
+ if len(m.supportedSignatureAlgorithmsCert) > 0 {
+ // RFC 8446, Section 4.2.3
+ exts.AddUint16(extensionSignatureAlgorithmsCert)
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ for _, sigAlgo := range m.supportedSignatureAlgorithmsCert {
+ exts.AddUint16(uint16(sigAlgo))
+ }
+ })
+ })
+ }
+ if m.secureRenegotiationSupported {
+ // RFC 5746, Section 3.2
+ exts.AddUint16(extensionRenegotiationInfo)
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddBytes(m.secureRenegotiation)
+ })
+ })
+ }
+ if len(m.alpnProtocols) > 0 {
+ // RFC 7301, Section 3.1
+ exts.AddUint16(extensionALPN)
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ for _, proto := range m.alpnProtocols {
+ exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddBytes([]byte(proto))
+ })
+ }
+ })
+ })
+ }
+ if m.scts {
+ // RFC 6962, Section 3.3.1
+ exts.AddUint16(extensionSCT)
+ exts.AddUint16(0) // empty extension_data
+ }
+ if len(m.supportedVersions) > 0 {
+ // RFC 8446, Section 4.2.1
+ exts.AddUint16(extensionSupportedVersions)
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
+ for _, vers := range m.supportedVersions {
+ exts.AddUint16(vers)
+ }
+ })
+ })
+ }
+ if len(m.cookie) > 0 {
+ // RFC 8446, Section 4.2.2
+ exts.AddUint16(extensionCookie)
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddBytes(m.cookie)
+ })
+ })
+ }
+ if len(m.keyShares) > 0 {
+ // RFC 8446, Section 4.2.8
+ exts.AddUint16(extensionKeyShare)
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ for _, ks := range m.keyShares {
+ exts.AddUint16(uint16(ks.group))
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddBytes(ks.data)
+ })
+ }
+ })
+ })
+ }
+ if m.earlyData {
+ // RFC 8446, Section 4.2.10
+ exts.AddUint16(extensionEarlyData)
+ exts.AddUint16(0) // empty extension_data
+ }
+ if len(m.pskModes) > 0 {
+ // RFC 8446, Section 4.2.9
+ exts.AddUint16(extensionPSKModes)
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddBytes(m.pskModes)
+ })
+ })
+ }
+ if len(m.pskIdentities) > 0 { // pre_shared_key must be the last extension
+ // RFC 8446, Section 4.2.11
+ exts.AddUint16(extensionPreSharedKey)
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ for _, psk := range m.pskIdentities {
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddBytes(psk.label)
+ })
+ exts.AddUint32(psk.obfuscatedTicketAge)
+ }
+ })
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ for _, binder := range m.pskBinders {
+ exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddBytes(binder)
+ })
+ }
+ })
+ })
+ }
+ extBytes, err := exts.Bytes()
+ if err != nil {
+ return nil, err
}
var b cryptobyte.Builder
@@ -120,219 +293,53 @@ func (m *clientHelloMsg) marshal() []byte {
b.AddBytes(m.compressionMethods)
})
- // If extensions aren't present, omit them.
- var extensionsPresent bool
- bWithoutExtensions := *b
-
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- if len(m.serverName) > 0 {
- // RFC 6066, Section 3
- b.AddUint16(extensionServerName)
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddUint8(0) // name_type = host_name
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddBytes([]byte(m.serverName))
- })
- })
- })
- }
- if m.ocspStapling {
- // RFC 4366, Section 3.6
- b.AddUint16(extensionStatusRequest)
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddUint8(1) // status_type = ocsp
- b.AddUint16(0) // empty responder_id_list
- b.AddUint16(0) // empty request_extensions
- })
- }
- if len(m.supportedCurves) > 0 {
- // RFC 4492, sections 5.1.1 and RFC 8446, Section 4.2.7
- b.AddUint16(extensionSupportedCurves)
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- for _, curve := range m.supportedCurves {
- b.AddUint16(uint16(curve))
- }
- })
- })
- }
- if len(m.supportedPoints) > 0 {
- // RFC 4492, Section 5.1.2
- b.AddUint16(extensionSupportedPoints)
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddBytes(m.supportedPoints)
- })
- })
- }
- if m.ticketSupported {
- // RFC 5077, Section 3.2
- b.AddUint16(extensionSessionTicket)
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddBytes(m.sessionTicket)
- })
- }
- if len(m.supportedSignatureAlgorithms) > 0 {
- // RFC 5246, Section 7.4.1.4.1
- b.AddUint16(extensionSignatureAlgorithms)
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- for _, sigAlgo := range m.supportedSignatureAlgorithms {
- b.AddUint16(uint16(sigAlgo))
- }
- })
- })
- }
- if len(m.supportedSignatureAlgorithmsCert) > 0 {
- // RFC 8446, Section 4.2.3
- b.AddUint16(extensionSignatureAlgorithmsCert)
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- for _, sigAlgo := range m.supportedSignatureAlgorithmsCert {
- b.AddUint16(uint16(sigAlgo))
- }
- })
- })
- }
- if m.secureRenegotiationSupported {
- // RFC 5746, Section 3.2
- b.AddUint16(extensionRenegotiationInfo)
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddBytes(m.secureRenegotiation)
- })
- })
- }
- if len(m.alpnProtocols) > 0 {
- // RFC 7301, Section 3.1
- b.AddUint16(extensionALPN)
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- for _, proto := range m.alpnProtocols {
- b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddBytes([]byte(proto))
- })
- }
- })
- })
- }
- if m.scts {
- // RFC 6962, Section 3.3.1
- b.AddUint16(extensionSCT)
- b.AddUint16(0) // empty extension_data
- }
- if len(m.supportedVersions) > 0 {
- // RFC 8446, Section 4.2.1
- b.AddUint16(extensionSupportedVersions)
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
- for _, vers := range m.supportedVersions {
- b.AddUint16(vers)
- }
- })
- })
- }
- if len(m.cookie) > 0 {
- // RFC 8446, Section 4.2.2
- b.AddUint16(extensionCookie)
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddBytes(m.cookie)
- })
- })
- }
- if len(m.keyShares) > 0 {
- // RFC 8446, Section 4.2.8
- b.AddUint16(extensionKeyShare)
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- for _, ks := range m.keyShares {
- b.AddUint16(uint16(ks.group))
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddBytes(ks.data)
- })
- }
- })
- })
- }
- if m.earlyData {
- // RFC 8446, Section 4.2.10
- b.AddUint16(extensionEarlyData)
- b.AddUint16(0) // empty extension_data
- }
- if len(m.pskModes) > 0 {
- // RFC 8446, Section 4.2.9
- b.AddUint16(extensionPSKModes)
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddBytes(m.pskModes)
- })
- })
- }
- if len(m.pskIdentities) > 0 { // pre_shared_key must be the last extension
- // RFC 8446, Section 4.2.11
- b.AddUint16(extensionPreSharedKey)
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- for _, psk := range m.pskIdentities {
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddBytes(psk.label)
- })
- b.AddUint32(psk.obfuscatedTicketAge)
- }
- })
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- for _, binder := range m.pskBinders {
- b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddBytes(binder)
- })
- }
- })
- })
- }
-
- extensionsPresent = len(b.BytesOrPanic()) > 2
- })
-
- if !extensionsPresent {
- *b = bWithoutExtensions
+ if len(extBytes) > 0 {
+ b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+ b.AddBytes(extBytes)
+ })
}
})
- m.raw = b.BytesOrPanic()
- return m.raw
+ m.raw, err = b.Bytes()
+ return m.raw, err
}
// marshalWithoutBinders returns the ClientHello through the
-// PreSharedKeyExtension.identities field, according to RFC 8446, Section
+// FakePreSharedKeyExtension.identities field, according to RFC 8446, Section
// 4.2.11.2. Note that m.pskBinders must be set to slices of the correct length.
-func (m *clientHelloMsg) marshalWithoutBinders() []byte {
+func (m *clientHelloMsg) marshalWithoutBinders() ([]byte, error) {
bindersLen := 2 // uint16 length prefix
for _, binder := range m.pskBinders {
bindersLen += 1 // uint8 length prefix
bindersLen += len(binder)
}
- fullMessage := m.marshal()
- return fullMessage[:len(fullMessage)-bindersLen]
+ fullMessage, err := m.marshal()
+ if err != nil {
+ return nil, err
+ }
+ return fullMessage[:len(fullMessage)-bindersLen], nil
}
// updateBinders updates the m.pskBinders field, if necessary updating the
// cached marshaled representation. The supplied binders must have the same
// length as the current m.pskBinders.
-func (m *clientHelloMsg) updateBinders(pskBinders [][]byte) {
+func (m *clientHelloMsg) updateBinders(pskBinders [][]byte) error {
if len(pskBinders) != len(m.pskBinders) {
- panic("tls: internal error: pskBinders length mismatch")
+ return errors.New("tls: internal error: pskBinders length mismatch")
}
for i := range m.pskBinders {
if len(pskBinders[i]) != len(m.pskBinders[i]) {
- panic("tls: internal error: pskBinders length mismatch")
+ return errors.New("tls: internal error: pskBinders length mismatch")
}
}
m.pskBinders = pskBinders
if m.raw != nil {
- lenWithoutBinders := len(m.marshalWithoutBinders())
+ helloBytes, err := m.marshalWithoutBinders()
+ if err != nil {
+ return err
+ }
+ lenWithoutBinders := len(helloBytes)
b := cryptobyte.NewFixedBuilder(m.raw[:lenWithoutBinders])
b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
for _, binder := range m.pskBinders {
@@ -342,9 +349,11 @@ func (m *clientHelloMsg) updateBinders(pskBinders [][]byte) {
}
})
if out, err := b.Bytes(); err != nil || len(out) != len(m.raw) {
- panic("tls: internal error: failed to update binders")
+ return errors.New("tls: internal error: failed to update binders")
}
}
+
+ return nil
}
func (m *clientHelloMsg) unmarshal(data []byte) bool {
@@ -627,9 +636,98 @@ type serverHelloMsg struct {
nextProtos []string
}
-func (m *serverHelloMsg) marshal() []byte {
+func (m *serverHelloMsg) marshal() ([]byte, error) {
if m.raw != nil {
- return m.raw
+ return m.raw, nil
+ }
+
+ var exts cryptobyte.Builder
+ if m.ocspStapling {
+ exts.AddUint16(extensionStatusRequest)
+ exts.AddUint16(0) // empty extension_data
+ }
+ if m.ticketSupported {
+ exts.AddUint16(extensionSessionTicket)
+ exts.AddUint16(0) // empty extension_data
+ }
+ if m.secureRenegotiationSupported {
+ exts.AddUint16(extensionRenegotiationInfo)
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddBytes(m.secureRenegotiation)
+ })
+ })
+ }
+ if len(m.alpnProtocol) > 0 {
+ exts.AddUint16(extensionALPN)
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddBytes([]byte(m.alpnProtocol))
+ })
+ })
+ })
+ }
+ if len(m.scts) > 0 {
+ exts.AddUint16(extensionSCT)
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ for _, sct := range m.scts {
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddBytes(sct)
+ })
+ }
+ })
+ })
+ }
+ if m.supportedVersion != 0 {
+ exts.AddUint16(extensionSupportedVersions)
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddUint16(m.supportedVersion)
+ })
+ }
+ if m.serverShare.group != 0 {
+ exts.AddUint16(extensionKeyShare)
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddUint16(uint16(m.serverShare.group))
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddBytes(m.serverShare.data)
+ })
+ })
+ }
+ if m.selectedIdentityPresent {
+ exts.AddUint16(extensionPreSharedKey)
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddUint16(m.selectedIdentity)
+ })
+ }
+
+ if len(m.cookie) > 0 {
+ exts.AddUint16(extensionCookie)
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddBytes(m.cookie)
+ })
+ })
+ }
+ if m.selectedGroup != 0 {
+ exts.AddUint16(extensionKeyShare)
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddUint16(uint16(m.selectedGroup))
+ })
+ }
+ if len(m.supportedPoints) > 0 {
+ exts.AddUint16(extensionSupportedPoints)
+ exts.AddUint16LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddUint8LengthPrefixed(func(exts *cryptobyte.Builder) {
+ exts.AddBytes(m.supportedPoints)
+ })
+ })
+ }
+
+ extBytes, err := exts.Bytes()
+ if err != nil {
+ return nil, err
}
var b cryptobyte.Builder
@@ -643,104 +741,15 @@ func (m *serverHelloMsg) marshal() []byte {
b.AddUint16(m.cipherSuite)
b.AddUint8(m.compressionMethod)
- // If extensions aren't present, omit them.
- var extensionsPresent bool
- bWithoutExtensions := *b
-
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- if m.ocspStapling {
- b.AddUint16(extensionStatusRequest)
- b.AddUint16(0) // empty extension_data
- }
- if m.ticketSupported {
- b.AddUint16(extensionSessionTicket)
- b.AddUint16(0) // empty extension_data
- }
- if m.secureRenegotiationSupported {
- b.AddUint16(extensionRenegotiationInfo)
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddBytes(m.secureRenegotiation)
- })
- })
- }
- if len(m.alpnProtocol) > 0 {
- b.AddUint16(extensionALPN)
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddBytes([]byte(m.alpnProtocol))
- })
- })
- })
- }
- if len(m.scts) > 0 {
- b.AddUint16(extensionSCT)
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- for _, sct := range m.scts {
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddBytes(sct)
- })
- }
- })
- })
- }
- if m.supportedVersion != 0 {
- b.AddUint16(extensionSupportedVersions)
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddUint16(m.supportedVersion)
- })
- }
- if m.serverShare.group != 0 {
- b.AddUint16(extensionKeyShare)
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddUint16(uint16(m.serverShare.group))
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddBytes(m.serverShare.data)
- })
- })
- }
- if m.selectedIdentityPresent {
- b.AddUint16(extensionPreSharedKey)
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddUint16(m.selectedIdentity)
- })
- }
-
- if len(m.cookie) > 0 {
- b.AddUint16(extensionCookie)
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddBytes(m.cookie)
- })
- })
- }
- if m.selectedGroup != 0 {
- b.AddUint16(extensionKeyShare)
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddUint16(uint16(m.selectedGroup))
- })
- }
- if len(m.supportedPoints) > 0 {
- b.AddUint16(extensionSupportedPoints)
- b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
- b.AddBytes(m.supportedPoints)
- })
- })
- }
-
- extensionsPresent = len(b.BytesOrPanic()) > 2
- })
-
- if !extensionsPresent {
- *b = bWithoutExtensions
+ if len(extBytes) > 0 {
+ b.AddUint16LengthPrefixed(func(b *cryptobyte.Builder) {
+ b.AddBytes(extBytes)
+ })
}
})
- m.raw = b.BytesOrPanic()
- return m.raw
+ m.raw, err = b.Bytes()
+ return m.raw, err
}
func (m *serverHelloMsg) unmarshal(data []byte) bool {
@@ -872,9 +881,9 @@ type encryptedExtensionsMsg struct {
utls utlsEncryptedExtensionsMsgExtraFields // [uTLS]
}
-func (m *encryptedExtensionsMsg) marshal() []byte {
+func (m *encryptedExtensionsMsg) marshal() ([]byte, error) {
if m.raw != nil {
- return m.raw
+ return m.raw, nil
}
var b cryptobyte.Builder
@@ -894,8 +903,9 @@ func (m *encryptedExtensionsMsg) marshal() []byte {
})
})
- m.raw = b.BytesOrPanic()
- return m.raw
+ var err error
+ m.raw, err = b.Bytes()
+ return m.raw, err
}
func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool {
@@ -948,10 +958,10 @@ func (m *encryptedExtensionsMsg) unmarshal(data []byte) bool {
type endOfEarlyDataMsg struct{}
-func (m *endOfEarlyDataMsg) marshal() []byte {
+func (m *endOfEarlyDataMsg) marshal() ([]byte, error) {
x := make([]byte, 4)
x[0] = typeEndOfEarlyData
- return x
+ return x, nil
}
func (m *endOfEarlyDataMsg) unmarshal(data []byte) bool {
@@ -963,9 +973,9 @@ type keyUpdateMsg struct {
updateRequested bool
}
-func (m *keyUpdateMsg) marshal() []byte {
+func (m *keyUpdateMsg) marshal() ([]byte, error) {
if m.raw != nil {
- return m.raw
+ return m.raw, nil
}
var b cryptobyte.Builder
@@ -978,8 +988,9 @@ func (m *keyUpdateMsg) marshal() []byte {
}
})
- m.raw = b.BytesOrPanic()
- return m.raw
+ var err error
+ m.raw, err = b.Bytes()
+ return m.raw, err
}
func (m *keyUpdateMsg) unmarshal(data []byte) bool {
@@ -1011,9 +1022,9 @@ type newSessionTicketMsgTLS13 struct {
maxEarlyData uint32
}
-func (m *newSessionTicketMsgTLS13) marshal() []byte {
+func (m *newSessionTicketMsgTLS13) marshal() ([]byte, error) {
if m.raw != nil {
- return m.raw
+ return m.raw, nil
}
var b cryptobyte.Builder
@@ -1038,8 +1049,9 @@ func (m *newSessionTicketMsgTLS13) marshal() []byte {
})
})
- m.raw = b.BytesOrPanic()
- return m.raw
+ var err error
+ m.raw, err = b.Bytes()
+ return m.raw, err
}
func (m *newSessionTicketMsgTLS13) unmarshal(data []byte) bool {
@@ -1092,9 +1104,9 @@ type certificateRequestMsgTLS13 struct {
certificateAuthorities [][]byte
}
-func (m *certificateRequestMsgTLS13) marshal() []byte {
+func (m *certificateRequestMsgTLS13) marshal() ([]byte, error) {
if m.raw != nil {
- return m.raw
+ return m.raw, nil
}
var b cryptobyte.Builder
@@ -1153,8 +1165,9 @@ func (m *certificateRequestMsgTLS13) marshal() []byte {
})
})
- m.raw = b.BytesOrPanic()
- return m.raw
+ var err error
+ m.raw, err = b.Bytes()
+ return m.raw, err
}
func (m *certificateRequestMsgTLS13) unmarshal(data []byte) bool {
@@ -1238,9 +1251,9 @@ type certificateMsg struct {
certificates [][]byte
}
-func (m *certificateMsg) marshal() (x []byte) {
+func (m *certificateMsg) marshal() ([]byte, error) {
if m.raw != nil {
- return m.raw
+ return m.raw, nil
}
var i int
@@ -1249,7 +1262,7 @@ func (m *certificateMsg) marshal() (x []byte) {
}
length := 3 + 3*len(m.certificates) + i
- x = make([]byte, 4+length)
+ x := make([]byte, 4+length)
x[0] = typeCertificate
x[1] = uint8(length >> 16)
x[2] = uint8(length >> 8)
@@ -1270,7 +1283,7 @@ func (m *certificateMsg) marshal() (x []byte) {
}
m.raw = x
- return
+ return m.raw, nil
}
func (m *certificateMsg) unmarshal(data []byte) bool {
@@ -1317,9 +1330,9 @@ type certificateMsgTLS13 struct {
scts bool
}
-func (m *certificateMsgTLS13) marshal() []byte {
+func (m *certificateMsgTLS13) marshal() ([]byte, error) {
if m.raw != nil {
- return m.raw
+ return m.raw, nil
}
var b cryptobyte.Builder
@@ -1337,8 +1350,9 @@ func (m *certificateMsgTLS13) marshal() []byte {
marshalCertificate(b, certificate)
})
- m.raw = b.BytesOrPanic()
- return m.raw
+ var err error
+ m.raw, err = b.Bytes()
+ return m.raw, err
}
func marshalCertificate(b *cryptobyte.Builder, certificate Certificate) {
@@ -1461,9 +1475,9 @@ type serverKeyExchangeMsg struct {
key []byte
}
-func (m *serverKeyExchangeMsg) marshal() []byte {
+func (m *serverKeyExchangeMsg) marshal() ([]byte, error) {
if m.raw != nil {
- return m.raw
+ return m.raw, nil
}
length := len(m.key)
x := make([]byte, length+4)
@@ -1474,7 +1488,7 @@ func (m *serverKeyExchangeMsg) marshal() []byte {
copy(x[4:], m.key)
m.raw = x
- return x
+ return x, nil
}
func (m *serverKeyExchangeMsg) unmarshal(data []byte) bool {
@@ -1491,9 +1505,9 @@ type certificateStatusMsg struct {
response []byte
}
-func (m *certificateStatusMsg) marshal() []byte {
+func (m *certificateStatusMsg) marshal() ([]byte, error) {
if m.raw != nil {
- return m.raw
+ return m.raw, nil
}
var b cryptobyte.Builder
@@ -1505,8 +1519,9 @@ func (m *certificateStatusMsg) marshal() []byte {
})
})
- m.raw = b.BytesOrPanic()
- return m.raw
+ var err error
+ m.raw, err = b.Bytes()
+ return m.raw, err
}
func (m *certificateStatusMsg) unmarshal(data []byte) bool {
@@ -1525,10 +1540,10 @@ func (m *certificateStatusMsg) unmarshal(data []byte) bool {
type serverHelloDoneMsg struct{}
-func (m *serverHelloDoneMsg) marshal() []byte {
+func (m *serverHelloDoneMsg) marshal() ([]byte, error) {
x := make([]byte, 4)
x[0] = typeServerHelloDone
- return x
+ return x, nil
}
func (m *serverHelloDoneMsg) unmarshal(data []byte) bool {
@@ -1540,9 +1555,9 @@ type clientKeyExchangeMsg struct {
ciphertext []byte
}
-func (m *clientKeyExchangeMsg) marshal() []byte {
+func (m *clientKeyExchangeMsg) marshal() ([]byte, error) {
if m.raw != nil {
- return m.raw
+ return m.raw, nil
}
length := len(m.ciphertext)
x := make([]byte, length+4)
@@ -1553,7 +1568,7 @@ func (m *clientKeyExchangeMsg) marshal() []byte {
copy(x[4:], m.ciphertext)
m.raw = x
- return x
+ return x, nil
}
func (m *clientKeyExchangeMsg) unmarshal(data []byte) bool {
@@ -1574,9 +1589,9 @@ type finishedMsg struct {
verifyData []byte
}
-func (m *finishedMsg) marshal() []byte {
+func (m *finishedMsg) marshal() ([]byte, error) {
if m.raw != nil {
- return m.raw
+ return m.raw, nil
}
var b cryptobyte.Builder
@@ -1585,8 +1600,9 @@ func (m *finishedMsg) marshal() []byte {
b.AddBytes(m.verifyData)
})
- m.raw = b.BytesOrPanic()
- return m.raw
+ var err error
+ m.raw, err = b.Bytes()
+ return m.raw, err
}
func (m *finishedMsg) unmarshal(data []byte) bool {
@@ -1608,9 +1624,9 @@ type certificateRequestMsg struct {
certificateAuthorities [][]byte
}
-func (m *certificateRequestMsg) marshal() (x []byte) {
+func (m *certificateRequestMsg) marshal() ([]byte, error) {
if m.raw != nil {
- return m.raw
+ return m.raw, nil
}
// See RFC 4346, Section 7.4.4.
@@ -1625,7 +1641,7 @@ func (m *certificateRequestMsg) marshal() (x []byte) {
length += 2 + 2*len(m.supportedSignatureAlgorithms)
}
- x = make([]byte, 4+length)
+ x := make([]byte, 4+length)
x[0] = typeCertificateRequest
x[1] = uint8(length >> 16)
x[2] = uint8(length >> 8)
@@ -1660,7 +1676,7 @@ func (m *certificateRequestMsg) marshal() (x []byte) {
}
m.raw = x
- return
+ return m.raw, nil
}
func (m *certificateRequestMsg) unmarshal(data []byte) bool {
@@ -1746,9 +1762,9 @@ type certificateVerifyMsg struct {
signature []byte
}
-func (m *certificateVerifyMsg) marshal() (x []byte) {
+func (m *certificateVerifyMsg) marshal() ([]byte, error) {
if m.raw != nil {
- return m.raw
+ return m.raw, nil
}
var b cryptobyte.Builder
@@ -1762,8 +1778,9 @@ func (m *certificateVerifyMsg) marshal() (x []byte) {
})
})
- m.raw = b.BytesOrPanic()
- return m.raw
+ var err error
+ m.raw, err = b.Bytes()
+ return m.raw, err
}
func (m *certificateVerifyMsg) unmarshal(data []byte) bool {
@@ -1786,15 +1803,15 @@ type newSessionTicketMsg struct {
ticket []byte
}
-func (m *newSessionTicketMsg) marshal() (x []byte) {
+func (m *newSessionTicketMsg) marshal() ([]byte, error) {
if m.raw != nil {
- return m.raw
+ return m.raw, nil
}
// See RFC 5077, Section 3.3.
ticketLen := len(m.ticket)
length := 2 + 4 + ticketLen
- x = make([]byte, 4+length)
+ x := make([]byte, 4+length)
x[0] = typeNewSessionTicket
x[1] = uint8(length >> 16)
x[2] = uint8(length >> 8)
@@ -1805,7 +1822,7 @@ func (m *newSessionTicketMsg) marshal() (x []byte) {
m.raw = x
- return
+ return m.raw, nil
}
func (m *newSessionTicketMsg) unmarshal(data []byte) bool {
@@ -1833,10 +1850,25 @@ func (m *newSessionTicketMsg) unmarshal(data []byte) bool {
type helloRequestMsg struct {
}
-func (*helloRequestMsg) marshal() []byte {
- return []byte{typeHelloRequest, 0, 0, 0}
+func (*helloRequestMsg) marshal() ([]byte, error) {
+ return []byte{typeHelloRequest, 0, 0, 0}, nil
}
func (*helloRequestMsg) unmarshal(data []byte) bool {
return len(data) == 4
}
+
+type transcriptHash interface {
+ Write([]byte) (int, error)
+}
+
+// transcriptMsg is a helper used to marshal and hash messages which typically
+// are not written to the wire, and as such aren't hashed during Conn.writeRecord.
+func transcriptMsg(msg handshakeMessage, h transcriptHash) error {
+ data, err := msg.marshal()
+ if err != nil {
+ return err
+ }
+ h.Write(data)
+ return nil
+}
diff --git a/handshake_messages_test.go b/handshake_messages_test.go
index 5597a63..618c852 100644
--- a/handshake_messages_test.go
+++ b/handshake_messages_test.go
@@ -39,6 +39,15 @@ var tests = []any{
&utlsCompressedCertificateMsg{}, // [UTLS]
}
+func mustMarshal(t *testing.T, msg handshakeMessage) []byte {
+ t.Helper()
+ b, err := msg.marshal()
+ if err != nil {
+ t.Fatal(err)
+ }
+ return b
+}
+
func TestMarshalUnmarshal(t *testing.T) {
rand := rand.New(rand.NewSource(time.Now().UnixNano()))
@@ -57,7 +66,7 @@ func TestMarshalUnmarshal(t *testing.T) {
}
m1 := v.Interface().(handshakeMessage)
- marshaled := m1.marshal()
+ marshaled := mustMarshal(t, m1)
m2 := iface.(handshakeMessage)
if !m2.unmarshal(marshaled) {
t.Errorf("#%d failed to unmarshal %#v %x", i, m1, marshaled)
@@ -419,12 +428,12 @@ func TestRejectEmptySCTList(t *testing.T) {
var random [32]byte
sct := []byte{0x42, 0x42, 0x42, 0x42}
- serverHello := serverHelloMsg{
+ serverHello := &serverHelloMsg{
vers: VersionTLS12,
random: random[:],
scts: [][]byte{sct},
}
- serverHelloBytes := serverHello.marshal()
+ serverHelloBytes := mustMarshal(t, serverHello)
var serverHelloCopy serverHelloMsg
if !serverHelloCopy.unmarshal(serverHelloBytes) {
@@ -462,12 +471,12 @@ func TestRejectEmptySCT(t *testing.T) {
// not be zero length.
var random [32]byte
- serverHello := serverHelloMsg{
+ serverHello := &serverHelloMsg{
vers: VersionTLS12,
random: random[:],
scts: [][]byte{nil},
}
- serverHelloBytes := serverHello.marshal()
+ serverHelloBytes := mustMarshal(t, serverHello)
var serverHelloCopy serverHelloMsg
if serverHelloCopy.unmarshal(serverHelloBytes) {
diff --git a/handshake_server.go b/handshake_server.go
index 92b38cb..8cb9acf 100644
--- a/handshake_server.go
+++ b/handshake_server.go
@@ -129,7 +129,9 @@ func (hs *serverHandshakeState) handshake() error {
// readClientHello reads a ClientHello message and selects the protocol version.
func (c *Conn) readClientHello(ctx context.Context) (*clientHelloMsg, error) {
- msg, err := c.readHandshake()
+ // clientHelloMsg is included in the transcript, but we haven't initialized
+ // it yet. The respective handshake functions will record it themselves.
+ msg, err := c.readHandshake(nil)
if err != nil {
return nil, err
}
@@ -463,9 +465,10 @@ func (hs *serverHandshakeState) doResumeHandshake() error {
hs.hello.ticketSupported = hs.sessionState.usedOldKey
hs.finishedHash = newFinishedHash(c.vers, hs.suite)
hs.finishedHash.discardHandshakeBuffer()
- hs.finishedHash.Write(hs.clientHello.marshal())
- hs.finishedHash.Write(hs.hello.marshal())
- if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil {
+ if err := transcriptMsg(hs.clientHello, &hs.finishedHash); err != nil {
+ return err
+ }
+ if _, err := hs.c.writeHandshakeRecord(hs.hello, &hs.finishedHash); err != nil {
return err
}
@@ -503,24 +506,23 @@ func (hs *serverHandshakeState) doFullHandshake() error {
// certificates won't be used.
hs.finishedHash.discardHandshakeBuffer()
}
- hs.finishedHash.Write(hs.clientHello.marshal())
- hs.finishedHash.Write(hs.hello.marshal())
- if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil {
+ if err := transcriptMsg(hs.clientHello, &hs.finishedHash); err != nil {
+ return err
+ }
+ if _, err := hs.c.writeHandshakeRecord(hs.hello, &hs.finishedHash); err != nil {
return err
}
certMsg := new(certificateMsg)
certMsg.certificates = hs.cert.Certificate
- hs.finishedHash.Write(certMsg.marshal())
- if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil {
+ if _, err := hs.c.writeHandshakeRecord(certMsg, &hs.finishedHash); err != nil {
return err
}
if hs.hello.ocspStapling {
certStatus := new(certificateStatusMsg)
certStatus.response = hs.cert.OCSPStaple
- hs.finishedHash.Write(certStatus.marshal())
- if _, err := c.writeRecord(recordTypeHandshake, certStatus.marshal()); err != nil {
+ if _, err := hs.c.writeHandshakeRecord(certStatus, &hs.finishedHash); err != nil {
return err
}
}
@@ -532,8 +534,7 @@ func (hs *serverHandshakeState) doFullHandshake() error {
return err
}
if skx != nil {
- hs.finishedHash.Write(skx.marshal())
- if _, err := c.writeRecord(recordTypeHandshake, skx.marshal()); err != nil {
+ if _, err := hs.c.writeHandshakeRecord(skx, &hs.finishedHash); err != nil {
return err
}
}
@@ -559,15 +560,13 @@ func (hs *serverHandshakeState) doFullHandshake() error {
if c.config.ClientCAs != nil {
certReq.certificateAuthorities = c.config.ClientCAs.Subjects()
}
- hs.finishedHash.Write(certReq.marshal())
- if _, err := c.writeRecord(recordTypeHandshake, certReq.marshal()); err != nil {
+ if _, err := hs.c.writeHandshakeRecord(certReq, &hs.finishedHash); err != nil {
return err
}
}
helloDone := new(serverHelloDoneMsg)
- hs.finishedHash.Write(helloDone.marshal())
- if _, err := c.writeRecord(recordTypeHandshake, helloDone.marshal()); err != nil {
+ if _, err := hs.c.writeHandshakeRecord(helloDone, &hs.finishedHash); err != nil {
return err
}
@@ -577,7 +576,7 @@ func (hs *serverHandshakeState) doFullHandshake() error {
var pub crypto.PublicKey // public key for client auth, if any
- msg, err := c.readHandshake()
+ msg, err := c.readHandshake(&hs.finishedHash)
if err != nil {
return err
}
@@ -590,7 +589,6 @@ func (hs *serverHandshakeState) doFullHandshake() error {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(certMsg, msg)
}
- hs.finishedHash.Write(certMsg.marshal())
if err := c.processCertsFromClient(Certificate{
Certificate: certMsg.certificates,
@@ -601,7 +599,7 @@ func (hs *serverHandshakeState) doFullHandshake() error {
pub = c.peerCertificates[0].PublicKey
}
- msg, err = c.readHandshake()
+ msg, err = c.readHandshake(&hs.finishedHash)
if err != nil {
return err
}
@@ -619,7 +617,6 @@ func (hs *serverHandshakeState) doFullHandshake() error {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(ckx, msg)
}
- hs.finishedHash.Write(ckx.marshal())
preMasterSecret, err := keyAgreement.processClientKeyExchange(c.config, hs.cert, ckx, c.vers)
if err != nil {
@@ -639,7 +636,10 @@ func (hs *serverHandshakeState) doFullHandshake() error {
// to the client's certificate. This allows us to verify that the client is in
// possession of the private key of the certificate.
if len(c.peerCertificates) > 0 {
- msg, err = c.readHandshake()
+ // certificateVerifyMsg is included in the transcript, but not until
+ // after we verify the handshake signature, since the state before
+ // this message was sent is used.
+ msg, err = c.readHandshake(nil)
if err != nil {
return err
}
@@ -674,7 +674,9 @@ func (hs *serverHandshakeState) doFullHandshake() error {
return errors.New("tls: invalid signature by the client certificate: " + err.Error())
}
- hs.finishedHash.Write(certVerify.marshal())
+ if err := transcriptMsg(certVerify, &hs.finishedHash); err != nil {
+ return err
+ }
}
hs.finishedHash.discardHandshakeBuffer()
@@ -714,7 +716,10 @@ func (hs *serverHandshakeState) readFinished(out []byte) error {
return err
}
- msg, err := c.readHandshake()
+ // finishedMsg is included in the transcript, but not until after we
+ // check the client version, since the state before this message was
+ // sent is used during verification.
+ msg, err := c.readHandshake(nil)
if err != nil {
return err
}
@@ -731,7 +736,10 @@ func (hs *serverHandshakeState) readFinished(out []byte) error {
return errors.New("tls: client's Finished message is incorrect")
}
- hs.finishedHash.Write(clientFinished.marshal())
+ if err := transcriptMsg(clientFinished, &hs.finishedHash); err != nil {
+ return err
+ }
+
copy(out, verify)
return nil
}
@@ -765,14 +773,16 @@ func (hs *serverHandshakeState) sendSessionTicket() error {
masterSecret: hs.masterSecret,
certificates: certsFromClient,
}
- var err error
- m.ticket, err = c.encryptTicket(state.marshal())
+ stateBytes, err := state.marshal()
+ if err != nil {
+ return err
+ }
+ m.ticket, err = c.encryptTicket(stateBytes)
if err != nil {
return err
}
- hs.finishedHash.Write(m.marshal())
- if _, err := c.writeRecord(recordTypeHandshake, m.marshal()); err != nil {
+ if _, err := hs.c.writeHandshakeRecord(m, &hs.finishedHash); err != nil {
return err
}
@@ -782,14 +792,13 @@ func (hs *serverHandshakeState) sendSessionTicket() error {
func (hs *serverHandshakeState) sendFinished(out []byte) error {
c := hs.c
- if _, err := c.writeRecord(recordTypeChangeCipherSpec, []byte{1}); err != nil {
+ if err := c.writeChangeCipherRecord(); err != nil {
return err
}
finished := new(finishedMsg)
finished.verifyData = hs.finishedHash.serverSum(hs.masterSecret)
- hs.finishedHash.Write(finished.marshal())
- if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil {
+ if _, err := hs.c.writeHandshakeRecord(finished, &hs.finishedHash); err != nil {
return err
}
diff --git a/handshake_server_test.go b/handshake_server_test.go
index 39b65f6..b2e8107 100644
--- a/handshake_server_test.go
+++ b/handshake_server_test.go
@@ -30,6 +30,13 @@ func testClientHello(t *testing.T, serverConfig *Config, m handshakeMessage) {
testClientHelloFailure(t, serverConfig, m, "")
}
+// testFatal is a hack to prevent the compiler from complaining that there is a
+// call to t.Fatal from a non-test goroutine
+func testFatal(t *testing.T, err error) {
+ t.Helper()
+ t.Fatal(err)
+}
+
func testClientHelloFailure(t *testing.T, serverConfig *Config, m handshakeMessage, expectedSubStr string) {
c, s := localPipe(t)
go func() {
@@ -37,7 +44,9 @@ func testClientHelloFailure(t *testing.T, serverConfig *Config, m handshakeMessa
if ch, ok := m.(*clientHelloMsg); ok {
cli.vers = ch.vers
}
- cli.writeRecord(recordTypeHandshake, m.marshal())
+ if _, err := cli.writeHandshakeRecord(m, nil); err != nil {
+ testFatal(t, err)
+ }
c.Close()
}()
ctx := context.Background()
@@ -194,7 +203,9 @@ func TestRenegotiationExtension(t *testing.T) {
go func() {
cli := Client(c, testConfig)
cli.vers = clientHello.vers
- cli.writeRecord(recordTypeHandshake, clientHello.marshal())
+ if _, err := cli.writeHandshakeRecord(clientHello, nil); err != nil {
+ testFatal(t, err)
+ }
buf := make([]byte, 1024)
n, err := c.Read(buf)
@@ -253,8 +264,10 @@ func TestTLS12OnlyCipherSuites(t *testing.T) {
go func() {
cli := Client(c, testConfig)
cli.vers = clientHello.vers
- cli.writeRecord(recordTypeHandshake, clientHello.marshal())
- reply, err := cli.readHandshake()
+ if _, err := cli.writeHandshakeRecord(clientHello, nil); err != nil {
+ testFatal(t, err)
+ }
+ reply, err := cli.readHandshake(nil)
c.Close()
if err != nil {
replyChan <- err
@@ -311,8 +324,10 @@ func TestTLSPointFormats(t *testing.T) {
go func() {
cli := Client(c, testConfig)
cli.vers = clientHello.vers
- cli.writeRecord(recordTypeHandshake, clientHello.marshal())
- reply, err := cli.readHandshake()
+ if _, err := cli.writeHandshakeRecord(clientHello, nil); err != nil {
+ testFatal(t, err)
+ }
+ reply, err := cli.readHandshake(nil)
c.Close()
if err != nil {
replyChan <- err
@@ -1426,7 +1441,9 @@ func TestSNIGivenOnFailure(t *testing.T) {
go func() {
cli := Client(c, testConfig)
cli.vers = clientHello.vers
- cli.writeRecord(recordTypeHandshake, clientHello.marshal())
+ if _, err := cli.writeHandshakeRecord(clientHello, nil); err != nil {
+ testFatal(t, err)
+ }
c.Close()
}()
conn := Server(s, serverConfig)
diff --git a/handshake_server_tls13.go b/handshake_server_tls13.go
index 03a477f..0043e16 100644
--- a/handshake_server_tls13.go
+++ b/handshake_server_tls13.go
@@ -302,7 +302,12 @@ func (hs *serverHandshakeStateTLS13) checkForResumption() error {
c.sendAlert(alertInternalError)
return errors.New("tls: internal error: failed to clone hash")
}
- transcript.Write(hs.clientHello.marshalWithoutBinders())
+ clientHelloBytes, err := hs.clientHello.marshalWithoutBinders()
+ if err != nil {
+ c.sendAlert(alertInternalError)
+ return err
+ }
+ transcript.Write(clientHelloBytes)
pskBinder := hs.suite.finishedHash(binderKey, transcript)
if !hmac.Equal(hs.clientHello.pskBinders[i], pskBinder) {
c.sendAlert(alertDecryptError)
@@ -393,8 +398,7 @@ func (hs *serverHandshakeStateTLS13) sendDummyChangeCipherSpec() error {
}
hs.sentDummyCCS = true
- _, err := hs.c.writeRecord(recordTypeChangeCipherSpec, []byte{1})
- return err
+ return hs.c.writeChangeCipherRecord()
}
func (hs *serverHandshakeStateTLS13) doHelloRetryRequest(selectedGroup CurveID) error {
@@ -402,7 +406,9 @@ func (hs *serverHandshakeStateTLS13) doHelloRetryRequest(selectedGroup CurveID)
// The first ClientHello gets double-hashed into the transcript upon a
// HelloRetryRequest. See RFC 8446, Section 4.4.1.
- hs.transcript.Write(hs.clientHello.marshal())
+ if err := transcriptMsg(hs.clientHello, hs.transcript); err != nil {
+ return err
+ }
chHash := hs.transcript.Sum(nil)
hs.transcript.Reset()
hs.transcript.Write([]byte{typeMessageHash, 0, 0, uint8(len(chHash))})
@@ -418,8 +424,7 @@ func (hs *serverHandshakeStateTLS13) doHelloRetryRequest(selectedGroup CurveID)
selectedGroup: selectedGroup,
}
- hs.transcript.Write(helloRetryRequest.marshal())
- if _, err := c.writeRecord(recordTypeHandshake, helloRetryRequest.marshal()); err != nil {
+ if _, err := hs.c.writeHandshakeRecord(helloRetryRequest, hs.transcript); err != nil {
return err
}
@@ -427,7 +432,8 @@ func (hs *serverHandshakeStateTLS13) doHelloRetryRequest(selectedGroup CurveID)
return err
}
- msg, err := c.readHandshake()
+ // clientHelloMsg is not included in the transcript.
+ msg, err := c.readHandshake(nil)
if err != nil {
return err
}
@@ -518,9 +524,10 @@ func illegalClientHelloChange(ch, ch1 *clientHelloMsg) bool {
func (hs *serverHandshakeStateTLS13) sendServerParameters() error {
c := hs.c
- hs.transcript.Write(hs.clientHello.marshal())
- hs.transcript.Write(hs.hello.marshal())
- if _, err := c.writeRecord(recordTypeHandshake, hs.hello.marshal()); err != nil {
+ if err := transcriptMsg(hs.clientHello, hs.transcript); err != nil {
+ return err
+ }
+ if _, err := hs.c.writeHandshakeRecord(hs.hello, hs.transcript); err != nil {
return err
}
@@ -563,8 +570,7 @@ func (hs *serverHandshakeStateTLS13) sendServerParameters() error {
encryptedExtensions.alpnProtocol = selectedProto
c.clientProtocol = selectedProto
- hs.transcript.Write(encryptedExtensions.marshal())
- if _, err := c.writeRecord(recordTypeHandshake, encryptedExtensions.marshal()); err != nil {
+ if _, err := hs.c.writeHandshakeRecord(encryptedExtensions, hs.transcript); err != nil {
return err
}
@@ -593,8 +599,7 @@ func (hs *serverHandshakeStateTLS13) sendServerCertificate() error {
certReq.certificateAuthorities = c.config.ClientCAs.Subjects()
}
- hs.transcript.Write(certReq.marshal())
- if _, err := c.writeRecord(recordTypeHandshake, certReq.marshal()); err != nil {
+ if _, err := hs.c.writeHandshakeRecord(certReq, hs.transcript); err != nil {
return err
}
}
@@ -605,8 +610,7 @@ func (hs *serverHandshakeStateTLS13) sendServerCertificate() error {
certMsg.scts = hs.clientHello.scts && len(hs.cert.SignedCertificateTimestamps) > 0
certMsg.ocspStapling = hs.clientHello.ocspStapling && len(hs.cert.OCSPStaple) > 0
- hs.transcript.Write(certMsg.marshal())
- if _, err := c.writeRecord(recordTypeHandshake, certMsg.marshal()); err != nil {
+ if _, err := hs.c.writeHandshakeRecord(certMsg, hs.transcript); err != nil {
return err
}
@@ -637,8 +641,7 @@ func (hs *serverHandshakeStateTLS13) sendServerCertificate() error {
}
certVerifyMsg.signature = sig
- hs.transcript.Write(certVerifyMsg.marshal())
- if _, err := c.writeRecord(recordTypeHandshake, certVerifyMsg.marshal()); err != nil {
+ if _, err := hs.c.writeHandshakeRecord(certVerifyMsg, hs.transcript); err != nil {
return err
}
@@ -652,8 +655,7 @@ func (hs *serverHandshakeStateTLS13) sendServerFinished() error {
verifyData: hs.suite.finishedHash(c.out.trafficSecret, hs.transcript),
}
- hs.transcript.Write(finished.marshal())
- if _, err := c.writeRecord(recordTypeHandshake, finished.marshal()); err != nil {
+ if _, err := hs.c.writeHandshakeRecord(finished, hs.transcript); err != nil {
return err
}
@@ -714,7 +716,9 @@ func (hs *serverHandshakeStateTLS13) sendSessionTickets() error {
finishedMsg := &finishedMsg{
verifyData: hs.clientFinished,
}
- hs.transcript.Write(finishedMsg.marshal())
+ if err := transcriptMsg(finishedMsg, hs.transcript); err != nil {
+ return err
+ }
if !hs.shouldSendSessionTickets() {
return nil
@@ -739,8 +743,12 @@ func (hs *serverHandshakeStateTLS13) sendSessionTickets() error {
SignedCertificateTimestamps: c.scts,
},
}
- var err error
- m.label, err = c.encryptTicket(state.marshal())
+ stateBytes, err := state.marshal()
+ if err != nil {
+ c.sendAlert(alertInternalError)
+ return err
+ }
+ m.label, err = c.encryptTicket(stateBytes)
if err != nil {
return err
}
@@ -759,7 +767,7 @@ func (hs *serverHandshakeStateTLS13) sendSessionTickets() error {
// ticket_nonce, which must be unique per connection, is always left at
// zero because we only ever send one ticket per connection.
- if _, err := c.writeRecord(recordTypeHandshake, m.marshal()); err != nil {
+ if _, err := c.writeHandshakeRecord(m, nil); err != nil {
return err
}
@@ -784,7 +792,7 @@ func (hs *serverHandshakeStateTLS13) readClientCertificate() error {
// If we requested a client certificate, then the client must send a
// certificate message. If it's empty, no CertificateVerify is sent.
- msg, err := c.readHandshake()
+ msg, err := c.readHandshake(hs.transcript)
if err != nil {
return err
}
@@ -794,7 +802,6 @@ func (hs *serverHandshakeStateTLS13) readClientCertificate() error {
c.sendAlert(alertUnexpectedMessage)
return unexpectedMessageError(certMsg, msg)
}
- hs.transcript.Write(certMsg.marshal())
if err := c.processCertsFromClient(certMsg.certificate); err != nil {
return err
@@ -808,7 +815,10 @@ func (hs *serverHandshakeStateTLS13) readClientCertificate() error {
}
if len(certMsg.certificate.Certificate) != 0 {
- msg, err = c.readHandshake()
+ // certificateVerifyMsg is included in the transcript, but not until
+ // after we verify the handshake signature, since the state before
+ // this message was sent is used.
+ msg, err = c.readHandshake(nil)
if err != nil {
return err
}
@@ -839,7 +849,9 @@ func (hs *serverHandshakeStateTLS13) readClientCertificate() error {
return errors.New("tls: invalid signature by the client certificate: " + err.Error())
}
- hs.transcript.Write(certVerify.marshal())
+ if err := transcriptMsg(certVerify, hs.transcript); err != nil {
+ return err
+ }
}
// If we waited until the client certificates to send session tickets, we
@@ -854,7 +866,8 @@ func (hs *serverHandshakeStateTLS13) readClientCertificate() error {
func (hs *serverHandshakeStateTLS13) readClientFinished() error {
c := hs.c
- msg, err := c.readHandshake()
+ // finishedMsg is not included in the transcript.
+ msg, err := c.readHandshake(nil)
if err != nil {
return err
}
diff --git a/internal/helper/typeconv.go b/internal/helper/typeconv.go
new file mode 100644
index 0000000..73ec058
--- /dev/null
+++ b/internal/helper/typeconv.go
@@ -0,0 +1,23 @@
+package helper
+
+import (
+ "errors"
+
+ "golang.org/x/crypto/cryptobyte"
+)
+
+// Uint8to16 converts a slice of uint8 to a slice of uint16.
+// e.g. []uint8{0x00, 0x01, 0x00, 0x02} -> []uint16{0x0001, 0x0002}
+func Uint8to16(in []uint8) ([]uint16, error) {
+ s := cryptobyte.String(in)
+ var out []uint16
+ for !s.Empty() {
+ var v uint16
+ if s.ReadUint16(&v) {
+ out = append(out, v)
+ } else {
+ return nil, errors.New("ReadUint16 failed")
+ }
+ }
+ return out, nil
+}
diff --git a/key_agreement.go b/key_agreement.go
index 75deeb0..c28a64f 100644
--- a/key_agreement.go
+++ b/key_agreement.go
@@ -319,7 +319,7 @@ func (ka *ecdheKeyAgreement) processServerKeyExchange(config *Config, clientHell
}
if !isSupportedSignatureAlgorithm(signatureAlgorithm, clientHello.supportedSignatureAlgorithms) {
- return fmt.Errorf("tls: certificate used with invalid signature algorithm -- ClientHello not advertising %04x", uint16(signatureAlgorithm))
+ return errors.New("tls: certificate used with invalid signature algorithm")
}
sigType, sigHash, err = typeAndHashFromSignatureScheme(signatureAlgorithm)
if err != nil {
diff --git a/key_schedule.go b/key_schedule.go
index 3140169..185137b 100644
--- a/key_schedule.go
+++ b/key_schedule.go
@@ -8,6 +8,7 @@ import (
"crypto/elliptic"
"crypto/hmac"
"errors"
+ "fmt"
"hash"
"io"
"math/big"
@@ -42,8 +43,24 @@ func (c *cipherSuiteTLS13) expandLabel(secret []byte, label string, context []by
hkdfLabel.AddUint8LengthPrefixed(func(b *cryptobyte.Builder) {
b.AddBytes(context)
})
+ hkdfLabelBytes, err := hkdfLabel.Bytes()
+ if err != nil {
+ // Rather than calling BytesOrPanic, we explicitly handle this error, in
+ // order to provide a reasonable error message. It should be basically
+ // impossible for this to panic, and routing errors back through the
+ // tree rooted in this function is quite painful. The labels are fixed
+ // size, and the context is either a fixed-length computed hash, or
+ // parsed from a field which has the same length limitation. As such, an
+ // error here is likely to only be caused during development.
+ //
+ // NOTE: another reasonable approach here might be to return a
+ // randomized slice if we encounter an error, which would break the
+ // connection, but avoid panicking. This would perhaps be safer but
+ // significantly more confusing to users.
+ panic(fmt.Errorf("failed to construct HKDF label: %s", err))
+ }
out := make([]byte, length)
- n, err := hkdf.Expand(c.hash.New, secret, hkdfLabel.BytesOrPanic()).Read(out)
+ n, err := hkdf.Expand(c.hash.New, secret, hkdfLabelBytes).Read(out)
if err != nil || n != length {
panic("tls: HKDF-Expand-Label invocation failed unexpectedly")
}
diff --git a/notboring.go b/notboring.go
index a8fcdf5..4384069 100644
--- a/notboring.go
+++ b/notboring.go
@@ -36,4 +36,4 @@ func (*Boring) Unreachable() {
// do nothing
}
-var boring Boring
\ No newline at end of file
+var boring Boring
diff --git a/testdata/ClientHello-JSON-Chrome102.json b/testdata/ClientHello-JSON-Chrome102.json
new file mode 100644
index 0000000..2463906
--- /dev/null
+++ b/testdata/ClientHello-JSON-Chrome102.json
@@ -0,0 +1,75 @@
+{
+ "cipher_suites": [
+ "GREASE",
+ "TLS_AES_128_GCM_SHA256",
+ "TLS_AES_256_GCM_SHA384",
+ "TLS_CHACHA20_POLY1305_SHA256",
+ "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256",
+ "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
+ "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384",
+ "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384",
+ "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256",
+ "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256",
+ "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA",
+ "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA",
+ "TLS_RSA_WITH_AES_128_GCM_SHA256",
+ "TLS_RSA_WITH_AES_256_GCM_SHA384",
+ "TLS_RSA_WITH_AES_128_CBC_SHA",
+ "TLS_RSA_WITH_AES_256_CBC_SHA"
+ ],
+ "compression_methods": [
+ "NULL"
+ ],
+ "extensions": [
+ {"name": "GREASE"},
+ {"name": "server_name"},
+ {"name": "extended_master_secret"},
+ {"name": "renegotiation_info"},
+ {"name": "supported_groups", "named_group_list": [
+ "GREASE",
+ "x25519",
+ "secp256r1",
+ "secp384r1"
+ ]},
+ {"name": "ec_point_formats", "ec_point_format_list": [
+ "uncompressed"
+ ]},
+ {"name": "session_ticket"},
+ {"name": "application_layer_protocol_negotiation", "protocol_name_list": [
+ "h2",
+ "http/1.1"
+ ]},
+ {"name": "status_request"},
+ {"name": "signature_algorithms", "supported_signature_algorithms": [
+ "ecdsa_secp256r1_sha256",
+ "rsa_pss_rsae_sha256",
+ "rsa_pkcs1_sha256",
+ "ecdsa_secp384r1_sha384",
+ "rsa_pss_rsae_sha384",
+ "rsa_pkcs1_sha384",
+ "rsa_pss_rsae_sha512",
+ "rsa_pkcs1_sha512"
+ ]},
+ {"name": "signed_certificate_timestamp"},
+ {"name": "key_share", "client_shares": [
+ {"group": "GREASE", "key_exchange": [0]},
+ {"group": "x25519"}
+ ]},
+ {"name": "psk_key_exchange_modes", "ke_modes": [
+ "psk_dhe_ke"
+ ]},
+ {"name": "supported_versions", "versions": [
+ "GREASE",
+ "TLS 1.3",
+ "TLS 1.2"
+ ]},
+ {"name": "compress_certificate", "algorithms": [
+ "brotli"
+ ]},
+ {"name": "application_settings", "supported_protocols": [
+ "h2"
+ ]},
+ {"name": "GREASE"},
+ {"name": "padding", "len": 0}
+ ]
+}
\ No newline at end of file
diff --git a/testdata/ClientHello-JSON-Edge106.json b/testdata/ClientHello-JSON-Edge106.json
new file mode 100644
index 0000000..4f506ca
--- /dev/null
+++ b/testdata/ClientHello-JSON-Edge106.json
@@ -0,0 +1,76 @@
+{
+ "cipher_suites": [
+ "GREASE",
+ "TLS_AES_128_GCM_SHA256",
+ "TLS_AES_256_GCM_SHA384",
+ "TLS_AES_256_GCM_SHA384",
+ "TLS_CHACHA20_POLY1305_SHA256",
+ "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256",
+ "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
+ "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384",
+ "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384",
+ "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256",
+ "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256",
+ "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA",
+ "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA",
+ "TLS_RSA_WITH_AES_128_GCM_SHA256",
+ "TLS_RSA_WITH_AES_256_GCM_SHA384",
+ "TLS_RSA_WITH_AES_128_CBC_SHA",
+ "TLS_RSA_WITH_AES_256_CBC_SHA"
+ ],
+ "compression_methods": [
+ "NULL"
+ ],
+ "extensions": [
+ {"name": "GREASE"},
+ {"name": "server_name"},
+ {"name": "extended_master_secret"},
+ {"name": "renegotiation_info"},
+ {"name": "supported_groups", "named_group_list": [
+ "GREASE",
+ "x25519",
+ "secp256r1",
+ "secp384r1"
+ ]},
+ {"name": "ec_point_formats", "ec_point_format_list": [
+ "uncompressed"
+ ]},
+ {"name": "session_ticket"},
+ {"name": "application_layer_protocol_negotiation", "protocol_name_list": [
+ "h2",
+ "http/1.1"
+ ]},
+ {"name": "status_request"},
+ {"name": "signature_algorithms", "supported_signature_algorithms": [
+ "ecdsa_secp256r1_sha256",
+ "rsa_pss_rsae_sha256",
+ "rsa_pkcs1_sha256",
+ "ecdsa_secp384r1_sha384",
+ "rsa_pss_rsae_sha384",
+ "rsa_pkcs1_sha384",
+ "rsa_pss_rsae_sha512",
+ "rsa_pkcs1_sha512"
+ ]},
+ {"name": "signed_certificate_timestamp"},
+ {"name": "key_share", "client_shares": [
+ {"group": "GREASE", "key_exchange": [0]},
+ {"group": "x25519"}
+ ]},
+ {"name": "psk_key_exchange_modes", "ke_modes": [
+ "psk_dhe_ke"
+ ]},
+ {"name": "supported_versions", "versions": [
+ "GREASE",
+ "TLS 1.3",
+ "TLS 1.2"
+ ]},
+ {"name": "compress_certificate", "algorithms": [
+ "brotli"
+ ]},
+ {"name": "application_settings", "supported_protocols": [
+ "h2"
+ ]},
+ {"name": "GREASE"},
+ {"name": "padding", "len": 0}
+ ]
+}
\ No newline at end of file
diff --git a/testdata/ClientHello-JSON-Firefox105.json b/testdata/ClientHello-JSON-Firefox105.json
new file mode 100644
index 0000000..fa0aac7
--- /dev/null
+++ b/testdata/ClientHello-JSON-Firefox105.json
@@ -0,0 +1,78 @@
+{
+ "cipher_suites": [
+ "TLS_AES_128_GCM_SHA256",
+ "TLS_CHACHA20_POLY1305_SHA256",
+ "TLS_AES_256_GCM_SHA384",
+ "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256",
+ "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
+ "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256",
+ "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256",
+ "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384",
+ "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384",
+ "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA",
+ "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA",
+ "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA",
+ "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA",
+ "TLS_RSA_WITH_AES_128_GCM_SHA256",
+ "TLS_RSA_WITH_AES_256_GCM_SHA384",
+ "TLS_RSA_WITH_AES_128_CBC_SHA",
+ "TLS_RSA_WITH_AES_256_CBC_SHA"
+ ],
+ "compression_methods": [
+ "NULL"
+ ],
+ "extensions": [
+ {"name": "server_name"},
+ {"name": "extended_master_secret"},
+ {"name": "renegotiation_info"},
+ {"name": "supported_groups", "named_group_list": [
+ "x25519",
+ "secp256r1",
+ "secp384r1",
+ "secp521r1",
+ "ffdhe2048",
+ "ffdhe3072"
+ ]},
+ {"name": "ec_point_formats", "ec_point_format_list": [
+ "uncompressed"
+ ]},
+ {"name": "session_ticket"},
+ {"name": "application_layer_protocol_negotiation", "protocol_name_list": [
+ "h2",
+ "http/1.1"
+ ]},
+ {"name": "status_request"},
+ {"name": "delegated_credentials", "supported_signature_algorithms": [
+ "ecdsa_secp256r1_sha256",
+ "ecdsa_secp384r1_sha384",
+ "ecdsa_secp521r1_sha512",
+ "ecdsa_sha1"
+ ]},
+ {"name": "key_share", "client_shares": [
+ {"group": "x25519"},
+ {"group": "secp256r1"}
+ ]},
+ {"name": "supported_versions", "versions": [
+ "TLS 1.3",
+ "TLS 1.2"
+ ]},
+ {"name": "signature_algorithms", "supported_signature_algorithms": [
+ "ecdsa_secp256r1_sha256",
+ "ecdsa_secp384r1_sha384",
+ "ecdsa_secp521r1_sha512",
+ "rsa_pss_rsae_sha256",
+ "rsa_pss_rsae_sha384",
+ "rsa_pss_rsae_sha512",
+ "rsa_pkcs1_sha256",
+ "rsa_pkcs1_sha384",
+ "rsa_pkcs1_sha512",
+ "ecdsa_sha1",
+ "rsa_pkcs1_sha1"
+ ]},
+ {"name": "psk_key_exchange_modes", "ke_modes": [
+ "psk_dhe_ke"
+ ]},
+ {"name": "record_size_limit", "record_size_limit": 16385},
+ {"name": "padding", "len": 0}
+ ]
+}
\ No newline at end of file
diff --git a/testdata/ClientHello-JSON-iOS14.json b/testdata/ClientHello-JSON-iOS14.json
new file mode 100644
index 0000000..f16725e
--- /dev/null
+++ b/testdata/ClientHello-JSON-iOS14.json
@@ -0,0 +1,85 @@
+{
+ "cipher_suites": [
+ "GREASE",
+ "TLS_AES_128_GCM_SHA256",
+ "TLS_AES_256_GCM_SHA384",
+ "TLS_CHACHA20_POLY1305_SHA256",
+ "TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384",
+ "TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256",
+ "TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305_SHA256",
+ "TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384",
+ "TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256",
+ "TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305_SHA256",
+ "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA384",
+ "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA256",
+ "TLS_ECDHE_ECDSA_WITH_AES_256_CBC_SHA",
+ "TLS_ECDHE_ECDSA_WITH_AES_128_CBC_SHA",
+ "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA384",
+ "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA256",
+ "TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA",
+ "TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA",
+ "TLS_RSA_WITH_AES_256_GCM_SHA384",
+ "TLS_RSA_WITH_AES_128_GCM_SHA256",
+ "TLS_RSA_WITH_AES_256_CBC_SHA256",
+ "TLS_RSA_WITH_AES_128_CBC_SHA256",
+ "TLS_RSA_WITH_AES_256_CBC_SHA",
+ "TLS_RSA_WITH_AES_128_CBC_SHA",
+ "TLS_ECDHE_ECDSA_WITH_3DES_EDE_CBC_SHA",
+ "TLS_ECDHE_RSA_WITH_3DES_EDE_CBC_SHA",
+ "TLS_RSA_WITH_3DES_EDE_CBC_SHA"
+ ],
+ "compression_methods": [
+ "NULL"
+ ],
+ "extensions": [
+ {"name": "GREASE"},
+ {"name": "server_name"},
+ {"name": "extended_master_secret"},
+ {"name": "renegotiation_info"},
+ {"name": "supported_groups", "named_group_list": [
+ "GREASE",
+ "x25519",
+ "secp256r1",
+ "secp384r1",
+ "secp521r1"
+ ]},
+ {"name": "ec_point_formats", "ec_point_format_list": [
+ "uncompressed"
+ ]},
+ {"name": "application_layer_protocol_negotiation", "protocol_name_list": [
+ "h2",
+ "http/1.1"
+ ]},
+ {"name": "status_request"},
+ {"name": "signature_algorithms", "supported_signature_algorithms": [
+ "ecdsa_secp256r1_sha256",
+ "rsa_pss_rsae_sha256",
+ "rsa_pkcs1_sha256",
+ "ecdsa_secp384r1_sha384",
+ "ecdsa_sha1",
+ "rsa_pss_rsae_sha384",
+ "rsa_pss_rsae_sha384",
+ "rsa_pkcs1_sha384",
+ "rsa_pss_rsae_sha512",
+ "rsa_pkcs1_sha512",
+ "rsa_pkcs1_sha1"
+ ]},
+ {"name": "signed_certificate_timestamp"},
+ {"name": "key_share", "client_shares": [
+ {"group": "GREASE", "key_exchange": [0]},
+ {"group": "x25519"}
+ ]},
+ {"name": "psk_key_exchange_modes", "ke_modes": [
+ "psk_dhe_ke"
+ ]},
+ {"name": "supported_versions", "versions": [
+ "GREASE",
+ "TLS 1.3",
+ "TLS 1.2",
+ "TLS 1.1",
+ "TLS 1.0"
+ ]},
+ {"name": "GREASE"},
+ {"name": "padding"}
+ ]
+}
\ No newline at end of file
diff --git a/ticket.go b/ticket.go
index 9ce1454..c861e92 100644
--- a/ticket.go
+++ b/ticket.go
@@ -32,7 +32,7 @@ type sessionState struct {
usedOldKey bool
}
-func (m *sessionState) marshal() []byte {
+func (m *sessionState) marshal() ([]byte, error) {
var b cryptobyte.Builder
b.AddUint16(m.vers)
b.AddUint16(m.cipherSuite)
@@ -47,7 +47,7 @@ func (m *sessionState) marshal() []byte {
})
}
})
- return b.BytesOrPanic()
+ return b.Bytes()
}
func (m *sessionState) unmarshal(data []byte) bool {
@@ -86,7 +86,7 @@ type sessionStateTLS13 struct {
certificate Certificate // CertificateEntry certificate_list<0..2^24-1>;
}
-func (m *sessionStateTLS13) marshal() []byte {
+func (m *sessionStateTLS13) marshal() ([]byte, error) {
var b cryptobyte.Builder
b.AddUint16(VersionTLS13)
b.AddUint8(0) // revision
@@ -96,7 +96,7 @@ func (m *sessionStateTLS13) marshal() []byte {
b.AddBytes(m.resumptionSecret)
})
marshalCertificate(&b, m.certificate)
- return b.BytesOrPanic()
+ return b.Bytes()
}
func (m *sessionStateTLS13) unmarshal(data []byte) bool {
diff --git a/tls_test.go b/tls_test.go
index 98de1df..dd02f9b 100644
--- a/tls_test.go
+++ b/tls_test.go
@@ -814,7 +814,7 @@ func TestCloneNonFuncFields(t *testing.T) {
f.Set(reflect.ValueOf("b"))
case "ClientAuth":
f.Set(reflect.ValueOf(VerifyClientCertIfGiven))
- case "InsecureSkipVerify", "SessionTicketsDisabled", "DynamicRecordSizingDisabled", "PreferServerCipherSuites":
+ case "InsecureSkipVerify", "InsecureSkipTimeVerify", "SessionTicketsDisabled", "DynamicRecordSizingDisabled", "PreferServerCipherSuites":
f.Set(reflect.ValueOf(true))
case "InsecureServerNameToVerify":
f.Set(reflect.ValueOf("c"))
diff --git a/u_clienthello_json.go b/u_clienthello_json.go
new file mode 100644
index 0000000..2529bf7
--- /dev/null
+++ b/u_clienthello_json.go
@@ -0,0 +1,168 @@
+package tls
+
+import (
+ "encoding/json"
+ "errors"
+ "fmt"
+ "os"
+
+ "github.com/gaukas/godicttls"
+)
+
+var ErrUnknownExtension = errors.New("extension name is unknown to the dictionary")
+
+type ClientHelloSpecJSONUnmarshaler struct {
+ CipherSuites *CipherSuitesJSONUnmarshaler `json:"cipher_suites"`
+ CompressionMethods *CompressionMethodsJSONUnmarshaler `json:"compression_methods"`
+ Extensions *TLSExtensionsJSONUnmarshaler `json:"extensions"`
+ TLSVersMin uint16 `json:"min_vers,omitempty"` // optional
+ TLSVersMax uint16 `json:"max_vers,omitempty"` // optional
+}
+
+func (chsju *ClientHelloSpecJSONUnmarshaler) ClientHelloSpec() ClientHelloSpec {
+ return ClientHelloSpec{
+ CipherSuites: chsju.CipherSuites.CipherSuites(),
+ CompressionMethods: chsju.CompressionMethods.CompressionMethods(),
+ Extensions: chsju.Extensions.Extensions(),
+ TLSVersMin: chsju.TLSVersMin,
+ TLSVersMax: chsju.TLSVersMax,
+ }
+}
+
+type CipherSuitesJSONUnmarshaler struct {
+ cipherSuites []uint16
+}
+
+func (c *CipherSuitesJSONUnmarshaler) UnmarshalJSON(jsonStr []byte) error {
+ var cipherSuiteNames []string
+ if err := json.Unmarshal(jsonStr, &cipherSuiteNames); err != nil {
+ return err
+ }
+
+ for _, name := range cipherSuiteNames {
+ if name == "GREASE" {
+ c.cipherSuites = append(c.cipherSuites, GREASE_PLACEHOLDER)
+ continue
+ }
+
+ if id, ok := godicttls.DictCipherSuiteNameIndexed[name]; ok {
+ c.cipherSuites = append(c.cipherSuites, id)
+ } else {
+ return fmt.Errorf("unknown cipher suite name: %s", name)
+ }
+ }
+
+ return nil
+}
+
+func (c *CipherSuitesJSONUnmarshaler) CipherSuites() []uint16 {
+ return c.cipherSuites
+}
+
+type CompressionMethodsJSONUnmarshaler struct {
+ compressionMethods []uint8
+}
+
+func (c *CompressionMethodsJSONUnmarshaler) UnmarshalJSON(jsonStr []byte) error {
+ var compressionMethodNames []string
+ if err := json.Unmarshal(jsonStr, &compressionMethodNames); err != nil {
+ return err
+ }
+
+ for _, name := range compressionMethodNames {
+ if id, ok := godicttls.DictCompMethNameIndexed[name]; ok {
+ c.compressionMethods = append(c.compressionMethods, id)
+ } else {
+ return fmt.Errorf("unknown compression method name: %s", name)
+ }
+ }
+
+ return nil
+}
+
+func (c *CompressionMethodsJSONUnmarshaler) CompressionMethods() []uint8 {
+ return c.compressionMethods
+}
+
+type TLSExtensionsJSONUnmarshaler struct {
+ extensions []TLSExtensionJSON
+}
+
+func (e *TLSExtensionsJSONUnmarshaler) UnmarshalJSON(jsonStr []byte) error {
+ var accepters []tlsExtensionJSONAccepter
+ if err := json.Unmarshal(jsonStr, &accepters); err != nil {
+ return err
+ }
+
+ var exts []TLSExtensionJSON = make([]TLSExtensionJSON, 0, len(accepters))
+ for _, accepter := range accepters {
+ if accepter.extNameOnly.Name == "GREASE" {
+ exts = append(exts, &UtlsGREASEExtension{})
+ continue
+ }
+
+ if extID, ok := godicttls.DictExtTypeNameIndexed[accepter.extNameOnly.Name]; !ok {
+ return fmt.Errorf("%w: %s", ErrUnknownExtension, accepter.extNameOnly.Name)
+ } else {
+ // get extension type from ID
+ var ext TLSExtension = ExtensionFromID(extID)
+ if ext == nil {
+ // fallback to generic extension
+ ext = genericExtension(extID, accepter.extNameOnly.Name)
+ }
+
+ if extJsonCompatible, ok := ext.(TLSExtensionJSON); ok {
+ exts = append(exts, extJsonCompatible)
+ } else {
+ return fmt.Errorf("extension %d (%s) is not JSON compatible", extID, accepter.extNameOnly.Name)
+ }
+ }
+ }
+
+ // unmashal extensions
+ for idx, ext := range exts {
+ // json.Unmarshal will call the UnmarshalJSON method of the extension
+ if err := json.Unmarshal(accepters[idx].origJsonInput, ext); err != nil {
+ return err
+ }
+ }
+
+ e.extensions = exts
+ return nil
+}
+
+func (e *TLSExtensionsJSONUnmarshaler) Extensions() []TLSExtension {
+ var exts []TLSExtension = make([]TLSExtension, 0, len(e.extensions))
+ for _, ext := range e.extensions {
+ exts = append(exts, ext)
+ }
+ return exts
+}
+
+func genericExtension(id uint16, name string) TLSExtension {
+ var warningMsg string = "WARNING: extension "
+ warningMsg += fmt.Sprintf("%d ", id)
+ if len(name) > 0 {
+ warningMsg += fmt.Sprintf("(%s) ", name)
+ }
+ warningMsg += "is falling back to generic extension"
+ warningMsg += "\n"
+
+ fmt.Fprint(os.Stderr, warningMsg)
+
+ // fallback to generic extension
+ return &GenericExtension{Id: id}
+}
+
+type tlsExtensionJSONAccepter struct {
+ extNameOnly struct {
+ Name string `json:"name"`
+ }
+ origJsonInput []byte
+}
+
+func (t *tlsExtensionJSONAccepter) UnmarshalJSON(jsonStr []byte) error {
+ t.origJsonInput = make([]byte, len(jsonStr))
+ copy(t.origJsonInput, jsonStr)
+ return json.Unmarshal(jsonStr, &t.extNameOnly)
+}
diff --git a/u_clienthello_json_test.go b/u_clienthello_json_test.go
new file mode 100644
index 0000000..9ab86c7
--- /dev/null
+++ b/u_clienthello_json_test.go
@@ -0,0 +1,123 @@
+package tls
+
+import (
+ "encoding/json"
+ "os"
+ "reflect"
+ "testing"
+)
+
+func TestClientHelloSpecJSONUnmarshaler(t *testing.T) {
+ testClientHelloSpecJSONUnmarshaler(t, "testdata/ClientHello-JSON-Chrome102.json", HelloChrome_102)
+ testClientHelloSpecJSONUnmarshaler(t, "testdata/ClientHello-JSON-Firefox105.json", HelloFirefox_105)
+ testClientHelloSpecJSONUnmarshaler(t, "testdata/ClientHello-JSON-iOS14.json", HelloIOS_14)
+ testClientHelloSpecJSONUnmarshaler(t, "testdata/ClientHello-JSON-Edge106.json", HelloEdge_106)
+}
+
+func testClientHelloSpecJSONUnmarshaler(
+ t *testing.T,
+ jsonFilepath string,
+ truthClientHelloID ClientHelloID,
+) {
+ jsonCH, err := os.ReadFile(jsonFilepath)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ var chsju ClientHelloSpecJSONUnmarshaler
+ if err := json.Unmarshal(jsonCH, &chsju); err != nil {
+ t.Fatal(err)
+ }
+
+ truthSpec, _ := utlsIdToSpec(truthClientHelloID)
+ jsonSpec := chsju.ClientHelloSpec()
+
+ // Compare CipherSuites
+ if !reflect.DeepEqual(jsonSpec.CipherSuites, truthSpec.CipherSuites) {
+ t.Errorf("JSONUnmarshaler %s: got %#v, want %#v", clientHelloSpecJSONTestIdentifier(truthClientHelloID), jsonSpec.CipherSuites, truthSpec.CipherSuites)
+ }
+
+ // Compare CompressionMethods
+ if !reflect.DeepEqual(jsonSpec.CompressionMethods, truthSpec.CompressionMethods) {
+ t.Errorf("JSONUnmarshaler %s: got %#v, want %#v", clientHelloSpecJSONTestIdentifier(truthClientHelloID), jsonSpec.CompressionMethods, truthSpec.CompressionMethods)
+ }
+
+ // Compare Extensions
+ if len(jsonSpec.Extensions) != len(truthSpec.Extensions) {
+ t.Errorf("JSONUnmarshaler %s: len(jsonExtensions) = %d != %d = len(truthExtensions)", clientHelloSpecJSONTestIdentifier(truthClientHelloID), len(jsonSpec.Extensions), len(truthSpec.Extensions))
+ }
+
+ for i := range jsonSpec.Extensions {
+ if !reflect.DeepEqual(jsonSpec.Extensions[i], truthSpec.Extensions[i]) {
+ if _, ok := jsonSpec.Extensions[i].(*UtlsPaddingExtension); ok {
+ testedPaddingExt := jsonSpec.Extensions[i].(*UtlsPaddingExtension)
+ savedPaddingExt := truthSpec.Extensions[i].(*UtlsPaddingExtension)
+ if testedPaddingExt.PaddingLen != savedPaddingExt.PaddingLen || testedPaddingExt.WillPad != savedPaddingExt.WillPad {
+ t.Errorf("got %#v, want %#v", testedPaddingExt, savedPaddingExt)
+ } else {
+ continue // UtlsPaddingExtension has non-nil function member
+ }
+ }
+ t.Errorf("JSONUnmarshaler %s: got %#v, want %#v", clientHelloSpecJSONTestIdentifier(truthClientHelloID), jsonSpec.Extensions[i], truthSpec.Extensions[i])
+ }
+ }
+}
+
+func TestClientHelloSpecUnmarshalJSON(t *testing.T) {
+ testClientHelloSpecUnmarshalJSON(t, "testdata/ClientHello-JSON-Chrome102.json", HelloChrome_102)
+ testClientHelloSpecUnmarshalJSON(t, "testdata/ClientHello-JSON-Firefox105.json", HelloFirefox_105)
+ testClientHelloSpecUnmarshalJSON(t, "testdata/ClientHello-JSON-iOS14.json", HelloIOS_14)
+ testClientHelloSpecUnmarshalJSON(t, "testdata/ClientHello-JSON-Edge106.json", HelloEdge_106)
+}
+
+func testClientHelloSpecUnmarshalJSON(
+ t *testing.T,
+ jsonFilepath string,
+ truthClientHelloID ClientHelloID,
+) {
+ var jsonSpec ClientHelloSpec
+ jsonCH, err := os.ReadFile(jsonFilepath)
+ if err != nil {
+ t.Fatal(err)
+ }
+
+ if err := json.Unmarshal(jsonCH, &jsonSpec); err != nil {
+ t.Fatal(err)
+ }
+
+ truthSpec, _ := utlsIdToSpec(truthClientHelloID)
+
+ // Compare CipherSuites
+ if !reflect.DeepEqual(jsonSpec.CipherSuites, truthSpec.CipherSuites) {
+ t.Errorf("UnmarshalJSON %s: got %#v, want %#v", clientHelloSpecJSONTestIdentifier(truthClientHelloID), jsonSpec.CipherSuites, truthSpec.CipherSuites)
+ }
+
+ // Compare CompressionMethods
+ if !reflect.DeepEqual(jsonSpec.CompressionMethods, truthSpec.CompressionMethods) {
+ t.Errorf("UnmarshalJSON %s: got %#v, want %#v", clientHelloSpecJSONTestIdentifier(truthClientHelloID), jsonSpec.CompressionMethods, truthSpec.CompressionMethods)
+ }
+
+ // Compare Extensions
+ if len(jsonSpec.Extensions) != len(truthSpec.Extensions) {
+ t.Errorf("UnmarshalJSON %s: len(jsonExtensions) = %d != %d = len(truthExtensions)", jsonFilepath, len(jsonSpec.Extensions), len(truthSpec.Extensions))
+ }
+
+ for i := range jsonSpec.Extensions {
+ if !reflect.DeepEqual(jsonSpec.Extensions[i], truthSpec.Extensions[i]) {
+ if _, ok := jsonSpec.Extensions[i].(*UtlsPaddingExtension); ok {
+ testedPaddingExt := jsonSpec.Extensions[i].(*UtlsPaddingExtension)
+ savedPaddingExt := truthSpec.Extensions[i].(*UtlsPaddingExtension)
+ if testedPaddingExt.PaddingLen != savedPaddingExt.PaddingLen || testedPaddingExt.WillPad != savedPaddingExt.WillPad {
+ t.Errorf("got %#v, want %#v", testedPaddingExt, savedPaddingExt)
+ } else {
+ continue // UtlsPaddingExtension has non-nil function member
+ }
+ }
+ t.Errorf("UnmarshalJSON %s: got %#v, want %#v", clientHelloSpecJSONTestIdentifier(truthClientHelloID), jsonSpec.Extensions[i], truthSpec.Extensions[i])
+ }
+ }
+}
+
+func clientHelloSpecJSONTestIdentifier(id ClientHelloID) string {
+ return id.Client + id.Version
+}
diff --git a/u_common.go b/u_common.go
index 196d392..33e3a28 100644
--- a/u_common.go
+++ b/u_common.go
@@ -7,8 +7,14 @@ package tls
import (
"crypto/hmac"
"crypto/sha512"
+ "encoding/json"
+ "errors"
"fmt"
"hash"
+ "log"
+
+ "github.com/refraction-networking/utls/internal/helper"
+ "golang.org/x/crypto/cryptobyte"
)
// Naming convention:
@@ -26,6 +32,8 @@ const (
// TLS
const (
+ extensionNextProtoNeg uint16 = 13172 // not IANA assigned. Removed by crypto/tls since Nov 2019
+
utlsExtensionPadding uint16 = 21
utlsExtensionExtendedMasterSecret uint16 = 23 // https://tools.ietf.org/html/rfc7627
utlsExtensionCompressCertificate uint16 = 27 // https://datatracker.ietf.org/doc/html/rfc8879#section-7.1
@@ -33,10 +41,12 @@ const (
utlsFakeExtensionCustom uint16 = 1234 // not IANA assigned, for ALPS
// extensions with 'fake' prefix break connection, if server echoes them back
+ fakeExtensionEncryptThenMAC uint16 = 22
fakeExtensionTokenBinding uint16 = 24
+ fakeExtensionDelegatedCredentials uint16 = 34
+ fakeExtensionPreSharedKey uint16 = 41
fakeOldExtensionChannelID uint16 = 30031 // not IANA assigned
fakeExtensionChannelID uint16 = 30032 // not IANA assigned
- fakeExtensionDelegatedCredentials uint16 = 34
)
const (
@@ -110,6 +120,10 @@ type ClientHelloID struct {
// Seed is only used for randomized fingerprints to seed PRNG.
// Must not be modified once set.
Seed *PRNGSeed
+
+ // Weights are only used for randomized fingerprints in func
+ // generateRandomizedSpec(). Must not be modified once set.
+ Weights *Weights
}
func (p *ClientHelloID) Str() string {
@@ -155,67 +169,474 @@ type ClientHelloSpec struct {
// TLSFingerprintLink string // ?? link to tlsfingerprint.io for informational purposes
}
+// ReadCipherSuites is a helper function to construct a list of cipher suites from
+// a []byte into []uint16.
+//
+// example: []byte{0x13, 0x01, 0x13, 0x02, 0x13, 0x03} => []uint16{0x1301, 0x1302, 0x1303}
+func (chs *ClientHelloSpec) ReadCipherSuites(b []byte) error {
+ cipherSuites := []uint16{}
+ s := cryptobyte.String(b)
+ for !s.Empty() {
+ var suite uint16
+ if !s.ReadUint16(&suite) {
+ return errors.New("unable to read ciphersuite")
+ }
+ cipherSuites = append(cipherSuites, unGREASEUint16(suite))
+ }
+ chs.CipherSuites = cipherSuites
+ return nil
+}
+
+// ReadCompressionMethods is a helper function to construct a list of compression
+// methods from a []byte into []uint8.
+func (chs *ClientHelloSpec) ReadCompressionMethods(compressionMethods []byte) error {
+ chs.CompressionMethods = compressionMethods
+ return nil
+}
+
+// ReadTLSExtensions is a helper function to construct a list of TLS extensions from
+// a byte slice into []TLSExtension.
+//
+// If keepPSK is not set, the PSK extension will cause an error.
+func (chs *ClientHelloSpec) ReadTLSExtensions(b []byte, allowBluntMimicry bool) error {
+ extensions := cryptobyte.String(b)
+ for !extensions.Empty() {
+ var extension uint16
+ var extData cryptobyte.String
+ if !extensions.ReadUint16(&extension) {
+ return fmt.Errorf("unable to read extension ID")
+ }
+ if !extensions.ReadUint16LengthPrefixed(&extData) {
+ return fmt.Errorf("unable to read data for extension %x", extension)
+ }
+
+ ext := ExtensionFromID(extension)
+ extWriter, ok := ext.(TLSExtensionWriter)
+ if ext != nil && ok { // known extension and implements TLSExtensionWriter properly
+ if extension == extensionSupportedVersions {
+ chs.TLSVersMin = 0
+ chs.TLSVersMax = 0
+ }
+ if _, err := extWriter.Write(extData); err != nil {
+ return err
+ }
+
+ chs.Extensions = append(chs.Extensions, extWriter)
+ } else {
+ if allowBluntMimicry {
+ chs.Extensions = append(chs.Extensions, &GenericExtension{extension, extData})
+ } else {
+ return fmt.Errorf("unsupported extension %d", extension)
+ }
+ }
+ }
+ return nil
+}
+
+func (chs *ClientHelloSpec) AlwaysAddPadding() {
+ alreadyHasPadding := false
+ for _, ext := range chs.Extensions {
+ if _, ok := ext.(*UtlsPaddingExtension); ok {
+ alreadyHasPadding = true
+ break
+ }
+ if _, ok := ext.(*FakePreSharedKeyExtension); ok {
+ alreadyHasPadding = true // PSK must be last, so we don't need to add padding
+ break
+ }
+ }
+ if !alreadyHasPadding {
+ chs.Extensions = append(chs.Extensions, &UtlsPaddingExtension{GetPaddingLen: BoringPaddingStyle})
+ }
+}
+
+// Import TLS ClientHello data from client.tlsfingerprint.io:8443
+//
+// data is a map of []byte with following keys:
+// - cipher_suites: [10, 10, 19, 1, 19, 2, 19, 3, 192, 43, 192, 47, 192, 44, 192, 48, 204, 169, 204, 168, 192, 19, 192, 20, 0, 156, 0, 157, 0, 47, 0, 53]
+// - compression_methods: [0] => null
+// - extensions: [10, 10, 255, 1, 0, 45, 0, 35, 0, 16, 68, 105, 0, 11, 0, 43, 0, 18, 0, 13, 0, 0, 0, 10, 0, 27, 0, 5, 0, 51, 0, 23, 10, 10, 0, 21]
+// - pt_fmts (ec_point_formats): [1, 0] => len: 1, content: 0x00
+// - sig_algs (signature_algorithms): [0, 16, 4, 3, 8, 4, 4, 1, 5, 3, 8, 5, 5, 1, 8, 6, 6, 1] => len: 16, content: 0x0403, 0x0804, 0x0401, 0x0503, 0x0805, 0x0501, 0x0806, 0x0601
+// - supported_versions: [10, 10, 3, 4, 3, 3] => 0x0a0a, 0x0304, 0x0303 (GREASE, TLS 1.3, TLS 1.2)
+// - curves (named_groups, supported_groups): [0, 8, 10, 10, 0, 29, 0, 23, 0, 24] => len: 8, content: GREASE, 0x001d, 0x0017, 0x0018
+// - alpn: [0, 12, 2, 104, 50, 8, 104, 116, 116, 112, 47, 49, 46, 49] => len: 12, content: h2, http/1.1
+// - key_share: [10, 10, 0, 1, 0, 29, 0, 32] => {group: 0x0a0a, len:1}, {group: 0x001d, len:32}
+// - psk_key_exchange_modes: [1] => psk_dhe_ke(0x01)
+// - cert_compression_algs: [2, 0, 2] => brotli (0x0002)
+// - record_size_limit: [0, 255] => 255
+//
+// TLSVersMin/TLSVersMax are set to 0 if supported_versions is present.
+// To prevent conflict, they should be set manually if needed BEFORE calling this function.
+func (chs *ClientHelloSpec) ImportTLSClientHello(data map[string][]byte) error {
+ var tlsExtensionTypes []uint16
+ var err error
+
+ if data["cipher_suites"] == nil {
+ return errors.New("cipher_suites is required")
+ }
+ chs.CipherSuites, err = helper.Uint8to16(data["cipher_suites"])
+ if err != nil {
+ return err
+ }
+
+ if data["compression_methods"] == nil {
+ return errors.New("compression_methods is required")
+ }
+ chs.CompressionMethods = data["compression_methods"]
+
+ if data["extensions"] == nil {
+ return errors.New("extensions is required")
+ }
+ tlsExtensionTypes, err = helper.Uint8to16(data["extensions"])
+ if err != nil {
+ return err
+ }
+
+ for _, extType := range tlsExtensionTypes {
+ extension := ExtensionFromID(extType)
+ extWriter, ok := extension.(TLSExtensionWriter)
+ if !ok {
+ return fmt.Errorf("unsupported extension %d", extType)
+ }
+ if extension == nil || !ok {
+ log.Printf("[Warning] Unsupported extension %d added as a &GenericExtension without Data", extType)
+ chs.Extensions = append(chs.Extensions, &GenericExtension{extType, []byte{}})
+ } else {
+ switch extType {
+ case extensionSupportedPoints:
+ if data["pt_fmts"] == nil {
+ return errors.New("pt_fmts is required")
+ }
+ _, err = extWriter.Write(data["pt_fmts"])
+ if err != nil {
+ return err
+ }
+ case extensionSignatureAlgorithms:
+ if data["sig_algs"] == nil {
+ return errors.New("sig_algs is required")
+ }
+ _, err = extWriter.Write(data["sig_algs"])
+ if err != nil {
+ return err
+ }
+ case extensionSupportedVersions:
+ chs.TLSVersMin = 0
+ chs.TLSVersMax = 0
+
+ if data["supported_versions"] == nil {
+ return errors.New("supported_versions is required")
+ }
+
+ // need to add uint8 length prefix
+ fixedData := make([]byte, len(data["supported_versions"])+1)
+ fixedData[0] = uint8(len(data["supported_versions"]) & 0xff)
+ copy(fixedData[1:], data["supported_versions"])
+ _, err = extWriter.Write(fixedData)
+ if err != nil {
+ return err
+ }
+ case extensionSupportedCurves:
+ if data["curves"] == nil {
+ return errors.New("curves is required")
+ }
+
+ _, err = extWriter.Write(data["curves"])
+ if err != nil {
+ return err
+ }
+ case extensionALPN:
+ if data["alpn"] == nil {
+ return errors.New("alpn is required")
+ }
+
+ _, err = extWriter.Write(data["alpn"])
+ if err != nil {
+ return err
+ }
+ case extensionKeyShare:
+ if data["key_share"] == nil {
+ return errors.New("key_share is required")
+ }
+
+ // need to add (zero) data per each key share, [10, 10, 0, 1] => [10, 10, 0, 1, 0]
+ fixedData := make([]byte, 0)
+ for i := 0; i < len(data["key_share"]); i += 4 {
+ fixedData = append(fixedData, data["key_share"][i:i+4]...)
+ for j := 0; j < int(data["key_share"][i+3]); j++ {
+ fixedData = append(fixedData, 0)
+ }
+ }
+ // add uint16 length prefix
+ fixedData = append([]byte{uint8(len(fixedData) >> 8), uint8(len(fixedData) & 0xff)}, fixedData...)
+
+ _, err = extWriter.Write(fixedData)
+ if err != nil {
+ return err
+ }
+ case extensionPSKModes:
+ if data["psk_key_exchange_modes"] == nil {
+ return errors.New("psk_key_exchange_modes is required")
+ }
+
+ // need to add uint8 length prefix
+ fixedData := make([]byte, len(data["psk_key_exchange_modes"])+1)
+ fixedData[0] = uint8(len(data["psk_key_exchange_modes"]) & 0xff)
+ copy(fixedData[1:], data["psk_key_exchange_modes"])
+ _, err = extWriter.Write(fixedData)
+ if err != nil {
+ return err
+ }
+ case utlsExtensionCompressCertificate:
+ if data["cert_compression_algs"] == nil {
+ return errors.New("cert_compression_algs is required")
+ }
+
+ // need to add uint8 length prefix
+ fixedData := make([]byte, len(data["cert_compression_algs"])+1)
+ fixedData[0] = uint8(len(data["cert_compression_algs"]) & 0xff)
+ copy(fixedData[1:], data["cert_compression_algs"])
+ _, err = extWriter.Write(fixedData)
+ if err != nil {
+ return err
+ }
+ case fakeRecordSizeLimit:
+ if data["record_size_limit"] == nil {
+ return errors.New("record_size_limit is required")
+ }
+
+ _, err = extWriter.Write(data["record_size_limit"]) // uint16 as []byte
+ if err != nil {
+ return err
+ }
+ case utlsExtensionApplicationSettings:
+ // TODO: tlsfingerprint.io should record/provide application settings data
+ extWriter.(*ApplicationSettingsExtension).SupportedProtocols = []string{"h2"}
+ case fakeExtensionPreSharedKey:
+ log.Printf("[Warning] PSK extension added without data")
+ default:
+ if !isGREASEUint16(extType) {
+ log.Printf("[Warning] extension %d added without data", extType)
+ } /*else {
+ log.Printf("[Warning] GREASE extension added but ID/Data discarded. They will be automatically re-GREASEd on ApplyPreset() call.")
+ }*/
+ }
+ chs.Extensions = append(chs.Extensions, extWriter)
+ }
+ }
+ return nil
+}
+
+// ImportTLSClientHelloFromJSON imports ClientHelloSpec from JSON data from client.tlsfingerprint.io format
+//
+// It calls ImportTLSClientHello internally after unmarshaling JSON data into map[string][]byte
+func (chs *ClientHelloSpec) ImportTLSClientHelloFromJSON(jsonB []byte) error {
+ var data map[string][]byte
+ err := json.Unmarshal(jsonB, &data)
+ if err != nil {
+ return err
+ }
+ return chs.ImportTLSClientHello(data)
+}
+
+// FromRaw converts a ClientHello message in the form of raw bytes into a ClientHelloSpec.
+func (chs *ClientHelloSpec) FromRaw(raw []byte, allowBluntMimicry ...bool) error {
+ if chs == nil {
+ return errors.New("cannot unmarshal into nil ClientHelloSpec")
+ }
+
+ var bluntMimicry = false
+ if len(allowBluntMimicry) == 1 {
+ bluntMimicry = allowBluntMimicry[0]
+ }
+
+ *chs = ClientHelloSpec{} // reset
+ s := cryptobyte.String(raw)
+
+ var contentType uint8
+ var recordVersion uint16
+ if !s.ReadUint8(&contentType) || // record type
+ !s.ReadUint16(&recordVersion) || !s.Skip(2) { // record version and length
+ return errors.New("unable to read record type, version, and length")
+ }
+
+ if recordType(contentType) != recordTypeHandshake {
+ return errors.New("record is not a handshake")
+ }
+
+ var handshakeVersion uint16
+ var handshakeType uint8
+
+ if !s.ReadUint8(&handshakeType) || !s.Skip(3) || // message type and 3 byte length
+ !s.ReadUint16(&handshakeVersion) || !s.Skip(32) { // 32 byte random
+ return errors.New("unable to read handshake message type, length, and random")
+ }
+
+ if handshakeType != typeClientHello {
+ return errors.New("handshake message is not a ClientHello")
+ }
+
+ chs.TLSVersMin = recordVersion
+ chs.TLSVersMax = handshakeVersion
+
+ var ignoredSessionID cryptobyte.String
+ if !s.ReadUint8LengthPrefixed(&ignoredSessionID) {
+ return errors.New("unable to read session id")
+ }
+
+ // CipherSuites
+ var cipherSuitesBytes cryptobyte.String
+ if !s.ReadUint16LengthPrefixed(&cipherSuitesBytes) {
+ return errors.New("unable to read ciphersuites")
+ }
+
+ if err := chs.ReadCipherSuites(cipherSuitesBytes); err != nil {
+ return err
+ }
+
+ // CompressionMethods
+ var compressionMethods cryptobyte.String
+ if !s.ReadUint8LengthPrefixed(&compressionMethods) {
+ return errors.New("unable to read compression methods")
+ }
+
+ if err := chs.ReadCompressionMethods(compressionMethods); err != nil {
+ return err
+ }
+
+ if s.Empty() {
+ // Extensions are optional
+ return nil
+ }
+
+ var extensions cryptobyte.String
+ if !s.ReadUint16LengthPrefixed(&extensions) {
+ return errors.New("unable to read extensions data")
+ }
+
+ if err := chs.ReadTLSExtensions(extensions, bluntMimicry); err != nil {
+ return err
+ }
+
+ return nil
+}
+
+// UnmarshalJSON unmarshals a ClientHello message in the form of JSON into a ClientHelloSpec.
+func (chs *ClientHelloSpec) UnmarshalJSON(jsonB []byte) error {
+ var chsju ClientHelloSpecJSONUnmarshaler
+ if err := json.Unmarshal(jsonB, &chsju); err != nil {
+ return err
+ }
+
+ *chs = chsju.ClientHelloSpec()
+ return nil
+}
+
var (
// HelloGolang will use default "crypto/tls" handshake marshaling codepath, which WILL
// overwrite your changes to Hello(Config, Session are fine).
// You might want to call BuildHandshakeState() before applying any changes.
// UConn.Extensions will be completely ignored.
- HelloGolang = ClientHelloID{helloGolang, helloAutoVers, nil}
+ HelloGolang = ClientHelloID{helloGolang, helloAutoVers, nil, nil}
// HelloCustom will prepare ClientHello with empty uconn.Extensions so you can fill it with
// TLSExtensions manually or use ApplyPreset function
- HelloCustom = ClientHelloID{helloCustom, helloAutoVers, nil}
+ HelloCustom = ClientHelloID{helloCustom, helloAutoVers, nil, nil}
// HelloRandomized* randomly adds/reorders extensions, ciphersuites, etc.
- HelloRandomized = ClientHelloID{helloRandomized, helloAutoVers, nil}
- HelloRandomizedALPN = ClientHelloID{helloRandomizedALPN, helloAutoVers, nil}
- HelloRandomizedNoALPN = ClientHelloID{helloRandomizedNoALPN, helloAutoVers, nil}
+ HelloRandomized = ClientHelloID{helloRandomized, helloAutoVers, nil, nil}
+ HelloRandomizedALPN = ClientHelloID{helloRandomizedALPN, helloAutoVers, nil, nil}
+ HelloRandomizedNoALPN = ClientHelloID{helloRandomizedNoALPN, helloAutoVers, nil, nil}
// The rest will will parrot given browser.
HelloFirefox_Auto = HelloFirefox_105
- HelloFirefox_55 = ClientHelloID{helloFirefox, "55", nil}
- HelloFirefox_56 = ClientHelloID{helloFirefox, "56", nil}
- HelloFirefox_63 = ClientHelloID{helloFirefox, "63", nil}
- HelloFirefox_65 = ClientHelloID{helloFirefox, "65", nil}
- HelloFirefox_99 = ClientHelloID{helloFirefox, "99", nil}
- HelloFirefox_102 = ClientHelloID{helloFirefox, "102", nil}
- HelloFirefox_105 = ClientHelloID{helloFirefox, "105", nil}
+ HelloFirefox_55 = ClientHelloID{helloFirefox, "55", nil, nil}
+ HelloFirefox_56 = ClientHelloID{helloFirefox, "56", nil, nil}
+ HelloFirefox_63 = ClientHelloID{helloFirefox, "63", nil, nil}
+ HelloFirefox_65 = ClientHelloID{helloFirefox, "65", nil, nil}
+ HelloFirefox_99 = ClientHelloID{helloFirefox, "99", nil, nil}
+ HelloFirefox_102 = ClientHelloID{helloFirefox, "102", nil, nil}
+ HelloFirefox_105 = ClientHelloID{helloFirefox, "105", nil, nil}
HelloChrome_Auto = HelloChrome_106_Shuffle
- HelloChrome_58 = ClientHelloID{helloChrome, "58", nil}
- HelloChrome_62 = ClientHelloID{helloChrome, "62", nil}
- HelloChrome_70 = ClientHelloID{helloChrome, "70", nil}
- HelloChrome_72 = ClientHelloID{helloChrome, "72", nil}
- HelloChrome_83 = ClientHelloID{helloChrome, "83", nil}
- HelloChrome_87 = ClientHelloID{helloChrome, "87", nil}
- HelloChrome_96 = ClientHelloID{helloChrome, "96", nil}
- HelloChrome_100 = ClientHelloID{helloChrome, "100", nil}
- HelloChrome_102 = ClientHelloID{helloChrome, "102", nil}
- HelloChrome_106_Shuffle = ClientHelloID{helloChrome, "106", nil} // beta: shuffler enabled starting from 106
+ HelloChrome_58 = ClientHelloID{helloChrome, "58", nil, nil}
+ HelloChrome_62 = ClientHelloID{helloChrome, "62", nil, nil}
+ HelloChrome_70 = ClientHelloID{helloChrome, "70", nil, nil}
+ HelloChrome_72 = ClientHelloID{helloChrome, "72", nil, nil}
+ HelloChrome_83 = ClientHelloID{helloChrome, "83", nil, nil}
+ HelloChrome_87 = ClientHelloID{helloChrome, "87", nil, nil}
+ HelloChrome_96 = ClientHelloID{helloChrome, "96", nil, nil}
+ HelloChrome_100 = ClientHelloID{helloChrome, "100", nil, nil}
+ HelloChrome_102 = ClientHelloID{helloChrome, "102", nil, nil}
+ HelloChrome_106_Shuffle = ClientHelloID{helloChrome, "106", nil, nil} // beta: shuffler enabled starting from 106
+
+ // Chrome with PSK: Chrome start sending this ClientHello after doing TLS 1.3 handshake with the same server.
+ HelloChrome_100_PSK = ClientHelloID{helloChrome, "100_PSK", nil, nil} // beta: PSK extension added. uTLS doesn't fully support PSK. Use at your own risk.
+ HelloChrome_112_PSK_Shuf = ClientHelloID{helloChrome, "112_PSK", nil, nil} // beta: PSK extension added. uTLS doesn't fully support PSK. Use at your own risk.
HelloIOS_Auto = HelloIOS_14
- HelloIOS_11_1 = ClientHelloID{helloIOS, "111", nil} // legacy "111" means 11.1
- HelloIOS_12_1 = ClientHelloID{helloIOS, "12.1", nil}
- HelloIOS_13 = ClientHelloID{helloIOS, "13", nil}
- HelloIOS_14 = ClientHelloID{helloIOS, "14", nil}
+ HelloIOS_11_1 = ClientHelloID{helloIOS, "111", nil, nil} // legacy "111" means 11.1
+ HelloIOS_12_1 = ClientHelloID{helloIOS, "12.1", nil, nil}
+ HelloIOS_13 = ClientHelloID{helloIOS, "13", nil, nil}
+ HelloIOS_14 = ClientHelloID{helloIOS, "14", nil, nil}
- HelloAndroid_11_OkHttp = ClientHelloID{helloAndroid, "11", nil}
+ HelloAndroid_11_OkHttp = ClientHelloID{helloAndroid, "11", nil, nil}
HelloEdge_Auto = HelloEdge_85 // HelloEdge_106 seems to be incompatible with this library
- HelloEdge_85 = ClientHelloID{helloEdge, "85", nil}
- HelloEdge_106 = ClientHelloID{helloEdge, "106", nil}
+ HelloEdge_85 = ClientHelloID{helloEdge, "85", nil, nil}
+ HelloEdge_106 = ClientHelloID{helloEdge, "106", nil, nil}
HelloSafari_Auto = HelloSafari_16_0
- HelloSafari_16_0 = ClientHelloID{helloSafari, "16.0", nil}
+ HelloSafari_16_0 = ClientHelloID{helloSafari, "16.0", nil, nil}
Hello360_Auto = Hello360_7_5 // Hello360_11_0 seems to be incompatible with this library
- Hello360_7_5 = ClientHelloID{hello360, "7.5", nil}
- Hello360_11_0 = ClientHelloID{hello360, "11.0", nil}
+ Hello360_7_5 = ClientHelloID{hello360, "7.5", nil, nil}
+ Hello360_11_0 = ClientHelloID{hello360, "11.0", nil, nil}
HelloQQ_Auto = HelloQQ_11_1
- HelloQQ_11_1 = ClientHelloID{helloQQ, "11.1", nil}
+ HelloQQ_11_1 = ClientHelloID{helloQQ, "11.1", nil, nil}
)
+type Weights struct {
+ Extensions_Append_ALPN float64
+ TLSVersMax_Set_VersionTLS13 float64
+ CipherSuites_Remove_RandomCiphers float64
+ SigAndHashAlgos_Append_ECDSAWithSHA1 float64
+ SigAndHashAlgos_Append_ECDSAWithP521AndSHA512 float64
+ SigAndHashAlgos_Append_PSSWithSHA256 float64
+ SigAndHashAlgos_Append_PSSWithSHA384_PSSWithSHA512 float64
+ CurveIDs_Append_X25519 float64
+ CurveIDs_Append_CurveP521 float64
+ Extensions_Append_Padding float64
+ Extensions_Append_Status float64
+ Extensions_Append_SCT float64
+ Extensions_Append_Reneg float64
+ Extensions_Append_EMS float64
+ FirstKeyShare_Set_CurveP256 float64
+ Extensions_Append_ALPS float64
+}
+
+// Do not modify them directly as they may being used. If you
+// want to use your custom weights, please make a copy first.
+var DefaultWeights = Weights{
+ Extensions_Append_ALPN: 0.7,
+ TLSVersMax_Set_VersionTLS13: 0.4,
+ CipherSuites_Remove_RandomCiphers: 0.4,
+ SigAndHashAlgos_Append_ECDSAWithSHA1: 0.63,
+ SigAndHashAlgos_Append_ECDSAWithP521AndSHA512: 0.59,
+ SigAndHashAlgos_Append_PSSWithSHA256: 0.51,
+ SigAndHashAlgos_Append_PSSWithSHA384_PSSWithSHA512: 0.9,
+ CurveIDs_Append_X25519: 0.71,
+ CurveIDs_Append_CurveP521: 0.46,
+ Extensions_Append_Padding: 0.62,
+ Extensions_Append_Status: 0.74,
+ Extensions_Append_SCT: 0.46,
+ Extensions_Append_Reneg: 0.75,
+ Extensions_Append_EMS: 0.77,
+ FirstKeyShare_Set_CurveP256: 0.25,
+ Extensions_Append_ALPS: 0.33,
+}
+
// based on spec's GreaseStyle, GREASE_PLACEHOLDER may be replaced by another GREASE value
// https://tools.ietf.org/html/draft-ietf-tls-grease-01
const GREASE_PLACEHOLDER = 0x0a0a
diff --git a/u_conn.go b/u_conn.go
index 78c9bba..caff710 100644
--- a/u_conn.go
+++ b/u_conn.go
@@ -399,7 +399,10 @@ func (c *UConn) clientHandshake(ctx context.Context) (err error) {
}
// [uTLS section ends]
- cacheKey, session, earlySecret, binderKey := c.loadSession(hello)
+ cacheKey, session, earlySecret, binderKey, err := c.loadSession(hello)
+ if err != nil {
+ return err
+ }
if cacheKey != "" && session != nil {
defer func() {
// If we got a handshake failure when resuming a session, throw away
@@ -421,11 +424,11 @@ func (c *UConn) clientHandshake(ctx context.Context) (err error) {
}
}
- if _, err := c.writeRecord(recordTypeHandshake, hello.marshal()); err != nil {
+ if _, err := c.writeHandshakeRecord(hello, nil); err != nil {
return err
}
- msg, err := c.readHandshake()
+ msg, err := c.readHandshake(nil)
if err != nil {
return err
}
@@ -596,7 +599,7 @@ func (uconn *UConn) SetTLSVers(minTLSVers, maxTLSVers uint16, specExtensions []T
minVers := uint16(0)
maxVers := uint16(0)
for _, vers := range versions {
- if vers == GREASE_PLACEHOLDER {
+ if isGREASEUint16(vers) {
continue
}
if maxVers < vers || maxVers == 0 {
diff --git a/u_conn_test.go b/u_conn_test.go
index 2c520b1..68c8411 100644
--- a/u_conn_test.go
+++ b/u_conn_test.go
@@ -47,7 +47,10 @@ func TestUTLSMarshalNoOp(t *testing.T) {
t.Errorf("Got error: %s; expected to succeed", err)
}
msg.raw = []byte(str)
- marshalledHello := msg.marshal()
+ marshalledHello, err := msg.marshal()
+ if err != nil {
+ t.Errorf("clientHelloMsg.marshal() returned error: %s", err.Error())
+ }
if strings.Compare(string(marshalledHello), str) != 0 {
t.Errorf("clientHelloMsg.marshal() is not NOOP! Expected to get: %s, got: %s", str, string(marshalledHello))
}
diff --git a/u_fingerprinter.go b/u_fingerprinter.go
index f4d4998..1e1e1c8 100644
--- a/u_fingerprinter.go
+++ b/u_fingerprinter.go
@@ -4,18 +4,8 @@
package tls
-import (
- "errors"
- "fmt"
- "strings"
-
- "golang.org/x/crypto/cryptobyte"
-)
-
// Fingerprinter is a struct largely for holding options for the FingerprintClientHello func
type Fingerprinter struct {
- // KeepPSK will ensure that the PreSharedKey extension is passed along into the resulting ClientHelloSpec as-is
- KeepPSK bool
// AllowBluntMimicry will ensure that unknown extensions are
// passed along into the resulting ClientHelloSpec as-is
// It will not ensure that the PSK is passed along, if you require that, use KeepPSK
@@ -40,390 +30,42 @@ type Fingerprinter struct {
// as well as the handshake type/length/version header
// https://tools.ietf.org/html/rfc5246#section-6.2
// https://tools.ietf.org/html/rfc5246#section-7.4
-func (f *Fingerprinter) FingerprintClientHello(data []byte) (*ClientHelloSpec, error) {
- clientHelloSpec := &ClientHelloSpec{}
- s := cryptobyte.String(data)
-
- var contentType uint8
- var recordVersion uint16
- if !s.ReadUint8(&contentType) || // record type
- !s.ReadUint16(&recordVersion) || !s.Skip(2) { // record version and length
- return nil, errors.New("unable to read record type, version, and length")
- }
-
- if recordType(contentType) != recordTypeHandshake {
- return nil, errors.New("record is not a handshake")
- }
-
- var handshakeVersion uint16
- var handshakeType uint8
-
- if !s.ReadUint8(&handshakeType) || !s.Skip(3) || // message type and 3 byte length
- !s.ReadUint16(&handshakeVersion) || !s.Skip(32) { // 32 byte random
- return nil, errors.New("unable to read handshake message type, length, and random")
- }
-
- if handshakeType != typeClientHello {
- return nil, errors.New("handshake message is not a ClientHello")
- }
-
- clientHelloSpec.TLSVersMin = recordVersion
- clientHelloSpec.TLSVersMax = handshakeVersion
-
- var ignoredSessionID cryptobyte.String
- if !s.ReadUint8LengthPrefixed(&ignoredSessionID) {
- return nil, errors.New("unable to read session id")
- }
-
- var cipherSuitesBytes cryptobyte.String
- if !s.ReadUint16LengthPrefixed(&cipherSuitesBytes) {
- return nil, errors.New("unable to read ciphersuites")
- }
- cipherSuites := []uint16{}
- for !cipherSuitesBytes.Empty() {
- var suite uint16
- if !cipherSuitesBytes.ReadUint16(&suite) {
- return nil, errors.New("unable to read ciphersuite")
- }
- cipherSuites = append(cipherSuites, unGREASEUint16(suite))
- }
- clientHelloSpec.CipherSuites = cipherSuites
-
- if !readUint8LengthPrefixed(&s, &clientHelloSpec.CompressionMethods) {
- return nil, errors.New("unable to read compression methods")
- }
+//
+// It calls UnmarshalClientHello internally, and is kept for backwards compatibility
+func (f *Fingerprinter) FingerprintClientHello(data []byte) (clientHelloSpec *ClientHelloSpec, err error) {
+ return f.RawClientHello(data)
+}
- if s.Empty() {
- // ClientHello is optionally followed by extension data
- return clientHelloSpec, nil
+// RawClientHello returns a ClientHelloSpec which is based on the
+// ClientHello raw bytes that is passed in as the raw argument.
+//
+// It was renamed from FingerprintClientHello in v1.3.1 and earlier versions
+// as a more precise name for the function
+func (f *Fingerprinter) RawClientHello(raw []byte) (clientHelloSpec *ClientHelloSpec, err error) {
+ clientHelloSpec = &ClientHelloSpec{}
+ err = clientHelloSpec.FromRaw(raw, f.AllowBluntMimicry)
+ if err != nil {
+ return nil, err
}
- var extensions cryptobyte.String
- if !s.ReadUint16LengthPrefixed(&extensions) {
- return nil, errors.New("unable to read extensions data")
+ if f.AlwaysAddPadding {
+ clientHelloSpec.AlwaysAddPadding()
}
- for !extensions.Empty() {
- var extension uint16
- var extData cryptobyte.String
- if !extensions.ReadUint16(&extension) ||
- !extensions.ReadUint16LengthPrefixed(&extData) {
- return nil, errors.New("unable to read extension data")
- }
-
- switch extension {
- case extensionServerName:
- // RFC 6066, Section 3
- var nameList cryptobyte.String
- if !extData.ReadUint16LengthPrefixed(&nameList) || nameList.Empty() {
- return nil, errors.New("unable to read server name extension data")
- }
- var serverName string
- for !nameList.Empty() {
- var nameType uint8
- var serverNameBytes cryptobyte.String
- if !nameList.ReadUint8(&nameType) ||
- !nameList.ReadUint16LengthPrefixed(&serverNameBytes) ||
- serverNameBytes.Empty() {
- return nil, errors.New("unable to read server name extension data")
- }
- if nameType != 0 {
- continue
- }
- if len(serverName) != 0 {
- return nil, errors.New("multiple names of the same name_type in server name extension are prohibited")
- }
- serverName = string(serverNameBytes)
- if strings.HasSuffix(serverName, ".") {
- return nil, errors.New("SNI value may not include a trailing dot")
- }
-
- clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &SNIExtension{})
-
- }
- case extensionNextProtoNeg:
- // draft-agl-tls-nextprotoneg-04
- clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &NPNExtension{})
-
- case extensionStatusRequest:
- // RFC 4366, Section 3.6
- var statusType uint8
- var ignored cryptobyte.String
- if !extData.ReadUint8(&statusType) ||
- !extData.ReadUint16LengthPrefixed(&ignored) ||
- !extData.ReadUint16LengthPrefixed(&ignored) {
- return nil, errors.New("unable to read status request extension data")
- }
-
- if statusType == statusTypeOCSP {
- clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &StatusRequestExtension{})
- } else {
- return nil, errors.New("status request extension statusType is not statusTypeOCSP")
- }
-
- case extensionSupportedCurves:
- // RFC 4492, sections 5.1.1 and RFC 8446, Section 4.2.7
- var curvesBytes cryptobyte.String
- if !extData.ReadUint16LengthPrefixed(&curvesBytes) || curvesBytes.Empty() {
- return nil, errors.New("unable to read supported curves extension data")
- }
- curves := []CurveID{}
- for !curvesBytes.Empty() {
- var curve uint16
- if !curvesBytes.ReadUint16(&curve) {
- return nil, errors.New("unable to read supported curves extension data")
- }
- curves = append(curves, CurveID(unGREASEUint16(curve)))
- }
- clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &SupportedCurvesExtension{curves})
-
- case extensionSupportedPoints:
- // RFC 4492, Section 5.1.2
- supportedPoints := []uint8{}
- if !readUint8LengthPrefixed(&extData, &supportedPoints) ||
- len(supportedPoints) == 0 {
- return nil, errors.New("unable to read supported points extension data")
- }
- clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &SupportedPointsExtension{supportedPoints})
-
- case extensionSessionTicket:
- // RFC 5077, Section 3.2
- clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &SessionTicketExtension{})
-
- case extensionSignatureAlgorithms:
- // RFC 5246, Section 7.4.1.4.1
- var sigAndAlgs cryptobyte.String
- if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() {
- return nil, errors.New("unable to read signature algorithms extension data")
- }
- supportedSignatureAlgorithms := []SignatureScheme{}
- for !sigAndAlgs.Empty() {
- var sigAndAlg uint16
- if !sigAndAlgs.ReadUint16(&sigAndAlg) {
- return nil, errors.New("unable to read signature algorithms extension data")
- }
- supportedSignatureAlgorithms = append(
- supportedSignatureAlgorithms, SignatureScheme(sigAndAlg))
- }
- clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &SignatureAlgorithmsExtension{supportedSignatureAlgorithms})
-
- case extensionSignatureAlgorithmsCert:
- // RFC 8446, Section 4.2.3
- if f.AllowBluntMimicry {
- clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &GenericExtension{extension, extData})
- } else {
- return nil, errors.New("unsupported extension SignatureAlgorithmsCert")
- }
-
- case extensionRenegotiationInfo:
- // RFC 5746, Section 3.2
- clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &RenegotiationInfoExtension{RenegotiateOnceAsClient})
-
- case extensionALPN:
- // RFC 7301, Section 3.1
- var protoList cryptobyte.String
- if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() {
- return nil, errors.New("unable to read ALPN extension data")
- }
- alpnProtocols := []string{}
- for !protoList.Empty() {
- var proto cryptobyte.String
- if !protoList.ReadUint8LengthPrefixed(&proto) || proto.Empty() {
- return nil, errors.New("unable to read ALPN extension data")
- }
- alpnProtocols = append(alpnProtocols, string(proto))
-
- }
- clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &ALPNExtension{alpnProtocols})
-
- case extensionSCT:
- // RFC 6962, Section 3.3.1
- clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &SCTExtension{})
-
- case extensionSupportedVersions:
- // RFC 8446, Section 4.2.1
- var versList cryptobyte.String
- if !extData.ReadUint8LengthPrefixed(&versList) || versList.Empty() {
- return nil, errors.New("unable to read supported versions extension data")
- }
- supportedVersions := []uint16{}
- for !versList.Empty() {
- var vers uint16
- if !versList.ReadUint16(&vers) {
- return nil, errors.New("unable to read supported versions extension data")
- }
- supportedVersions = append(supportedVersions, unGREASEUint16(vers))
- }
- clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &SupportedVersionsExtension{supportedVersions})
- // If SupportedVersionsExtension is present, use that instead of record+handshake versions
- clientHelloSpec.TLSVersMin = 0
- clientHelloSpec.TLSVersMax = 0
-
- case extensionKeyShare:
- // RFC 8446, Section 4.2.8
- var clientShares cryptobyte.String
- if !extData.ReadUint16LengthPrefixed(&clientShares) {
- return nil, errors.New("unable to read key share extension data")
- }
- keyShares := []KeyShare{}
- for !clientShares.Empty() {
- var ks KeyShare
- var group uint16
- if !clientShares.ReadUint16(&group) ||
- !readUint16LengthPrefixed(&clientShares, &ks.Data) ||
- len(ks.Data) == 0 {
- return nil, errors.New("unable to read key share extension data")
- }
- ks.Group = CurveID(unGREASEUint16(group))
- // if not GREASE, key share data will be discarded as it should
- // be generated per connection
- if ks.Group != GREASE_PLACEHOLDER {
- ks.Data = nil
- }
- keyShares = append(keyShares, ks)
- }
- clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &KeyShareExtension{keyShares})
-
- case extensionPSKModes:
- // RFC 8446, Section 4.2.9
- // TODO: PSK Modes have their own form of GREASE-ing which is not currently implemented
- // the current functionality will NOT re-GREASE/re-randomize these values when using a fingerprinted spec
- // https://github.com/refraction-networking/utls/pull/58#discussion_r522354105
- // https://tools.ietf.org/html/draft-ietf-tls-grease-01#section-2
- pskModes := []uint8{}
- if !readUint8LengthPrefixed(&extData, &pskModes) {
- return nil, errors.New("unable to read PSK extension data")
- }
- clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &PSKKeyExchangeModesExtension{pskModes})
-
- case utlsExtensionExtendedMasterSecret:
- // https://tools.ietf.org/html/rfc7627
- clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &UtlsExtendedMasterSecretExtension{})
-
- case utlsExtensionPadding:
- clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &UtlsPaddingExtension{GetPaddingLen: BoringPaddingStyle})
-
- case utlsExtensionCompressCertificate:
- methods := []CertCompressionAlgo{}
- methodsRaw := new(cryptobyte.String)
- if !extData.ReadUint8LengthPrefixed(methodsRaw) {
- return nil, errors.New("unable to read cert compression algorithms extension data")
- }
- for !methodsRaw.Empty() {
- var method uint16
- if !methodsRaw.ReadUint16(&method) {
- return nil, errors.New("unable to read cert compression algorithms extension data")
- }
- methods = append(methods, CertCompressionAlgo(method))
- }
- clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &UtlsCompressCertExtension{methods})
-
- case fakeExtensionChannelID:
- clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &FakeChannelIDExtension{})
-
- case fakeOldExtensionChannelID:
- clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &FakeChannelIDExtension{true})
-
- case fakeExtensionTokenBinding:
- var tokenBindingExt FakeTokenBindingExtension
- var keyParameters cryptobyte.String
- if !extData.ReadUint8(&tokenBindingExt.MajorVersion) ||
- !extData.ReadUint8(&tokenBindingExt.MinorVersion) ||
- !extData.ReadUint8LengthPrefixed(&keyParameters) {
- return nil, errors.New("unable to read token binding extension data")
- }
- tokenBindingExt.KeyParameters = keyParameters
- clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &tokenBindingExt)
-
- case utlsExtensionApplicationSettings:
- // Similar to ALPN (RFC 7301, Section 3.1):
- // https://datatracker.ietf.org/doc/html/draft-vvv-tls-alps#section-3
- var protoList cryptobyte.String
- if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() {
- return nil, errors.New("unable to read ALPS extension data")
- }
- supportedProtocols := []string{}
- for !protoList.Empty() {
- var proto cryptobyte.String
- if !protoList.ReadUint8LengthPrefixed(&proto) || proto.Empty() {
- return nil, errors.New("unable to read ALPS extension data")
- }
- supportedProtocols = append(supportedProtocols, string(proto))
- }
- clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &ApplicationSettingsExtension{supportedProtocols})
-
- case fakeRecordSizeLimit:
- recordSizeExt := new(FakeRecordSizeLimitExtension)
- if !extData.ReadUint16(&recordSizeExt.Limit) {
- return nil, errors.New("unable to read record size limit extension data")
- }
- clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, recordSizeExt)
-
- case fakeExtensionDelegatedCredentials:
- //https://datatracker.ietf.org/doc/html/draft-ietf-tls-subcerts-15#section-4.1.1
- var supportedAlgs cryptobyte.String
- if !extData.ReadUint16LengthPrefixed(&supportedAlgs) || supportedAlgs.Empty() {
- return nil, errors.New("unable to read signature algorithms extension data")
- }
- supportedSignatureAlgorithms := []SignatureScheme{}
- for !supportedAlgs.Empty() {
- var sigAndAlg uint16
- if !supportedAlgs.ReadUint16(&sigAndAlg) {
- return nil, errors.New("unable to read signature algorithms extension data")
- }
- supportedSignatureAlgorithms = append(
- supportedSignatureAlgorithms, SignatureScheme(sigAndAlg))
- }
- clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &FakeDelegatedCredentialsExtension{supportedSignatureAlgorithms})
-
- case extensionPreSharedKey:
- // RFC 8446, Section 4.2.11
- if f.KeepPSK {
- clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &GenericExtension{extension, extData})
- } else {
- return nil, errors.New("unsupported extension PreSharedKey")
- }
-
- case extensionCookie:
- // RFC 8446, Section 4.2.2
- if f.AllowBluntMimicry {
- clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &GenericExtension{extension, extData})
- } else {
- return nil, errors.New("unsupported extension Cookie")
- }
-
- case extensionEarlyData:
- // RFC 8446, Section 4.2.10
- if f.AllowBluntMimicry {
- clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &GenericExtension{extension, extData})
- } else {
- return nil, errors.New("unsupported extension EarlyData")
- }
-
- default:
- if isGREASEUint16(extension) {
- clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &UtlsGREASEExtension{unGREASEUint16(extension), extData})
- } else if f.AllowBluntMimicry {
- clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &GenericExtension{extension, extData})
- } else {
- return nil, fmt.Errorf("unsupported extension %d", extension)
- }
+ return clientHelloSpec, nil
+}
- continue
- }
+// UnmarshalJSONClientHello returns a ClientHelloSpec which is based on the
+// ClientHello JSON bytes that is passed in as the json argument.
+func (f *Fingerprinter) UnmarshalJSONClientHello(json []byte) (clientHelloSpec *ClientHelloSpec, err error) {
+ clientHelloSpec = &ClientHelloSpec{}
+ err = clientHelloSpec.UnmarshalJSON(json)
+ if err != nil {
+ return nil, err
}
if f.AlwaysAddPadding {
- alreadyHasPadding := false
- for _, ext := range clientHelloSpec.Extensions {
- if _, ok := ext.(*UtlsPaddingExtension); ok {
- alreadyHasPadding = true
- break
- }
- }
- if !alreadyHasPadding {
- clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &UtlsPaddingExtension{GetPaddingLen: BoringPaddingStyle})
- }
+ clientHelloSpec.AlwaysAddPadding()
}
return clientHelloSpec, nil
diff --git a/u_fingerprinter_test.go b/u_fingerprinter_test.go
index 0f761aa..ad140bf 100644
--- a/u_fingerprinter_test.go
+++ b/u_fingerprinter_test.go
@@ -501,12 +501,6 @@ func TestUTLSFingerprintClientHelloKeepPSK(t *testing.T) {
}
f := &Fingerprinter{}
- _, err = f.FingerprintClientHello(helloBytes)
- if err == nil {
- t.Errorf("expected error generating spec from client hello with PSK")
- }
-
- f = &Fingerprinter{KeepPSK: true}
generatedSpec, err := f.FingerprintClientHello(helloBytes)
if err != nil {
t.Errorf("got error: %v; expected to succeed", err)
@@ -514,10 +508,8 @@ func TestUTLSFingerprintClientHelloKeepPSK(t *testing.T) {
}
for _, ext := range generatedSpec.Extensions {
- if genericExtension, ok := (ext).(*GenericExtension); ok {
- if genericExtension.Id == extensionPreSharedKey {
- return
- }
+ if _, ok := (ext).(*FakePreSharedKeyExtension); ok {
+ return
}
}
t.Errorf("generated ClientHelloSpec with KeepPSK does not include preshared key extension")
diff --git a/u_handshake_client.go b/u_handshake_client.go
index 6166f6e..f59ffde 100644
--- a/u_handshake_client.go
+++ b/u_handshake_client.go
@@ -25,7 +25,9 @@ func (hs *clientHandshakeStateTLS13) utlsReadServerCertificate(msg any) (process
if len(hs.uconn.certCompressionAlgs) > 0 {
compressedCertMsg, ok := msg.(*utlsCompressedCertificateMsg)
if ok {
- hs.transcript.Write(compressedCertMsg.marshal())
+ if err = transcriptMsg(compressedCertMsg, hs.transcript); err != nil {
+ return nil, err
+ }
msg, err = hs.decompressCert(*compressedCertMsg)
if err != nil {
return nil, fmt.Errorf("tls: failed to decompress certificate message: %w", err)
@@ -128,8 +130,7 @@ func (hs *clientHandshakeStateTLS13) sendClientEncryptedExtensions() error {
if c.utls.hasApplicationSettings {
clientEncryptedExtensions.hasApplicationSettings = true
clientEncryptedExtensions.applicationSettings = c.utls.localApplicationSettings
- hs.transcript.Write(clientEncryptedExtensions.marshal())
- if _, err := c.writeRecord(recordTypeHandshake, clientEncryptedExtensions.marshal()); err != nil {
+ if _, err := c.writeHandshakeRecord(clientEncryptedExtensions, hs.transcript); err != nil {
return err
}
}
diff --git a/u_handshake_messages.go b/u_handshake_messages.go
index 20f9b4e..e7ebb15 100644
--- a/u_handshake_messages.go
+++ b/u_handshake_messages.go
@@ -20,9 +20,9 @@ type utlsCompressedCertificateMsg struct {
compressedCertificateMessage []byte
}
-func (m *utlsCompressedCertificateMsg) marshal() []byte {
+func (m *utlsCompressedCertificateMsg) marshal() ([]byte, error) {
if m.raw != nil {
- return m.raw
+ return m.raw, nil
}
var b cryptobyte.Builder
@@ -35,8 +35,9 @@ func (m *utlsCompressedCertificateMsg) marshal() []byte {
})
})
- m.raw = b.BytesOrPanic()
- return m.raw
+ var err error
+ m.raw, err = b.Bytes()
+ return m.raw, err
}
func (m *utlsCompressedCertificateMsg) unmarshal(data []byte) bool {
@@ -74,9 +75,9 @@ type utlsClientEncryptedExtensionsMsg struct {
customExtension []byte
}
-func (m *utlsClientEncryptedExtensionsMsg) marshal() (x []byte) {
+func (m *utlsClientEncryptedExtensionsMsg) marshal() (x []byte, err error) {
if m.raw != nil {
- return m.raw
+ return m.raw, nil
}
var builder cryptobyte.Builder
@@ -98,8 +99,8 @@ func (m *utlsClientEncryptedExtensionsMsg) marshal() (x []byte) {
})
})
- m.raw = builder.BytesOrPanic()
- return m.raw
+ m.raw, err = builder.Bytes()
+ return m.raw, err
}
func (m *utlsClientEncryptedExtensionsMsg) unmarshal(data []byte) bool {
diff --git a/u_parrots.go b/u_parrots.go
index cf2fe39..e2765b8 100644
--- a/u_parrots.go
+++ b/u_parrots.go
@@ -508,6 +508,77 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) {
&UtlsPaddingExtension{GetPaddingLen: BoringPaddingStyle},
},
}, nil
+ case HelloChrome_100_PSK:
+ return ClientHelloSpec{
+ CipherSuites: []uint16{
+ GREASE_PLACEHOLDER,
+ TLS_AES_128_GCM_SHA256,
+ TLS_AES_256_GCM_SHA384,
+ TLS_CHACHA20_POLY1305_SHA256,
+ TLS_ECDHE_ECDSA_WITH_AES_128_GCM_SHA256,
+ TLS_ECDHE_RSA_WITH_AES_128_GCM_SHA256,
+ TLS_ECDHE_ECDSA_WITH_AES_256_GCM_SHA384,
+ TLS_ECDHE_RSA_WITH_AES_256_GCM_SHA384,
+ TLS_ECDHE_ECDSA_WITH_CHACHA20_POLY1305,
+ TLS_ECDHE_RSA_WITH_CHACHA20_POLY1305,
+ TLS_ECDHE_RSA_WITH_AES_128_CBC_SHA,
+ TLS_ECDHE_RSA_WITH_AES_256_CBC_SHA,
+ TLS_RSA_WITH_AES_128_GCM_SHA256,
+ TLS_RSA_WITH_AES_256_GCM_SHA384,
+ TLS_RSA_WITH_AES_128_CBC_SHA,
+ TLS_RSA_WITH_AES_256_CBC_SHA,
+ },
+ CompressionMethods: []byte{
+ 0x00, // compressionNone
+ },
+ Extensions: []TLSExtension{
+ &UtlsGREASEExtension{},
+ &SNIExtension{},
+ &UtlsExtendedMasterSecretExtension{},
+ &RenegotiationInfoExtension{Renegotiation: RenegotiateOnceAsClient},
+ &SupportedCurvesExtension{[]CurveID{
+ GREASE_PLACEHOLDER,
+ X25519,
+ CurveP256,
+ CurveP384,
+ }},
+ &SupportedPointsExtension{SupportedPoints: []byte{
+ 0x00, // pointFormatUncompressed
+ }},
+ &SessionTicketExtension{},
+ &ALPNExtension{AlpnProtocols: []string{"h2", "http/1.1"}},
+ &StatusRequestExtension{},
+ &SignatureAlgorithmsExtension{SupportedSignatureAlgorithms: []SignatureScheme{
+ ECDSAWithP256AndSHA256,
+ PSSWithSHA256,
+ PKCS1WithSHA256,
+ ECDSAWithP384AndSHA384,
+ PSSWithSHA384,
+ PKCS1WithSHA384,
+ PSSWithSHA512,
+ PKCS1WithSHA512,
+ }},
+ &SCTExtension{},
+ &KeyShareExtension{[]KeyShare{
+ {Group: CurveID(GREASE_PLACEHOLDER), Data: []byte{0}},
+ {Group: X25519},
+ }},
+ &PSKKeyExchangeModesExtension{[]uint8{
+ PskModeDHE,
+ }},
+ &SupportedVersionsExtension{[]uint16{
+ GREASE_PLACEHOLDER,
+ VersionTLS13,
+ VersionTLS12,
+ }},
+ &UtlsCompressCertExtension{[]CertCompressionAlgo{
+ CertCompressionBrotli,
+ }},
+ &ApplicationSettingsExtension{SupportedProtocols: []string{"h2"}},
+ &UtlsGREASEExtension{},
+ &FakePreSharedKeyExtension{},
+ },
+ }, nil
case HelloChrome_106_Shuffle:
chs, err := utlsIdToSpec(HelloChrome_102)
if err != nil {
@@ -515,7 +586,17 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) {
}
// Chrome 107 started shuffling the order of extensions
- return shuffleExtensions(chs)
+ shuffleExtensions(&chs)
+ return chs, err
+ case HelloChrome_112_PSK_Shuf:
+ chs, err := utlsIdToSpec(HelloChrome_100_PSK)
+ if err != nil {
+ return chs, err
+ }
+
+ // Chrome 112 started shuffling the order of extensions
+ shuffleExtensions(&chs)
+ return chs, err
case HelloFirefox_55, HelloFirefox_56:
return ClientHelloSpec{
TLSVersMax: VersionTLS12,
@@ -680,8 +761,8 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) {
&SessionTicketExtension{},
&ALPNExtension{AlpnProtocols: []string{"h2", "http/1.1"}}, //application_layer_protocol_negotiation
&StatusRequestExtension{},
- &DelegatedCredentialsExtension{
- AlgorithmsSignature: []SignatureScheme{ //signature_algorithms
+ &FakeDelegatedCredentialsExtension{
+ SupportedSignatureAlgorithms: []SignatureScheme{ //signature_algorithms
ECDSAWithP256AndSHA256,
ECDSAWithP384AndSHA384,
ECDSAWithP521AndSHA512,
@@ -761,8 +842,8 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) {
&SessionTicketExtension{},
&ALPNExtension{AlpnProtocols: []string{"h2"}}, //application_layer_protocol_negotiation
&StatusRequestExtension{},
- &DelegatedCredentialsExtension{
- AlgorithmsSignature: []SignatureScheme{ //signature_algorithms
+ &FakeDelegatedCredentialsExtension{
+ SupportedSignatureAlgorithms: []SignatureScheme{ //signature_algorithms
ECDSAWithP256AndSHA256,
ECDSAWithP384AndSHA384,
ECDSAWithP521AndSHA512,
@@ -1847,65 +1928,55 @@ func utlsIdToSpec(id ClientHelloID) (ClientHelloSpec, error) {
default:
if id.Client == helloRandomized || id.Client == helloRandomizedALPN || id.Client == helloRandomizedNoALPN {
// Use empty values as they can be filled later by UConn.ApplyPreset or manually.
- return generateRandomizedSpec(id, "", nil, nil)
+ return generateRandomizedSpec(&id, "", nil, nil)
}
return ClientHelloSpec{}, errors.New("ClientHello ID " + id.Str() + " is unknown")
}
}
-func shuffleExtensions(chs ClientHelloSpec) (ClientHelloSpec, error) {
+func shuffleExtensions(chs *ClientHelloSpec) error {
// Shuffle extensions to avoid fingerprinting -- introduced in Chrome 106
- // GREASE, padding will remain in place (if present)
+ var err error = nil
- // Find indexes of GREASE and padding extensions
- var greaseIdx []int
- var paddingIdx []int
- var otherExtensions []TLSExtension
-
- for i, ext := range chs.Extensions {
- switch ext.(type) {
+ // unshufCheck checks:
+ // - if the exts[idx] is a GREASE extension, then it should not be shuffled
+ // - if the exts[idx] is a padding/pre_shared_key extension, then it should be the
+ // last extension in the list and should not be shuffled
+ var unshufCheck = func(idx int, exts []TLSExtension) (donotshuf bool, userErr error) {
+ switch exts[idx].(type) {
case *UtlsGREASEExtension:
- greaseIdx = append(greaseIdx, i)
- case *UtlsPaddingExtension:
- paddingIdx = append(paddingIdx, i)
+ donotshuf = true
+ case *UtlsPaddingExtension, *FakePreSharedKeyExtension:
+ donotshuf = true
+ if idx != len(chs.Extensions)-1 {
+ userErr = errors.New("UtlsPaddingExtension or FakePreSharedKeyExtension must be the last extension")
+ }
default:
- otherExtensions = append(otherExtensions, ext)
+ donotshuf = false
}
+ return
}
// Shuffle other extensions
- rand.Shuffle(len(otherExtensions), func(i, j int) {
- otherExtensions[i], otherExtensions[j] = otherExtensions[j], otherExtensions[i]
- })
-
- // Rebuild extensions slice
- otherExtIdx := 0
-SHUF_EXTENSIONS:
- for i := 0; i < len(chs.Extensions); i++ {
- // if current index is in greaseIdx or paddingIdx, add GREASE or padding extension
- for _, idx := range greaseIdx {
- if i == idx {
- chs.Extensions[i] = &UtlsGREASEExtension{}
- continue SHUF_EXTENSIONS
+ rand.Shuffle(len(chs.Extensions), func(i, j int) {
+ if unshuf, shuferr := unshufCheck(i, chs.Extensions); unshuf {
+ if shuferr != nil {
+ err = shuferr
}
+ return
}
- for _, idx := range paddingIdx {
- if i == idx {
- chs.Extensions[i] = &UtlsPaddingExtension{
- GetPaddingLen: BoringPaddingStyle,
- }
- break SHUF_EXTENSIONS
+
+ if unshuf, shuferr := unshufCheck(j, chs.Extensions); unshuf {
+ if shuferr != nil {
+ err = shuferr
}
+ return
}
- // otherwise add other extension
- chs.Extensions[i] = otherExtensions[otherExtIdx]
- otherExtIdx++
- }
- if otherExtIdx != len(otherExtensions) {
- return ClientHelloSpec{}, errors.New("shuffleExtensions: otherExtIdx != len(otherExtensions)")
- }
- return chs, nil
+ chs.Extensions[i], chs.Extensions[j] = chs.Extensions[j], chs.Extensions[i]
+ })
+
+ return err
}
func (uconn *UConn) applyPresetByID(id ClientHelloID) (err error) {
@@ -1948,6 +2019,7 @@ func (uconn *UConn) ApplyPreset(p *ClientHelloSpec) error {
}
uconn.HandshakeState.Hello = privateHello.getPublicPtr()
uconn.HandshakeState.State13.EcdheParams = ecdheParams
+ uconn.HandshakeState.State13.KeySharesEcdheParams = make(KeySharesEcdheParameters, 2)
hello := uconn.HandshakeState.Hello
session := uconn.HandshakeState.Session
@@ -1988,7 +2060,7 @@ func (uconn *UConn) ApplyPreset(p *ClientHelloSpec) error {
hello.CipherSuites = make([]uint16, len(p.CipherSuites))
copy(hello.CipherSuites, p.CipherSuites)
for i := range hello.CipherSuites {
- if hello.CipherSuites[i] == GREASE_PLACEHOLDER {
+ if isGREASEUint16(hello.CipherSuites[i]) { // just in case the user set a GREASE value instead of unGREASEd
hello.CipherSuites[i] = GetBoringGREASEValue(uconn.greaseSeed, ssl_grease_cipher)
}
}
@@ -2029,7 +2101,7 @@ func (uconn *UConn) ApplyPreset(p *ClientHelloSpec) error {
}
case *SupportedCurvesExtension:
for i := range ext.Curves {
- if ext.Curves[i] == GREASE_PLACEHOLDER {
+ if isGREASEUint16(uint16(ext.Curves[i])) {
ext.Curves[i] = CurveID(GetBoringGREASEValue(uconn.greaseSeed, ssl_grease_group))
}
}
@@ -2037,7 +2109,7 @@ func (uconn *UConn) ApplyPreset(p *ClientHelloSpec) error {
preferredCurveIsSet := false
for i := range ext.KeyShares {
curveID := ext.KeyShares[i].Group
- if curveID == GREASE_PLACEHOLDER {
+ if isGREASEUint16(uint16(curveID)) { // just in case the user set a GREASE value instead of unGREASEd
ext.KeyShares[i].Group = CurveID(GetBoringGREASEValue(uconn.greaseSeed, ssl_grease_group))
continue
}
@@ -2050,6 +2122,7 @@ func (uconn *UConn) ApplyPreset(p *ClientHelloSpec) error {
return fmt.Errorf("unsupported Curve in KeyShareExtension: %v."+
"To mimic it, fill the Data(key) field manually", curveID)
}
+ uconn.HandshakeState.State13.KeySharesEcdheParams.AddEcdheParams(curveID, ecdheParams)
ext.KeyShares[i].Data = ecdheParams.PublicKey()
if !preferredCurveIsSet {
// only do this once for the first non-grease curve
@@ -2059,7 +2132,7 @@ func (uconn *UConn) ApplyPreset(p *ClientHelloSpec) error {
}
case *SupportedVersionsExtension:
for i := range ext.Versions {
- if ext.Versions[i] == GREASE_PLACEHOLDER {
+ if isGREASEUint16(ext.Versions[i]) { // just in case the user set a GREASE value instead of unGREASEd
ext.Versions[i] = GetBoringGREASEValue(uconn.greaseSeed, ssl_grease_version)
}
}
@@ -2076,11 +2149,11 @@ func (uconn *UConn) ApplyPreset(p *ClientHelloSpec) error {
}
func (uconn *UConn) generateRandomizedSpec() (ClientHelloSpec, error) {
- return generateRandomizedSpec(uconn.ClientHelloID, uconn.serverName, uconn.HandshakeState.Session, uconn.config.NextProtos)
+ return generateRandomizedSpec(&uconn.ClientHelloID, uconn.serverName, uconn.HandshakeState.Session, uconn.config.NextProtos)
}
func generateRandomizedSpec(
- id ClientHelloID,
+ id *ClientHelloID,
serverName string,
session *ClientSessionState,
nextProtos []string,
@@ -2100,6 +2173,10 @@ func generateRandomizedSpec(
return p, err
}
+ if id.Weights == nil {
+ id.Weights = &DefaultWeights
+ }
+
var WithALPN bool
switch id.Client {
case helloRandomizedALPN:
@@ -2107,7 +2184,7 @@ func generateRandomizedSpec(
case helloRandomizedNoALPN:
WithALPN = false
case helloRandomized:
- if r.FlipWeightedCoin(0.7) {
+ if r.FlipWeightedCoin(id.Weights.Extensions_Append_ALPN) {
WithALPN = true
} else {
WithALPN = false
@@ -2123,7 +2200,7 @@ func generateRandomizedSpec(
return p, err
}
- if r.FlipWeightedCoin(0.4) {
+ if r.FlipWeightedCoin(id.Weights.TLSVersMax_Set_VersionTLS13) {
p.TLSVersMin = VersionTLS10
p.TLSVersMax = VersionTLS13
tls13ciphers := make([]uint16, len(defaultCipherSuitesTLS13))
@@ -2141,7 +2218,7 @@ func generateRandomizedSpec(
p.TLSVersMax = VersionTLS12
}
- p.CipherSuites = removeRandomCiphers(r, shuffledSuites, 0.4)
+ p.CipherSuites = removeRandomCiphers(r, shuffledSuites, id.Weights.CipherSuites_Remove_RandomCiphers)
sni := SNIExtension{serverName}
sessionTicket := SessionTicketExtension{Session: session}
@@ -2155,16 +2232,16 @@ func generateRandomizedSpec(
PKCS1WithSHA512,
}
- if r.FlipWeightedCoin(0.63) {
+ if r.FlipWeightedCoin(id.Weights.SigAndHashAlgos_Append_ECDSAWithSHA1) {
sigAndHashAlgos = append(sigAndHashAlgos, ECDSAWithSHA1)
}
- if r.FlipWeightedCoin(0.59) {
+ if r.FlipWeightedCoin(id.Weights.SigAndHashAlgos_Append_ECDSAWithP521AndSHA512) {
sigAndHashAlgos = append(sigAndHashAlgos, ECDSAWithP521AndSHA512)
}
- if r.FlipWeightedCoin(0.51) || p.TLSVersMax == VersionTLS13 {
+ if r.FlipWeightedCoin(id.Weights.SigAndHashAlgos_Append_PSSWithSHA256) || p.TLSVersMax == VersionTLS13 {
// https://tools.ietf.org/html/rfc8446 says "...RSASSA-PSS (which is mandatory in TLS 1.3)..."
sigAndHashAlgos = append(sigAndHashAlgos, PSSWithSHA256)
- if r.FlipWeightedCoin(0.9) {
+ if r.FlipWeightedCoin(id.Weights.SigAndHashAlgos_Append_PSSWithSHA384_PSSWithSHA512) {
// these usually go together
sigAndHashAlgos = append(sigAndHashAlgos, PSSWithSHA384)
sigAndHashAlgos = append(sigAndHashAlgos, PSSWithSHA512)
@@ -2182,11 +2259,11 @@ func generateRandomizedSpec(
points := SupportedPointsExtension{SupportedPoints: []byte{pointFormatUncompressed}}
curveIDs := []CurveID{}
- if r.FlipWeightedCoin(0.71) || p.TLSVersMax == VersionTLS13 {
+ if r.FlipWeightedCoin(id.Weights.CurveIDs_Append_X25519) || p.TLSVersMax == VersionTLS13 {
curveIDs = append(curveIDs, X25519)
}
curveIDs = append(curveIDs, CurveP256, CurveP384)
- if r.FlipWeightedCoin(0.46) {
+ if r.FlipWeightedCoin(id.Weights.CurveIDs_Append_CurveP521) {
curveIDs = append(curveIDs, CurveP521)
}
@@ -2212,28 +2289,28 @@ func generateRandomizedSpec(
p.Extensions = append(p.Extensions, &alpn)
}
- if r.FlipWeightedCoin(0.62) || p.TLSVersMax == VersionTLS13 {
+ if r.FlipWeightedCoin(id.Weights.Extensions_Append_Padding) || p.TLSVersMax == VersionTLS13 {
// always include for TLS 1.3, since TLS 1.3 ClientHellos are often over 256 bytes
// and that's when padding is required to work around buggy middleboxes
p.Extensions = append(p.Extensions, &padding)
}
- if r.FlipWeightedCoin(0.74) {
+ if r.FlipWeightedCoin(id.Weights.Extensions_Append_Status) {
p.Extensions = append(p.Extensions, &status)
}
- if r.FlipWeightedCoin(0.46) {
+ if r.FlipWeightedCoin(id.Weights.Extensions_Append_SCT) {
p.Extensions = append(p.Extensions, &sct)
}
- if r.FlipWeightedCoin(0.75) {
+ if r.FlipWeightedCoin(id.Weights.Extensions_Append_Reneg) {
p.Extensions = append(p.Extensions, &reneg)
}
- if r.FlipWeightedCoin(0.77) {
+ if r.FlipWeightedCoin(id.Weights.Extensions_Append_EMS) {
p.Extensions = append(p.Extensions, &ems)
}
if p.TLSVersMax == VersionTLS13 {
ks := KeyShareExtension{[]KeyShare{
{Group: X25519}, // the key for the group will be generated later
}}
- if r.FlipWeightedCoin(0.25) {
+ if r.FlipWeightedCoin(id.Weights.FirstKeyShare_Set_CurveP256) {
// do not ADD second keyShare because crypto/tls does not support multiple ecdheParams
// TODO: add it back when they implement multiple keyShares, or implement it oursevles
// ks.KeyShares = append(ks.KeyShares, KeyShare{Group: CurveP256})
@@ -2260,7 +2337,7 @@ func generateRandomizedSpec(
if err != nil {
return p, err
}
- if r.FlipWeightedCoin(0.33) {
+ if r.FlipWeightedCoin(id.Weights.Extensions_Append_ALPS) {
// As with the ALPN case above, default to something popular
// (unlike ALPN, ALPS can't yet be specified in uconn.config).
alps := &ApplicationSettingsExtension{SupportedProtocols: []string{"h2"}}
diff --git a/u_public.go b/u_public.go
index e27cfd7..a9d39e2 100644
--- a/u_public.go
+++ b/u_public.go
@@ -34,15 +34,16 @@ type PubClientHandshakeState struct {
// TLS 1.3 only
type TLS13OnlyState struct {
- Suite *PubCipherSuiteTLS13
- EcdheParams EcdheParameters
- EarlySecret []byte
- BinderKey []byte
- CertReq *CertificateRequestMsgTLS13
- UsingPSK bool
- SentDummyCCS bool
- Transcript hash.Hash
- TrafficSecret []byte // client_application_traffic_secret_0
+ Suite *PubCipherSuiteTLS13
+ EcdheParams EcdheParameters
+ KeySharesEcdheParams KeySharesEcdheParameters
+ EarlySecret []byte
+ BinderKey []byte
+ CertReq *CertificateRequestMsgTLS13
+ UsingPSK bool
+ SentDummyCCS bool
+ Transcript hash.Hash
+ TrafficSecret []byte // client_application_traffic_secret_0
}
// TLS 1.2 and before only
@@ -56,10 +57,11 @@ func (chs *PubClientHandshakeState) toPrivate13() *clientHandshakeStateTLS13 {
return nil
} else {
return &clientHandshakeStateTLS13{
- c: chs.C,
- serverHello: chs.ServerHello.getPrivatePtr(),
- hello: chs.Hello.getPrivatePtr(),
- ecdheParams: chs.State13.EcdheParams,
+ c: chs.C,
+ serverHello: chs.ServerHello.getPrivatePtr(),
+ hello: chs.Hello.getPrivatePtr(),
+ ecdheParams: chs.State13.EcdheParams,
+ keySharesEcdheParams: chs.State13.KeySharesEcdheParams,
session: chs.Session,
earlySecret: chs.State13.EarlySecret,
@@ -83,15 +85,16 @@ func (chs13 *clientHandshakeStateTLS13) toPublic13() *PubClientHandshakeState {
return nil
} else {
tls13State := TLS13OnlyState{
- EcdheParams: chs13.ecdheParams,
- EarlySecret: chs13.earlySecret,
- BinderKey: chs13.binderKey,
- CertReq: chs13.certReq.toPublic(),
- UsingPSK: chs13.usingPSK,
- SentDummyCCS: chs13.sentDummyCCS,
- Suite: chs13.suite.toPublic(),
- TrafficSecret: chs13.trafficSecret,
- Transcript: chs13.transcript,
+ KeySharesEcdheParams: chs13.keySharesEcdheParams,
+ EcdheParams: chs13.ecdheParams,
+ EarlySecret: chs13.earlySecret,
+ BinderKey: chs13.binderKey,
+ CertReq: chs13.certReq.toPublic(),
+ UsingPSK: chs13.usingPSK,
+ SentDummyCCS: chs13.sentDummyCCS,
+ Suite: chs13.suite.toPublic(),
+ TrafficSecret: chs13.trafficSecret,
+ Transcript: chs13.transcript,
}
return &PubClientHandshakeState{
C: chs13.c,
@@ -344,7 +347,7 @@ type PubClientHelloMsg struct {
KeyShares []KeyShare
EarlyData bool
PskModes []uint8
- PskIdentities []pskIdentity
+ PskIdentities []PskIdentity
PskBinders [][]byte
}
@@ -379,7 +382,7 @@ func (chm *PubClientHelloMsg) getPrivatePtr() *clientHelloMsg {
keyShares: KeyShares(chm.KeyShares).ToPrivate(),
earlyData: chm.EarlyData,
pskModes: chm.PskModes,
- pskIdentities: chm.PskIdentities,
+ pskIdentities: PskIdentities(chm.PskIdentities).ToPrivate(),
pskBinders: chm.PskBinders,
}
}
@@ -416,7 +419,7 @@ func (chm *clientHelloMsg) getPublicPtr() *PubClientHelloMsg {
KeyShares: keyShares(chm.keyShares).ToPublic(),
EarlyData: chm.earlyData,
PskModes: chm.pskModes,
- PskIdentities: chm.pskIdentities,
+ PskIdentities: pskIdentities(chm.pskIdentities).ToPublic(),
PskBinders: chm.pskBinders,
}
}
@@ -434,7 +437,7 @@ func UnmarshalClientHello(data []byte) *PubClientHelloMsg {
// Marshal allows external code to convert a ClientHello object back into
// raw bytes.
-func (chm *PubClientHelloMsg) Marshal() []byte {
+func (chm *PubClientHelloMsg) Marshal() ([]byte, error) {
return chm.getPrivatePtr().marshal()
}
@@ -540,8 +543,8 @@ func (fh *finishedHash) getPublicObj() FinishedHash {
// TLS 1.3 Key Share. See RFC 8446, Section 4.2.8.
type KeyShare struct {
- Group CurveID
- Data []byte
+ Group CurveID `json:"group"`
+ Data []byte `json:"key_exchange,omitempty"` // optional
}
type KeyShares []KeyShare
@@ -562,6 +565,32 @@ func (KSS KeyShares) ToPrivate() []keyShare {
return kss
}
+// TLS 1.3 PSK Identity. Can be a Session Ticket, or a reference to a saved
+// session. See RFC 8446, Section 4.2.11.
+type PskIdentity struct {
+ Label []byte `json:"identity"`
+ ObfuscatedTicketAge uint32 `json:"obfuscated_ticket_age"`
+}
+
+type PskIdentities []PskIdentity
+type pskIdentities []pskIdentity
+
+func (pss pskIdentities) ToPublic() []PskIdentity {
+ var PSS []PskIdentity
+ for _, ps := range pss {
+ PSS = append(PSS, PskIdentity{Label: ps.label, ObfuscatedTicketAge: ps.obfuscatedTicketAge})
+ }
+ return PSS
+}
+
+func (PSS PskIdentities) ToPrivate() []pskIdentity {
+ var pss []pskIdentity
+ for _, PS := range PSS {
+ pss = append(pss, pskIdentity{label: PS.Label, obfuscatedTicketAge: PS.ObfuscatedTicketAge})
+ }
+ return pss
+}
+
// ClientSessionState is public, but all its fields are private. Let's add setters, getters and constructor
// ClientSessionState contains the state needed by clients to resume TLS sessions.
diff --git a/u_tls_extensions.go b/u_tls_extensions.go
index f386b86..22cfdc2 100644
--- a/u_tls_extensions.go
+++ b/u_tls_extensions.go
@@ -5,10 +5,84 @@
package tls
import (
+ "encoding/json"
"errors"
+ "fmt"
"io"
+ "strings"
+
+ "github.com/gaukas/godicttls"
+ "golang.org/x/crypto/cryptobyte"
)
+// ExtensionFromID returns a TLSExtension for the given extension ID.
+func ExtensionFromID(id uint16) TLSExtension {
+ // deep copy
+ switch id {
+ case extensionServerName:
+ return &SNIExtension{}
+ case extensionStatusRequest:
+ return &StatusRequestExtension{}
+ case extensionSupportedCurves:
+ return &SupportedCurvesExtension{}
+ case extensionSupportedPoints:
+ return &SupportedPointsExtension{}
+ case extensionSignatureAlgorithms:
+ return &SignatureAlgorithmsExtension{}
+ case extensionALPN:
+ return &ALPNExtension{}
+ case extensionStatusRequestV2:
+ return &StatusRequestV2Extension{}
+ case extensionSCT:
+ return &SCTExtension{}
+ case utlsExtensionPadding:
+ return &UtlsPaddingExtension{}
+ case utlsExtensionExtendedMasterSecret:
+ return &UtlsExtendedMasterSecretExtension{}
+ case fakeExtensionTokenBinding:
+ return &FakeTokenBindingExtension{}
+ case utlsExtensionCompressCertificate:
+ return &UtlsCompressCertExtension{}
+ case fakeExtensionDelegatedCredentials:
+ return &FakeDelegatedCredentialsExtension{}
+ case extensionSessionTicket:
+ return &SessionTicketExtension{}
+ case fakeExtensionPreSharedKey:
+ return &FakePreSharedKeyExtension{}
+ // case extensionEarlyData:
+ // return &EarlyDataExtension{}
+ case extensionSupportedVersions:
+ return &SupportedVersionsExtension{}
+ // case extensionCookie:
+ // return &CookieExtension{}
+ case extensionPSKModes:
+ return &PSKKeyExchangeModesExtension{}
+ // case extensionCertificateAuthorities:
+ // return &CertificateAuthoritiesExtension{}
+ case extensionSignatureAlgorithmsCert:
+ return &SignatureAlgorithmsCertExtension{}
+ case extensionKeyShare:
+ return &KeyShareExtension{}
+ case extensionNextProtoNeg:
+ return &NPNExtension{}
+ case utlsExtensionApplicationSettings:
+ return &ApplicationSettingsExtension{}
+ case fakeOldExtensionChannelID:
+ return &FakeChannelIDExtension{true}
+ case fakeExtensionChannelID:
+ return &FakeChannelIDExtension{}
+ case fakeRecordSizeLimit:
+ return &FakeRecordSizeLimitExtension{}
+ case extensionRenegotiationInfo:
+ return &RenegotiationInfoExtension{}
+ default:
+ if isGREASEUint16(id) {
+ return &UtlsGREASEExtension{}
+ }
+ return nil // not returning GenericExtension, it should be handled by caller
+ }
+}
+
type TLSExtension interface {
writeToUConn(*UConn) error
@@ -19,42 +93,28 @@ type TLSExtension interface {
Read(p []byte) (n int, err error) // implements io.Reader
}
-type NPNExtension struct {
- NextProtos []string
-}
+// TLSExtensionWriter is an interface allowing a TLS extension to be
+// auto-constucted/recovered by reading in a byte stream.
+type TLSExtensionWriter interface {
+ TLSExtension
-func (e *NPNExtension) writeToUConn(uc *UConn) error {
- uc.config.NextProtos = e.NextProtos
- uc.HandshakeState.Hello.NextProtoNeg = true
- return nil
+ // Write writes up to len(b) bytes from b.
+ // It returns the number of bytes written (0 <= n <= len(b)) and any error encountered.
+ Write(b []byte) (n int, err error)
}
-func (e *NPNExtension) Len() int {
- return 4
-}
+type TLSExtensionJSON interface {
+ TLSExtension
-func (e *NPNExtension) Read(b []byte) (int, error) {
- if len(b) < e.Len() {
- return 0, io.ErrShortBuffer
- }
- b[0] = byte(extensionNextProtoNeg >> 8)
- b[1] = byte(extensionNextProtoNeg & 0xff)
- // The length is always 0
- return e.Len(), io.EOF
+ // UnmarshalJSON unmarshals the JSON-encoded data into the extension.
+ UnmarshalJSON([]byte) error
}
+// SNIExtension implements server_name (0)
type SNIExtension struct {
ServerName string // not an array because go crypto/tls doesn't support multiple SNIs
}
-func (e *SNIExtension) writeToUConn(uc *UConn) error {
- uc.config.ServerName = e.ServerName
- hostName := hostnameInSNI(e.ServerName)
- uc.HandshakeState.Hello.ServerName = hostName
-
- return nil
-}
-
func (e *SNIExtension) Len() int {
// Literal IP addresses, absolute FQDNs, and empty strings are not permitted as SNI values.
// See RFC 6066, Section 3.
@@ -89,14 +149,58 @@ func (e *SNIExtension) Read(b []byte) (int, error) {
return e.Len(), io.EOF
}
-type StatusRequestExtension struct {
+func (e *SNIExtension) UnmarshalJSON(_ []byte) error {
+ return nil // no-op
}
-func (e *StatusRequestExtension) writeToUConn(uc *UConn) error {
- uc.HandshakeState.Hello.OcspStapling = true
+// Write is a no-op for StatusRequestExtension.
+// SNI should not be fingerprinted and is user controlled.
+func (e *SNIExtension) Write(b []byte) (int, error) {
+ fullLen := len(b)
+ extData := cryptobyte.String(b)
+ // RFC 6066, Section 3
+ var nameList cryptobyte.String
+ if !extData.ReadUint16LengthPrefixed(&nameList) || nameList.Empty() {
+ return fullLen, errors.New("unable to read server name extension data")
+ }
+ var serverName string
+ for !nameList.Empty() {
+ var nameType uint8
+ var serverNameBytes cryptobyte.String
+ if !nameList.ReadUint8(&nameType) ||
+ !nameList.ReadUint16LengthPrefixed(&serverNameBytes) ||
+ serverNameBytes.Empty() {
+ return fullLen, errors.New("unable to read server name extension data")
+ }
+ if nameType != 0 {
+ continue
+ }
+ if len(serverName) != 0 {
+ return fullLen, errors.New("multiple names of the same name_type in server name extension are prohibited")
+ }
+ serverName = string(serverNameBytes)
+ if strings.HasSuffix(serverName, ".") {
+ return fullLen, errors.New("SNI value may not include a trailing dot")
+ }
+ }
+ // clientHelloSpec.Extensions = append(clientHelloSpec.Extensions, &SNIExtension{}) // gaukas moved this line out from the loop.
+
+ // don't copy SNI from ClientHello to ClientHelloSpec!
+ return fullLen, nil
+}
+
+func (e *SNIExtension) writeToUConn(uc *UConn) error {
+ uc.config.ServerName = e.ServerName
+ hostName := hostnameInSNI(e.ServerName)
+ uc.HandshakeState.Hello.ServerName = hostName
+
return nil
}
+// StatusRequestExtension implements status_request (5)
+type StatusRequestExtension struct {
+}
+
func (e *StatusRequestExtension) Len() int {
return 9
}
@@ -115,46 +219,40 @@ func (e *StatusRequestExtension) Read(b []byte) (int, error) {
return e.Len(), io.EOF
}
-type StatusRequestV2Extension struct {
+func (e *StatusRequestExtension) UnmarshalJSON(_ []byte) error {
+ return nil // no-op
}
-func (e *StatusRequestV2Extension) writeToUConn(uc *UConn) error {
- uc.HandshakeState.Hello.OcspStapling = true
- return nil
-}
+// Write is a no-op for StatusRequestExtension. No data for this extension.
+func (e *StatusRequestExtension) Write(b []byte) (int, error) {
+ fullLen := len(b)
+ extData := cryptobyte.String(b)
+ // RFC 4366, Section 3.6
+ var statusType uint8
+ var ignored cryptobyte.String
+ if !extData.ReadUint8(&statusType) ||
+ !extData.ReadUint16LengthPrefixed(&ignored) ||
+ !extData.ReadUint16LengthPrefixed(&ignored) {
+ return fullLen, errors.New("unable to read status request extension data")
+ }
-func (e *StatusRequestV2Extension) Len() int {
- return 13
+ if statusType != statusTypeOCSP {
+ return fullLen, errors.New("status request extension statusType is not statusTypeOCSP(1)")
+ }
+
+ return fullLen, nil
}
-func (e *StatusRequestV2Extension) Read(b []byte) (int, error) {
- if len(b) < e.Len() {
- return 0, io.ErrShortBuffer
- }
- // RFC 4366, section 3.6
- b[0] = byte(extensionStatusRequestV2 >> 8)
- b[1] = byte(extensionStatusRequestV2)
- b[2] = 0
- b[3] = 9
- b[4] = 0
- b[5] = 7
- b[6] = 2 // OCSP type
- b[7] = 0
- b[8] = 4
- // Two zero valued uint16s for the two lengths.
- return e.Len(), io.EOF
+func (e *StatusRequestExtension) writeToUConn(uc *UConn) error {
+ uc.HandshakeState.Hello.OcspStapling = true
+ return nil
}
+// SupportedCurvesExtension implements supported_groups (renamed from "elliptic_curves") (10)
type SupportedCurvesExtension struct {
Curves []CurveID
}
-func (e *SupportedCurvesExtension) writeToUConn(uc *UConn) error {
- uc.config.CurvePreferences = e.Curves
- uc.HandshakeState.Hello.SupportedCurves = e.Curves
- return nil
-}
-
func (e *SupportedCurvesExtension) Len() int {
return 6 + 2*len(e.Curves)
}
@@ -177,15 +275,60 @@ func (e *SupportedCurvesExtension) Read(b []byte) (int, error) {
return e.Len(), io.EOF
}
-type SupportedPointsExtension struct {
- SupportedPoints []uint8
+func (e *SupportedCurvesExtension) UnmarshalJSON(data []byte) error {
+ var namedGroups struct {
+ NamedGroupList []string `json:"named_group_list"`
+ }
+ if err := json.Unmarshal(data, &namedGroups); err != nil {
+ return err
+ }
+
+ for _, namedGroup := range namedGroups.NamedGroupList {
+ if namedGroup == "GREASE" {
+ e.Curves = append(e.Curves, GREASE_PLACEHOLDER)
+ continue
+ }
+
+ if group, ok := godicttls.DictSupportedGroupsNameIndexed[namedGroup]; ok {
+ e.Curves = append(e.Curves, CurveID(group))
+ } else {
+ return fmt.Errorf("unknown named group: %s", namedGroup)
+ }
+ }
+ return nil
}
-func (e *SupportedPointsExtension) writeToUConn(uc *UConn) error {
- uc.HandshakeState.Hello.SupportedPoints = e.SupportedPoints
+func (e *SupportedCurvesExtension) Write(b []byte) (int, error) {
+ fullLen := len(b)
+ extData := cryptobyte.String(b)
+ // RFC 4492, sections 5.1.1 and RFC 8446, Section 4.2.7
+ var curvesBytes cryptobyte.String
+ if !extData.ReadUint16LengthPrefixed(&curvesBytes) || curvesBytes.Empty() {
+ return 0, errors.New("unable to read supported curves extension data")
+ }
+ curves := []CurveID{}
+ for !curvesBytes.Empty() {
+ var curve uint16
+ if !curvesBytes.ReadUint16(&curve) {
+ return 0, errors.New("unable to read supported curves extension data")
+ }
+ curves = append(curves, CurveID(unGREASEUint16(curve)))
+ }
+ e.Curves = curves
+ return fullLen, nil
+}
+
+func (e *SupportedCurvesExtension) writeToUConn(uc *UConn) error {
+ uc.config.CurvePreferences = e.Curves
+ uc.HandshakeState.Hello.SupportedCurves = e.Curves
return nil
}
+// SupportedPointsExtension implements ec_point_formats (11)
+type SupportedPointsExtension struct {
+ SupportedPoints []uint8
+}
+
func (e *SupportedPointsExtension) Len() int {
return 5 + len(e.SupportedPoints)
}
@@ -206,15 +349,47 @@ func (e *SupportedPointsExtension) Read(b []byte) (int, error) {
return e.Len(), io.EOF
}
-type SignatureAlgorithmsExtension struct {
- SupportedSignatureAlgorithms []SignatureScheme
+func (e *SupportedPointsExtension) UnmarshalJSON(data []byte) error {
+ var pointFormatList struct {
+ ECPointFormatList []string `json:"ec_point_format_list"`
+ }
+ if err := json.Unmarshal(data, &pointFormatList); err != nil {
+ return err
+ }
+
+ for _, pointFormat := range pointFormatList.ECPointFormatList {
+ if format, ok := godicttls.DictECPointFormatNameIndexed[pointFormat]; ok {
+ e.SupportedPoints = append(e.SupportedPoints, format)
+ } else {
+ return fmt.Errorf("unknown point format: %s", pointFormat)
+ }
+ }
+ return nil
}
-func (e *SignatureAlgorithmsExtension) writeToUConn(uc *UConn) error {
- uc.HandshakeState.Hello.SupportedSignatureAlgorithms = e.SupportedSignatureAlgorithms
+func (e *SupportedPointsExtension) Write(b []byte) (int, error) {
+ fullLen := len(b)
+ extData := cryptobyte.String(b)
+ // RFC 4492, Section 5.1.2
+ supportedPoints := []uint8{}
+ if !readUint8LengthPrefixed(&extData, &supportedPoints) ||
+ len(supportedPoints) == 0 {
+ return 0, errors.New("unable to read supported points extension data")
+ }
+ e.SupportedPoints = supportedPoints
+ return fullLen, nil
+}
+
+func (e *SupportedPointsExtension) writeToUConn(uc *UConn) error {
+ uc.HandshakeState.Hello.SupportedPoints = e.SupportedPoints
return nil
}
+// SignatureAlgorithmsExtension implements signature_algorithms (13)
+type SignatureAlgorithmsExtension struct {
+ SupportedSignatureAlgorithms []SignatureScheme
+}
+
func (e *SignatureAlgorithmsExtension) Len() int {
return 6 + 2*len(e.SupportedSignatureAlgorithms)
}
@@ -230,22 +405,124 @@ func (e *SignatureAlgorithmsExtension) Read(b []byte) (int, error) {
b[3] = byte(2 + 2*len(e.SupportedSignatureAlgorithms))
b[4] = byte((2 * len(e.SupportedSignatureAlgorithms)) >> 8)
b[5] = byte(2 * len(e.SupportedSignatureAlgorithms))
- for i, sigAndHash := range e.SupportedSignatureAlgorithms {
- b[6+2*i] = byte(sigAndHash >> 8)
- b[7+2*i] = byte(sigAndHash)
+ for i, sigScheme := range e.SupportedSignatureAlgorithms {
+ b[6+2*i] = byte(sigScheme >> 8)
+ b[7+2*i] = byte(sigScheme)
}
return e.Len(), io.EOF
}
-type SignatureAlgorithmsCertExtension struct {
- SupportedSignatureAlgorithms []SignatureScheme
+func (e *SignatureAlgorithmsExtension) UnmarshalJSON(data []byte) error {
+ var signatureAlgorithms struct {
+ Algorithms []string `json:"supported_signature_algorithms"`
+ }
+ if err := json.Unmarshal(data, &signatureAlgorithms); err != nil {
+ return err
+ }
+
+ for _, sigScheme := range signatureAlgorithms.Algorithms {
+ if sigScheme == "GREASE" {
+ e.SupportedSignatureAlgorithms = append(e.SupportedSignatureAlgorithms, GREASE_PLACEHOLDER)
+ continue
+ }
+
+ if scheme, ok := godicttls.DictSignatureSchemeNameIndexed[sigScheme]; ok {
+ e.SupportedSignatureAlgorithms = append(e.SupportedSignatureAlgorithms, SignatureScheme(scheme))
+ } else {
+ return fmt.Errorf("unknown signature scheme: %s", sigScheme)
+ }
+ }
+ return nil
}
-func (e *SignatureAlgorithmsCertExtension) writeToUConn(uc *UConn) error {
+func (e *SignatureAlgorithmsExtension) Write(b []byte) (int, error) {
+ fullLen := len(b)
+ extData := cryptobyte.String(b)
+ // RFC 5246, Section 7.4.1.4.1
+ var sigAndAlgs cryptobyte.String
+ if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() {
+ return 0, errors.New("unable to read signature algorithms extension data")
+ }
+ supportedSignatureAlgorithms := []SignatureScheme{}
+ for !sigAndAlgs.Empty() {
+ var sigAndAlg uint16
+ if !sigAndAlgs.ReadUint16(&sigAndAlg) {
+ return 0, errors.New("unable to read signature algorithms extension data")
+ }
+ supportedSignatureAlgorithms = append(
+ supportedSignatureAlgorithms, SignatureScheme(sigAndAlg))
+ }
+ e.SupportedSignatureAlgorithms = supportedSignatureAlgorithms
+ return fullLen, nil
+}
+
+func (e *SignatureAlgorithmsExtension) writeToUConn(uc *UConn) error {
uc.HandshakeState.Hello.SupportedSignatureAlgorithms = e.SupportedSignatureAlgorithms
return nil
}
+// StatusRequestV2Extension implements status_request_v2 (17)
+type StatusRequestV2Extension struct {
+}
+
+func (e *StatusRequestV2Extension) writeToUConn(uc *UConn) error {
+ uc.HandshakeState.Hello.OcspStapling = true
+ return nil
+}
+
+func (e *StatusRequestV2Extension) Len() int {
+ return 13
+}
+
+func (e *StatusRequestV2Extension) Read(b []byte) (int, error) {
+ if len(b) < e.Len() {
+ return 0, io.ErrShortBuffer
+ }
+ // RFC 4366, section 3.6
+ b[0] = byte(extensionStatusRequestV2 >> 8)
+ b[1] = byte(extensionStatusRequestV2)
+ b[2] = 0
+ b[3] = 9
+ b[4] = 0
+ b[5] = 7
+ b[6] = 2 // OCSP type
+ b[7] = 0
+ b[8] = 4
+ // Two zero valued uint16s for the two lengths.
+ return e.Len(), io.EOF
+}
+
+// Write is a no-op for StatusRequestV2Extension. No data for this extension.
+func (e *StatusRequestV2Extension) Write(b []byte) (int, error) {
+ fullLen := len(b)
+ extData := cryptobyte.String(b)
+ // RFC 4366, Section 3.6
+ var statusType uint8
+ var ignored cryptobyte.String
+ if !extData.ReadUint16LengthPrefixed(&ignored) ||
+ !extData.ReadUint8(&statusType) ||
+ !extData.ReadUint16LengthPrefixed(&ignored) ||
+ !extData.ReadUint16LengthPrefixed(&ignored) ||
+ !extData.ReadUint16LengthPrefixed(&ignored) {
+ return fullLen, errors.New("unable to read status request v2 extension data")
+ }
+
+ if statusType != statusV2TypeOCSP {
+ return fullLen, errors.New("status request v2 extension statusType is not statusV2TypeOCSP(2)")
+ }
+
+ return fullLen, nil
+}
+
+func (e *StatusRequestV2Extension) UnmarshalJSON(_ []byte) error {
+ return nil // no-op
+}
+
+// SignatureAlgorithmsCertExtension implements signature_algorithms_cert (50)
+type SignatureAlgorithmsCertExtension struct {
+ SupportedSignatureAlgorithms []SignatureScheme
+}
+
func (e *SignatureAlgorithmsCertExtension) Len() int {
return 6 + 2*len(e.SupportedSignatureAlgorithms)
}
@@ -268,48 +545,60 @@ func (e *SignatureAlgorithmsCertExtension) Read(b []byte) (int, error) {
return e.Len(), io.EOF
}
-type RenegotiationInfoExtension struct {
- // Renegotiation field limits how many times client will perform renegotiation: no limit, once, or never.
- // The extension still will be sent, even if Renegotiation is set to RenegotiateNever.
- Renegotiation RenegotiationSupport
-}
+// Copied from SignatureAlgorithmsExtension.UnmarshalJSON
+func (e *SignatureAlgorithmsCertExtension) UnmarshalJSON(data []byte) error {
+ var signatureAlgorithms struct {
+ Algorithms []string `json:"supported_signature_algorithms"`
+ }
+ if err := json.Unmarshal(data, &signatureAlgorithms); err != nil {
+ return err
+ }
-func (e *RenegotiationInfoExtension) writeToUConn(uc *UConn) error {
- uc.config.Renegotiation = e.Renegotiation
- switch e.Renegotiation {
- case RenegotiateOnceAsClient:
- fallthrough
- case RenegotiateFreelyAsClient:
- uc.HandshakeState.Hello.SecureRenegotiationSupported = true
- case RenegotiateNever:
- default:
+ for _, sigScheme := range signatureAlgorithms.Algorithms {
+ if sigScheme == "GREASE" {
+ e.SupportedSignatureAlgorithms = append(e.SupportedSignatureAlgorithms, GREASE_PLACEHOLDER)
+ continue
+ }
+
+ if scheme, ok := godicttls.DictSignatureSchemeNameIndexed[sigScheme]; ok {
+ e.SupportedSignatureAlgorithms = append(e.SupportedSignatureAlgorithms, SignatureScheme(scheme))
+ } else {
+ return fmt.Errorf("unknown cert signature scheme: %s", sigScheme)
+ }
}
return nil
}
-func (e *RenegotiationInfoExtension) Len() int {
- return 5
-}
-
-func (e *RenegotiationInfoExtension) Read(b []byte) (int, error) {
- if len(b) < e.Len() {
- return 0, io.ErrShortBuffer
+// Write implementation copied from SignatureAlgorithmsExtension.Write
+//
+// Warning: not tested.
+func (e *SignatureAlgorithmsCertExtension) Write(b []byte) (int, error) {
+ fullLen := len(b)
+ extData := cryptobyte.String(b)
+ // RFC 8446, Section 4.2.3
+ var sigAndAlgs cryptobyte.String
+ if !extData.ReadUint16LengthPrefixed(&sigAndAlgs) || sigAndAlgs.Empty() {
+ return 0, errors.New("unable to read signature algorithms extension data")
}
+ supportedSignatureAlgorithms := []SignatureScheme{}
+ for !sigAndAlgs.Empty() {
+ var sigAndAlg uint16
+ if !sigAndAlgs.ReadUint16(&sigAndAlg) {
+ return 0, errors.New("unable to read signature algorithms extension data")
+ }
+ supportedSignatureAlgorithms = append(
+ supportedSignatureAlgorithms, SignatureScheme(sigAndAlg))
+ }
+ e.SupportedSignatureAlgorithms = supportedSignatureAlgorithms
+ return fullLen, nil
+}
- var extInnerBody []byte // inner body is empty
- innerBodyLen := len(extInnerBody)
- extBodyLen := innerBodyLen + 1
-
- b[0] = byte(extensionRenegotiationInfo >> 8)
- b[1] = byte(extensionRenegotiationInfo & 0xff)
- b[2] = byte(extBodyLen >> 8)
- b[3] = byte(extBodyLen)
- b[4] = byte(innerBodyLen)
- copy(b[5:], extInnerBody)
-
- return e.Len(), io.EOF
+func (e *SignatureAlgorithmsCertExtension) writeToUConn(uc *UConn) error {
+ uc.HandshakeState.Hello.SupportedSignatureAlgorithms = e.SupportedSignatureAlgorithms
+ return nil
}
+// ALPNExtension implements application_layer_protocol_negotiation (16)
type ALPNExtension struct {
AlpnProtocols []string
}
@@ -356,6 +645,40 @@ func (e *ALPNExtension) Read(b []byte) (int, error) {
return e.Len(), io.EOF
}
+func (e *ALPNExtension) UnmarshalJSON(b []byte) error {
+ var protocolNames struct {
+ ProtocolNameList []string `json:"protocol_name_list"`
+ }
+
+ if err := json.Unmarshal(b, &protocolNames); err != nil {
+ return err
+ }
+
+ e.AlpnProtocols = protocolNames.ProtocolNameList
+ return nil
+}
+
+func (e *ALPNExtension) Write(b []byte) (int, error) {
+ fullLen := len(b)
+ extData := cryptobyte.String(b)
+ // RFC 7301, Section 3.1
+ var protoList cryptobyte.String
+ if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() {
+ return 0, errors.New("unable to read ALPN extension data")
+ }
+ alpnProtocols := []string{}
+ for !protoList.Empty() {
+ var proto cryptobyte.String
+ if !protoList.ReadUint8LengthPrefixed(&proto) || proto.Empty() {
+ return 0, errors.New("unable to read ALPN extension data")
+ }
+ alpnProtocols = append(alpnProtocols, string(proto))
+
+ }
+ e.AlpnProtocols = alpnProtocols
+ return fullLen, nil
+}
+
// ApplicationSettingsExtension represents the TLS ALPS extension.
// At the time of this writing, this extension is currently a draft:
// https://datatracker.ietf.org/doc/html/draft-vvv-tls-alps-01
@@ -405,6 +728,42 @@ func (e *ApplicationSettingsExtension) Read(b []byte) (int, error) {
return e.Len(), io.EOF
}
+func (e *ApplicationSettingsExtension) UnmarshalJSON(b []byte) error {
+ var applicationSettingsSupport struct {
+ SupportedProtocols []string `json:"supported_protocols"`
+ }
+
+ if err := json.Unmarshal(b, &applicationSettingsSupport); err != nil {
+ return err
+ }
+
+ e.SupportedProtocols = applicationSettingsSupport.SupportedProtocols
+ return nil
+}
+
+// Write implementation copied from ALPNExtension.Write
+func (e *ApplicationSettingsExtension) Write(b []byte) (int, error) {
+ fullLen := len(b)
+ extData := cryptobyte.String(b)
+ // https://datatracker.ietf.org/doc/html/draft-vvv-tls-alps-01
+ var protoList cryptobyte.String
+ if !extData.ReadUint16LengthPrefixed(&protoList) || protoList.Empty() {
+ return 0, errors.New("unable to read ALPN extension data")
+ }
+ alpnProtocols := []string{}
+ for !protoList.Empty() {
+ var proto cryptobyte.String
+ if !protoList.ReadUint8LengthPrefixed(&proto) || proto.Empty() {
+ return 0, errors.New("unable to read ALPN extension data")
+ }
+ alpnProtocols = append(alpnProtocols, string(proto))
+
+ }
+ e.SupportedProtocols = alpnProtocols
+ return fullLen, nil
+}
+
+// SCTExtension implements signed_certificate_timestamp (18)
type SCTExtension struct {
}
@@ -428,6 +787,15 @@ func (e *SCTExtension) Read(b []byte) (int, error) {
return e.Len(), io.EOF
}
+func (e *SCTExtension) UnmarshalJSON(_ []byte) error {
+ return nil // no-op
+}
+
+func (e *SCTExtension) Write(_ []byte) (int, error) {
+ return 0, nil
+}
+
+// SessionTicketExtension implements session_ticket (35)
type SessionTicketExtension struct {
Session *ClientSessionState
}
@@ -464,7 +832,18 @@ func (e *SessionTicketExtension) Read(b []byte) (int, error) {
return e.Len(), io.EOF
}
+func (e *SessionTicketExtension) UnmarshalJSON(_ []byte) error {
+ return nil // no-op
+}
+
+func (e *SessionTicketExtension) Write(_ []byte) (int, error) {
+ // RFC 5077, Section 3.2
+ return 0, nil
+}
+
// GenericExtension allows to include in ClientHello arbitrary unsupported extensions.
+// It is not defined in TLS RFCs nor by IANA.
+// If a server echoes this extension back, the handshake will likely fail due to no further support.
type GenericExtension struct {
Id uint16
Data []byte
@@ -493,6 +872,26 @@ func (e *GenericExtension) Read(b []byte) (int, error) {
return e.Len(), io.EOF
}
+func (e *GenericExtension) UnmarshalJSON(b []byte) error {
+ var genericExtension struct {
+ Name string `json:"name"`
+ Data []byte `json:"data"`
+ }
+ if err := json.Unmarshal(b, &genericExtension); err != nil {
+ return err
+ }
+
+ // lookup extension ID by name
+ if id, ok := godicttls.DictExtTypeNameIndexed[genericExtension.Name]; ok {
+ e.Id = id
+ } else {
+ return fmt.Errorf("unknown extension name %s", genericExtension.Name)
+ }
+ e.Data = genericExtension.Data
+ return nil
+}
+
+// UtlsExtendedMasterSecretExtension implements extended_master_secret (23)
type UtlsExtendedMasterSecretExtension struct {
}
@@ -518,6 +917,15 @@ func (e *UtlsExtendedMasterSecretExtension) Read(b []byte) (int, error) {
return e.Len(), io.EOF
}
+func (e *UtlsExtendedMasterSecretExtension) UnmarshalJSON(_ []byte) error {
+ return nil // no-op
+}
+
+func (e *UtlsExtendedMasterSecretExtension) Write(_ []byte) (int, error) {
+ // https://tools.ietf.org/html/rfc7627
+ return 0, nil
+}
+
var extendedMasterSecretLabel = []byte("extended master secret")
// extendedMasterFromPreMasterSecret generates the master secret from the pre-master
@@ -580,6 +988,43 @@ func (e *UtlsGREASEExtension) Read(b []byte) (int, error) {
return e.Len(), io.EOF
}
+func (e *UtlsGREASEExtension) Write(b []byte) (int, error) {
+ e.Value = GREASE_PLACEHOLDER
+ e.Body = make([]byte, len(b))
+ n := copy(e.Body, b)
+ return n, nil
+}
+
+func (e *UtlsGREASEExtension) UnmarshalJSON(b []byte) error {
+ var jsonObj struct {
+ Id uint16 `json:"id"`
+ Data []byte `json:"data"`
+ KeepID bool `json:"keep_id"`
+ KeepData bool `json:"keep_data"`
+ }
+
+ if err := json.Unmarshal(b, &jsonObj); err != nil {
+ return err
+ }
+
+ if jsonObj.Id == 0 {
+ return nil
+ }
+
+ if isGREASEUint16(jsonObj.Id) {
+ if jsonObj.KeepID {
+ e.Value = jsonObj.Id
+ }
+ if jsonObj.KeepData {
+ e.Body = jsonObj.Data
+ }
+ return nil
+ } else {
+ return errors.New("GREASE extension id must be a GREASE value")
+ }
+}
+
+// UtlsPaddingExtension implements padding (21)
type UtlsPaddingExtension struct {
PaddingLen int
WillPad bool // set to false to disable extension
@@ -614,16 +1059,54 @@ func (e *UtlsPaddingExtension) Read(b []byte) (int, error) {
if len(b) < e.Len() {
return 0, io.ErrShortBuffer
}
- // https://tools.ietf.org/html/rfc7627
- b[0] = byte(utlsExtensionPadding >> 8)
- b[1] = byte(utlsExtensionPadding)
- b[2] = byte(e.PaddingLen >> 8)
- b[3] = byte(e.PaddingLen)
- return e.Len(), io.EOF
+ // https://tools.ietf.org/html/rfc7627
+ b[0] = byte(utlsExtensionPadding >> 8)
+ b[1] = byte(utlsExtensionPadding)
+ b[2] = byte(e.PaddingLen >> 8)
+ b[3] = byte(e.PaddingLen)
+ return e.Len(), io.EOF
+}
+
+func (e *UtlsPaddingExtension) UnmarshalJSON(b []byte) error {
+ var jsonObj struct {
+ Length uint `json:"len"`
+ }
+ if err := json.Unmarshal(b, &jsonObj); err != nil {
+ return err
+ }
+
+ if jsonObj.Length == 0 {
+ e.GetPaddingLen = BoringPaddingStyle
+ } else {
+ e.PaddingLen = int(jsonObj.Length)
+ e.WillPad = true
+ }
+
+ return nil
+}
+
+func (e *UtlsPaddingExtension) Write(_ []byte) (int, error) {
+ e.GetPaddingLen = BoringPaddingStyle
+ return 0, nil
+}
+
+// https://github.com/google/boringssl/blob/7d7554b6b3c79e707e25521e61e066ce2b996e4c/ssl/t1_lib.c#L2803
+func BoringPaddingStyle(unpaddedLen int) (int, bool) {
+ if unpaddedLen > 0xff && unpaddedLen < 0x200 {
+ paddingLen := 0x200 - unpaddedLen
+ if paddingLen >= 4+1 {
+ paddingLen -= 4
+ } else {
+ paddingLen = 1
+ }
+ return paddingLen, true
+ }
+ return 0, false
}
-// UtlsCompressCertExtension is only implemented client-side, for server certificates. Alternate
-// certificate message formats (https://datatracker.ietf.org/doc/html/rfc7250) are not supported.
+// UtlsCompressCertExtension implements compress_certificate (27) and is only implemented client-side
+// for server certificates. Alternate certificate message formats
+// (https://datatracker.ietf.org/doc/html/rfc7250) are not supported.
//
// See https://datatracker.ietf.org/doc/html/rfc8879#section-3
type UtlsCompressCertExtension struct {
@@ -667,21 +1150,45 @@ func (e *UtlsCompressCertExtension) Read(b []byte) (int, error) {
return e.Len(), io.EOF
}
-// https://github.com/google/boringssl/blob/7d7554b6b3c79e707e25521e61e066ce2b996e4c/ssl/t1_lib.c#L2803
-func BoringPaddingStyle(unpaddedLen int) (int, bool) {
- if unpaddedLen > 0xff && unpaddedLen < 0x200 {
- paddingLen := 0x200 - unpaddedLen
- if paddingLen >= 4+1 {
- paddingLen -= 4
+func (e *UtlsCompressCertExtension) Write(b []byte) (int, error) {
+ fullLen := len(b)
+ extData := cryptobyte.String(b)
+ methods := []CertCompressionAlgo{}
+ methodsRaw := new(cryptobyte.String)
+ if !extData.ReadUint8LengthPrefixed(methodsRaw) {
+ return 0, errors.New("unable to read cert compression algorithms extension data")
+ }
+ for !methodsRaw.Empty() {
+ var method uint16
+ if !methodsRaw.ReadUint16(&method) {
+ return 0, errors.New("unable to read cert compression algorithms extension data")
+ }
+ methods = append(methods, CertCompressionAlgo(method))
+ }
+
+ e.Algorithms = methods
+ return fullLen, nil
+}
+
+func (e *UtlsCompressCertExtension) UnmarshalJSON(b []byte) error {
+ var certificateCompressionAlgorithms struct {
+ Algorithms []string `json:"algorithms"`
+ }
+ if err := json.Unmarshal(b, &certificateCompressionAlgorithms); err != nil {
+ return err
+ }
+
+ for _, algorithm := range certificateCompressionAlgorithms.Algorithms {
+ if alg, ok := godicttls.DictCertificateCompressionAlgorithmNameIndexed[algorithm]; ok {
+ e.Algorithms = append(e.Algorithms, CertCompressionAlgo(alg))
} else {
- paddingLen = 1
+ return fmt.Errorf("unknown certificate compression algorithm %s", algorithm)
}
- return paddingLen, true
}
- return 0, false
+ return nil
}
-/* TLS 1.3 */
+// KeyShareExtension implements key_share (51) and is for TLS 1.3 only.
type KeyShareExtension struct {
KeyShares []KeyShare
}
@@ -724,11 +1231,74 @@ func (e *KeyShareExtension) Read(b []byte) (int, error) {
return e.Len(), io.EOF
}
+func (e *KeyShareExtension) Write(b []byte) (int, error) {
+ fullLen := len(b)
+ extData := cryptobyte.String(b)
+ // RFC 8446, Section 4.2.8
+ var clientShares cryptobyte.String
+ if !extData.ReadUint16LengthPrefixed(&clientShares) {
+ return 0, errors.New("unable to read key share extension data")
+ }
+ keyShares := []KeyShare{}
+ for !clientShares.Empty() {
+ var ks KeyShare
+ var group uint16
+ if !clientShares.ReadUint16(&group) ||
+ !readUint16LengthPrefixed(&clientShares, &ks.Data) ||
+ len(ks.Data) == 0 {
+ return 0, errors.New("unable to read key share extension data")
+ }
+ ks.Group = CurveID(unGREASEUint16(group))
+ // if not GREASE, key share data will be discarded as it should
+ // be generated per connection
+ if ks.Group != GREASE_PLACEHOLDER {
+ ks.Data = nil
+ }
+ keyShares = append(keyShares, ks)
+ }
+ e.KeyShares = keyShares
+ return fullLen, nil
+}
+
func (e *KeyShareExtension) writeToUConn(uc *UConn) error {
uc.HandshakeState.Hello.KeyShares = e.KeyShares
return nil
}
+func (e *KeyShareExtension) UnmarshalJSON(b []byte) error {
+ var keyShareClientHello struct {
+ ClientShares []struct {
+ Group string `json:"group"`
+ KeyExchange []uint8 `json:"key_exchange"`
+ } `json:"client_shares"`
+ }
+ if err := json.Unmarshal(b, &keyShareClientHello); err != nil {
+ return err
+ }
+
+ for _, clientShare := range keyShareClientHello.ClientShares {
+ if clientShare.Group == "GREASE" {
+ e.KeyShares = append(e.KeyShares, KeyShare{
+ Group: GREASE_PLACEHOLDER,
+ Data: clientShare.KeyExchange,
+ })
+ continue
+ }
+
+ if groupID, ok := godicttls.DictSupportedGroupsNameIndexed[clientShare.Group]; ok {
+ ks := KeyShare{
+ Group: CurveID(groupID),
+ Data: clientShare.KeyExchange,
+ }
+ e.KeyShares = append(e.KeyShares, ks)
+ } else {
+ return fmt.Errorf("unknown group %s", clientShare.Group)
+ }
+ }
+ return nil
+}
+
+// PSKKeyExchangeModesExtension implements psk_key_exchange_modes (45).
type PSKKeyExchangeModesExtension struct {
Modes []uint8
}
@@ -761,11 +1331,46 @@ func (e *PSKKeyExchangeModesExtension) Read(b []byte) (int, error) {
return e.Len(), io.EOF
}
+func (e *PSKKeyExchangeModesExtension) Write(b []byte) (int, error) {
+ fullLen := len(b)
+ extData := cryptobyte.String(b)
+ // RFC 8446, Section 4.2.9
+ // TODO: PSK Modes have their own form of GREASE-ing which is not currently implemented
+ // the current functionality will NOT re-GREASE/re-randomize these values when using a fingerprinted spec
+ // https://github.com/refraction-networking/utls/pull/58#discussion_r522354105
+ // https://tools.ietf.org/html/draft-ietf-tls-grease-01#section-2
+ pskModes := []uint8{}
+ if !readUint8LengthPrefixed(&extData, &pskModes) {
+ return 0, errors.New("unable to read PSK extension data")
+ }
+ e.Modes = pskModes
+ return fullLen, nil
+}
+
func (e *PSKKeyExchangeModesExtension) writeToUConn(uc *UConn) error {
uc.HandshakeState.Hello.PskModes = e.Modes
return nil
}
+func (e *PSKKeyExchangeModesExtension) UnmarshalJSON(b []byte) error {
+ var pskKeyExchangeModes struct {
+ Modes []string `json:"ke_modes"`
+ }
+ if err := json.Unmarshal(b, &pskKeyExchangeModes); err != nil {
+ return err
+ }
+
+ for _, mode := range pskKeyExchangeModes.Modes {
+ if modeID, ok := godicttls.DictPSKKeyExchangeModeNameIndexed[mode]; ok {
+ e.Modes = append(e.Modes, modeID)
+ } else {
+ return fmt.Errorf("unknown PSK Key Exchange Mode %s", mode)
+ }
+ }
+ return nil
+}
+
+// SupportedVersionsExtension implements supported_versions (43).
type SupportedVersionsExtension struct {
Versions []uint16
}
@@ -803,6 +1408,57 @@ func (e *SupportedVersionsExtension) Read(b []byte) (int, error) {
return e.Len(), io.EOF
}
+func (e *SupportedVersionsExtension) Write(b []byte) (int, error) {
+ fullLen := len(b)
+ extData := cryptobyte.String(b)
+ // RFC 8446, Section 4.2.1
+ var versList cryptobyte.String
+ if !extData.ReadUint8LengthPrefixed(&versList) || versList.Empty() {
+ return 0, errors.New("unable to read supported versions extension data")
+ }
+ supportedVersions := []uint16{}
+ for !versList.Empty() {
+ var vers uint16
+ if !versList.ReadUint16(&vers) {
+ return 0, errors.New("unable to read supported versions extension data")
+ }
+ supportedVersions = append(supportedVersions, unGREASEUint16(vers))
+ }
+ e.Versions = supportedVersions
+ return fullLen, nil
+}
+
+func (e *SupportedVersionsExtension) UnmarshalJSON(b []byte) error {
+ var supportedVersions struct {
+ Versions []string `json:"versions"`
+ }
+ if err := json.Unmarshal(b, &supportedVersions); err != nil {
+ return err
+ }
+
+ for _, version := range supportedVersions.Versions {
+ switch version {
+ case "GREASE":
+ e.Versions = append(e.Versions, GREASE_PLACEHOLDER)
+ case "TLS 1.3":
+ e.Versions = append(e.Versions, VersionTLS13)
+ case "TLS 1.2":
+ e.Versions = append(e.Versions, VersionTLS12)
+ case "TLS 1.1":
+ e.Versions = append(e.Versions, VersionTLS11)
+ case "TLS 1.0":
+ e.Versions = append(e.Versions, VersionTLS10)
+ case "SSL 3.0": // deprecated
+ // e.Versions = append(e.Versions, VersionSSL30)
+ return fmt.Errorf("SSL 3.0 is deprecated")
+ default:
+ return fmt.Errorf("unknown version %s", version)
+ }
+ }
+ return nil
+}
+
+// CookieExtension implements cookie (44).
// MUST NOT be part of initial ClientHello
type CookieExtension struct {
Cookie []byte
@@ -831,6 +1487,122 @@ func (e *CookieExtension) Read(b []byte) (int, error) {
return e.Len(), io.EOF
}
+func (e *CookieExtension) UnmarshalJSON(data []byte) error {
+ var cookie struct {
+ Cookie []uint8 `json:"cookie"`
+ }
+ if err := json.Unmarshal(data, &cookie); err != nil {
+ return err
+ }
+ e.Cookie = []byte(cookie.Cookie)
+ return nil
+}
+
+// NPNExtension implements next_protocol_negotiation (Not IANA assigned)
+type NPNExtension struct {
+ NextProtos []string
+}
+
+func (e *NPNExtension) writeToUConn(uc *UConn) error {
+ uc.config.NextProtos = e.NextProtos
+ uc.HandshakeState.Hello.NextProtoNeg = true
+ return nil
+}
+
+func (e *NPNExtension) Len() int {
+ return 4
+}
+
+func (e *NPNExtension) Read(b []byte) (int, error) {
+ if len(b) < e.Len() {
+ return 0, io.ErrShortBuffer
+ }
+ b[0] = byte(extensionNextProtoNeg >> 8)
+ b[1] = byte(extensionNextProtoNeg & 0xff)
+ // The length is always 0
+ return e.Len(), io.EOF
+}
+
+// Write is a no-op for NPNExtension. NextProtos are not included in the
+// ClientHello.
+func (e *NPNExtension) Write(_ []byte) (int, error) {
+ return 0, nil
+}
+
+// draft-agl-tls-nextprotoneg-04:
+// The "extension_data" field of a "next_protocol_negotiation" extension
+// in a "ClientHello" MUST be empty.
+func (e *NPNExtension) UnmarshalJSON(_ []byte) error {
+ return nil
+}
+
+// RenegotiationInfoExtension implements renegotiation_info (65281)
+type RenegotiationInfoExtension struct {
+ // Renegotiation field limits how many times client will perform renegotiation: no limit, once, or never.
+ // The extension still will be sent, even if Renegotiation is set to RenegotiateNever.
+ Renegotiation RenegotiationSupport // [UTLS] added for internal use only
+
+ // RenegotiatedConnection is not yet properly handled, now we
+ // are just copying it to the client hello.
+ //
+ // If this is the initial handshake for a connection, then the
+ // "renegotiated_connection" field is of zero length in both the
+ // ClientHello and the ServerHello.
+ // RenegotiatedConnection []byte
+}
+
+func (e *RenegotiationInfoExtension) Len() int {
+ return 5 // + len(e.RenegotiatedConnection)
+}
+
+func (e *RenegotiationInfoExtension) Read(b []byte) (int, error) {
+ if len(b) < e.Len() {
+ return 0, io.ErrShortBuffer
+ }
+
+ // dataLen := len(e.RenegotiatedConnection)
+ extBodyLen := 1 // + len(dataLen)
+
+ b[0] = byte(extensionRenegotiationInfo >> 8)
+ b[1] = byte(extensionRenegotiationInfo & 0xff)
+ b[2] = byte(extBodyLen >> 8)
+ b[3] = byte(extBodyLen)
+ // b[4] = byte(dataLen)
+ // copy(b[5:], e.RenegotiatedConnection)
+
+ return e.Len(), io.EOF
+}
+
+func (e *RenegotiationInfoExtension) UnmarshalJSON(_ []byte) error {
+ e.Renegotiation = RenegotiateOnceAsClient
+ return nil
+}
+
+func (e *RenegotiationInfoExtension) Write(_ []byte) (int, error) {
+ e.Renegotiation = RenegotiateOnceAsClient // none empty or other modes are unsupported
+ // extData := cryptobyte.String(b)
+ // var renegotiatedConnection cryptobyte.String
+ // if !extData.ReadUint8LengthPrefixed(&renegotiatedConnection) || !extData.Empty() {
+ // return 0, errors.New("unable to read renegotiation info extension data")
+ // }
+ // e.RenegotiatedConnection = make([]byte, len(renegotiatedConnection))
+ // copy(e.RenegotiatedConnection, renegotiatedConnection)
+ return 0, nil
+}
+
+func (e *RenegotiationInfoExtension) writeToUConn(uc *UConn) error {
+ uc.config.Renegotiation = e.Renegotiation
+ switch e.Renegotiation {
+ case RenegotiateOnceAsClient:
+ fallthrough
+ case RenegotiateFreelyAsClient:
+ uc.HandshakeState.Hello.SecureRenegotiationSupported = true
+ case RenegotiateNever:
+ default:
+ }
+ return nil
+}
+
/*
FAKE EXTENSIONS
*/
@@ -863,6 +1635,16 @@ func (e *FakeChannelIDExtension) Read(b []byte) (int, error) {
return e.Len(), io.EOF
}
+func (e *FakeChannelIDExtension) Write(_ []byte) (int, error) {
+ return 0, nil
+}
+
+func (e *FakeChannelIDExtension) UnmarshalJSON(_ []byte) error {
+ return nil
+}
+
+// FakeRecordSizeLimitExtension implements record_size_limit (28)
+// but with no support.
type FakeRecordSizeLimitExtension struct {
Limit uint16
}
@@ -891,37 +1673,30 @@ func (e *FakeRecordSizeLimitExtension) Read(b []byte) (int, error) {
return e.Len(), io.EOF
}
-type DelegatedCredentialsExtension struct {
- AlgorithmsSignature []SignatureScheme
-}
-
-func (e *DelegatedCredentialsExtension) writeToUConn(uc *UConn) error {
- return nil
-}
-
-func (e *DelegatedCredentialsExtension) Len() int {
- return 6 + 2*len(e.AlgorithmsSignature)
+func (e *FakeRecordSizeLimitExtension) Write(b []byte) (int, error) {
+ fullLen := len(b)
+ extData := cryptobyte.String(b)
+ if !extData.ReadUint16(&e.Limit) {
+ return 0, errors.New("unable to read record size limit extension data")
+ }
+ return fullLen, nil
}
-func (e *DelegatedCredentialsExtension) Read(b []byte) (int, error) {
- if len(b) < e.Len() {
- return 0, io.ErrShortBuffer
+func (e *FakeRecordSizeLimitExtension) UnmarshalJSON(data []byte) error {
+ var limitAccepter struct {
+ Limit uint16 `json:"record_size_limit"`
}
- b[0] = byte(extensionDelegatedCredentials >> 8)
- b[1] = byte(extensionDelegatedCredentials)
- b[2] = byte((2 + 2*len(e.AlgorithmsSignature)) >> 8)
- b[3] = byte(2 + 2*len(e.AlgorithmsSignature))
- b[4] = byte((2 * len(e.AlgorithmsSignature)) >> 8)
- b[5] = byte(2 * len(e.AlgorithmsSignature))
- for i, sigAndHash := range e.AlgorithmsSignature {
- b[6+2*i] = byte(sigAndHash >> 8)
- b[7+2*i] = byte(sigAndHash)
+ if err := json.Unmarshal(data, &limitAccepter); err != nil {
+ return err
}
- return e.Len(), io.EOF
+
+ e.Limit = limitAccepter.Limit
+ return nil
}
-// https://tools.ietf.org/html/rfc8472#section-2
+type DelegatedCredentialsExtension = FakeDelegatedCredentialsExtension
+// https://tools.ietf.org/html/rfc8472#section-2
type FakeTokenBindingExtension struct {
MajorVersion, MinorVersion uint8
KeyParameters []uint8
@@ -954,6 +1729,48 @@ func (e *FakeTokenBindingExtension) Read(b []byte) (int, error) {
return e.Len(), io.EOF
}
+func (e *FakeTokenBindingExtension) Write(b []byte) (int, error) {
+ fullLen := len(b)
+ extData := cryptobyte.String(b)
+ var keyParameters cryptobyte.String
+ if !extData.ReadUint8(&e.MajorVersion) ||
+ !extData.ReadUint8(&e.MinorVersion) ||
+ !extData.ReadUint8LengthPrefixed(&keyParameters) {
+ return 0, errors.New("unable to read token binding extension data")
+ }
+ e.KeyParameters = keyParameters
+ return fullLen, nil
+}
+
+func (e *FakeTokenBindingExtension) UnmarshalJSON(data []byte) error {
+ var tokenBindingAccepter struct {
+ TB_ProtocolVersion struct {
+ Major uint8 `json:"major"`
+ Minor uint8 `json:"minor"`
+ } `json:"token_binding_version"`
+ TokenBindingKeyParameters []string `json:"key_parameters_list"`
+ }
+ if err := json.Unmarshal(data, &tokenBindingAccepter); err != nil {
+ return err
+ }
+
+ e.MajorVersion = tokenBindingAccepter.TB_ProtocolVersion.Major
+ e.MinorVersion = tokenBindingAccepter.TB_ProtocolVersion.Minor
+ for _, param := range tokenBindingAccepter.TokenBindingKeyParameters {
+ switch param {
+ case "rsa2048_pkcs1.5":
+ e.KeyParameters = append(e.KeyParameters, 0)
+ case "rsa2048_pss":
+ e.KeyParameters = append(e.KeyParameters, 1)
+ case "ecdsap256":
+ e.KeyParameters = append(e.KeyParameters, 2)
+ default:
+ return fmt.Errorf("unknown token binding key parameter: %s", param)
+ }
+ }
+ return nil
+}
+
// https://datatracker.ietf.org/doc/html/draft-ietf-tls-subcerts-15#section-4.1.1
type FakeDelegatedCredentialsExtension struct {
@@ -985,3 +1802,220 @@ func (e *FakeDelegatedCredentialsExtension) Read(b []byte) (int, error) {
}
return e.Len(), io.EOF
}
+
+func (e *FakeDelegatedCredentialsExtension) Write(b []byte) (int, error) {
+ fullLen := len(b)
+ extData := cryptobyte.String(b)
+ //https://datatracker.ietf.org/doc/html/draft-ietf-tls-subcerts-15#section-4.1.1
+ var supportedAlgs cryptobyte.String
+ if !extData.ReadUint16LengthPrefixed(&supportedAlgs) || supportedAlgs.Empty() {
+ return 0, errors.New("unable to read signature algorithms extension data")
+ }
+ supportedSignatureAlgorithms := []SignatureScheme{}
+ for !supportedAlgs.Empty() {
+ var sigAndAlg uint16
+ if !supportedAlgs.ReadUint16(&sigAndAlg) {
+ return 0, errors.New("unable to read signature algorithms extension data")
+ }
+ supportedSignatureAlgorithms = append(
+ supportedSignatureAlgorithms, SignatureScheme(sigAndAlg))
+ }
+ e.SupportedSignatureAlgorithms = supportedSignatureAlgorithms
+ return fullLen, nil
+}
+
+// Implementation copied from SignatureAlgorithmsExtension.UnmarshalJSON
+func (e *FakeDelegatedCredentialsExtension) UnmarshalJSON(data []byte) error {
+ var signatureAlgorithms struct {
+ Algorithms []string `json:"supported_signature_algorithms"`
+ }
+ if err := json.Unmarshal(data, &signatureAlgorithms); err != nil {
+ return err
+ }
+
+ for _, sigScheme := range signatureAlgorithms.Algorithms {
+ if sigScheme == "GREASE" {
+ e.SupportedSignatureAlgorithms = append(e.SupportedSignatureAlgorithms, GREASE_PLACEHOLDER)
+ continue
+ }
+
+ if scheme, ok := godicttls.DictSignatureSchemeNameIndexed[sigScheme]; ok {
+ e.SupportedSignatureAlgorithms = append(e.SupportedSignatureAlgorithms, SignatureScheme(scheme))
+ } else {
+ return fmt.Errorf("unknown delegated credentials signature scheme: %s", sigScheme)
+ }
+ }
+ return nil
+}
+
+// FakePreSharedKeyExtension is an extension used to set the PSK extension in the
+// ClientHello.
+//
+// Unfortunately, even when the PSK extension is set, there will be no PSK-based
+// resumption since crypto/tls does not implement PSK.
+type FakePreSharedKeyExtension struct {
+ PskIdentities []PskIdentity `json:"identities"`
+ PskBinders [][]byte `json:"binders"`
+}
+
+func (e *FakePreSharedKeyExtension) writeToUConn(uc *UConn) error {
+ if uc.config.ClientSessionCache == nil {
+ return nil // don't write the extension if there is no session cache
+ }
+ if session, ok := uc.config.ClientSessionCache.Get(clientSessionCacheKey(uc.conn.RemoteAddr(), uc.config)); !ok || session == nil {
+ return nil // don't write the extension if there is no session cache available for this session
+ }
+ uc.HandshakeState.Hello.PskIdentities = e.PskIdentities
+ uc.HandshakeState.Hello.PskBinders = e.PskBinders
+ return nil
+}
+
+func (e *FakePreSharedKeyExtension) Len() int {
+ length := 4 // extension type + extension length
+ length += 2 // identities length
+ for _, identity := range e.PskIdentities {
+ length += 2 + len(identity.Label) + 4 // identity length + identity + obfuscated ticket age
+ }
+ length += 2 // binders length
+ for _, binder := range e.PskBinders {
+ length += len(binder)
+ }
+ return length
+}
+
+func (e *FakePreSharedKeyExtension) Read(b []byte) (int, error) {
+ if len(b) < e.Len() {
+ return 0, io.ErrShortBuffer
+ }
+
+ b[0] = byte(extensionPreSharedKey >> 8)
+ b[1] = byte(extensionPreSharedKey)
+ b[2] = byte((e.Len() - 4) >> 8)
+ b[3] = byte(e.Len() - 4)
+
+ // identities length
+ identitiesLength := 0
+ for _, identity := range e.PskIdentities {
+ identitiesLength += 2 + len(identity.Label) + 4 // identity length + identity + obfuscated ticket age
+ }
+ b[4] = byte(identitiesLength >> 8)
+ b[5] = byte(identitiesLength)
+
+ // identities
+ offset := 6
+ for _, identity := range e.PskIdentities {
+ b[offset] = byte(len(identity.Label) >> 8)
+ b[offset+1] = byte(len(identity.Label))
+ offset += 2
+ copy(b[offset:], identity.Label)
+ offset += len(identity.Label)
+ b[offset] = byte(identity.ObfuscatedTicketAge >> 24)
+ b[offset+1] = byte(identity.ObfuscatedTicketAge >> 16)
+ b[offset+2] = byte(identity.ObfuscatedTicketAge >> 8)
+ b[offset+3] = byte(identity.ObfuscatedTicketAge)
+ offset += 4
+ }
+
+ // binders length
+ bindersLength := 0
+ for _, binder := range e.PskBinders {
+ bindersLength += len(binder)
+ }
+ b[offset] = byte(bindersLength >> 8)
+ b[offset+1] = byte(bindersLength)
+ offset += 2
+
+ // binders
+ for _, binder := range e.PskBinders {
+ copy(b[offset:], binder)
+ offset += len(binder)
+ }
+
+ return e.Len(), io.EOF
+}
+
+func (e *FakePreSharedKeyExtension) Write(b []byte) (n int, err error) {
+ fullLen := len(b)
+ s := cryptobyte.String(b)
+
+ var identitiesLength uint16
+ if !s.ReadUint16(&identitiesLength) {
+ return 0, errors.New("tls: invalid PSK extension")
+ }
+
+ // identities
+ for identitiesLength > 0 {
+ var identityLength uint16
+ if !s.ReadUint16(&identityLength) {
+ return 0, errors.New("tls: invalid PSK extension")
+ }
+ identitiesLength -= 2
+
+ if identityLength > identitiesLength {
+ return 0, errors.New("tls: invalid PSK extension")
+ }
+
+ var identity []byte
+ if !s.ReadBytes(&identity, int(identityLength)) {
+ return 0, errors.New("tls: invalid PSK extension")
+ }
+
+ identitiesLength -= identityLength // identity
+
+ var obfuscatedTicketAge uint32
+ if !s.ReadUint32(&obfuscatedTicketAge) {
+ return 0, errors.New("tls: invalid PSK extension")
+ }
+
+ e.PskIdentities = append(e.PskIdentities, PskIdentity{
+ Label: identity,
+ ObfuscatedTicketAge: obfuscatedTicketAge,
+ })
+
+ identitiesLength -= 4 // obfuscated ticket age
+ }
+
+ var bindersLength uint16
+ if !s.ReadUint16(&bindersLength) {
+ return 0, errors.New("tls: invalid PSK extension")
+ }
+
+ // binders
+ for bindersLength > 0 {
+ var binderLength uint8
+ if !s.ReadUint8(&binderLength) {
+ return 0, errors.New("tls: invalid PSK extension")
+ }
+ bindersLength -= 1
+
+ if uint16(binderLength) > bindersLength {
+ return 0, errors.New("tls: invalid PSK extension")
+ }
+
+ var binder []byte
+ if !s.ReadBytes(&binder, int(binderLength)) {
+ return 0, errors.New("tls: invalid PSK extension")
+ }
+
+ e.PskBinders = append(e.PskBinders, binder)
+
+ bindersLength -= uint16(binderLength)
+ }
+
+ return fullLen, nil
+}
+
+func (e *FakePreSharedKeyExtension) UnmarshalJSON(data []byte) error {
+ var pskAccepter struct {
+ PskIdentities []PskIdentity `json:"identities"`
+ PskBinders [][]byte `json:"binders"`
+ }
+
+ if err := json.Unmarshal(data, &pskAccepter); err != nil {
+ return err
+ }
+
+ e.PskIdentities = pskAccepter.PskIdentities
+ e.PskBinders = pskAccepter.PskBinders
+ return nil
+}