diff --git a/cert/cert.go b/cert/cert.go index 4246571..38a2528 100644 --- a/cert/cert.go +++ b/cert/cert.go @@ -113,10 +113,10 @@ func (cc *CachedCertificate) String() string { return cc.Certificate.String() } -// RecombineAndValidate will attempt to unmarshal a certificate received in a handshake. +// Recombine will attempt to unmarshal a certificate received in a handshake. // Handshakes save space by placing the peers public key in a different part of the packet, we have to // reassemble the actual certificate structure with that in mind. -func RecombineAndValidate(v Version, rawCertBytes, publicKey []byte, curve Curve, caPool *CAPool) (*CachedCertificate, error) { +func Recombine(v Version, rawCertBytes, publicKey []byte, curve Curve) (Certificate, error) { if publicKey == nil { return nil, ErrNoPeerStaticKey } @@ -125,29 +125,15 @@ func RecombineAndValidate(v Version, rawCertBytes, publicKey []byte, curve Curve return nil, ErrNoPayload } - c, err := unmarshalCertificateFromHandshake(v, rawCertBytes, publicKey, curve) - if err != nil { - return nil, fmt.Errorf("error unmarshaling cert: %w", err) - } - - cc, err := caPool.VerifyCertificate(time.Now(), c) - if err != nil { - return nil, fmt.Errorf("certificate validation failed: %w", err) - } - - return cc, nil -} - -func unmarshalCertificateFromHandshake(v Version, b []byte, publicKey []byte, curve Curve) (Certificate, error) { var c Certificate var err error switch v { // Implementations must ensure the result is a valid cert! case VersionPre1, Version1: - c, err = unmarshalCertificateV1(b, publicKey) + c, err = unmarshalCertificateV1(rawCertBytes, publicKey) case Version2: - c, err = unmarshalCertificateV2(b, publicKey, curve) + c, err = unmarshalCertificateV2(rawCertBytes, publicKey, curve) default: //TODO: CERT-V2 make a static var return nil, fmt.Errorf("unknown certificate version %d", v) diff --git a/handshake_ix.go b/handshake_ix.go index 9b8b3e9..daea526 100644 --- a/handshake_ix.go +++ b/handshake_ix.go @@ -132,13 +132,28 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet return } - remoteCert, err := cert.RecombineAndValidate(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve(), f.pki.GetCAPool()) + rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve()) if err != nil { - e := f.l.WithError(err).WithField("udpAddr", addr). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}) + f.l.WithError(err).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). + Info("Handshake did not contain a certificate") + return + } - if f.l.Level > logrus.DebugLevel { - e = e.WithField("cert", remoteCert) + remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc) + if err != nil { + fp, err := rc.Fingerprint() + if err != nil { + fp = "" + } + + e := f.l.WithError(err).WithField("udpAddr", addr). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). + WithField("certVpnNetworks", rc.Networks()). + WithField("certFingerprint", fp) + + if f.l.Level >= logrus.DebugLevel { + e = e.WithField("cert", rc) } e.Info("Invalid certificate from host") @@ -160,14 +175,10 @@ func ixHandshakeStage1(f *Interface, addr netip.AddrPort, via *ViaSender, packet } if len(remoteCert.Certificate.Networks()) == 0 { - e := f.l.WithError(err).WithField("udpAddr", addr). - WithField("handshake", m{"stage": 1, "style": "ix_psk0"}) - - if f.l.Level > logrus.DebugLevel { - e = e.WithField("cert", remoteCert) - } - - e.Info("Invalid vpn ip from host") + f.l.WithError(err).WithField("udpAddr", addr). + WithField("cert", remoteCert). + WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). + Info("No networks in certificate") return } @@ -487,30 +498,42 @@ func ixHandshakeStage2(f *Interface, addr netip.AddrPort, via *ViaSender, hh *Ha return true } - remoteCert, err := cert.RecombineAndValidate(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve(), f.pki.GetCAPool()) + rc, err := cert.Recombine(cert.Version(hs.Details.CertVersion), hs.Details.Cert, ci.H.PeerStatic(), ci.Curve()) if err != nil { - e := f.l.WithError(err).WithField("vpnAddrs", hostinfo.vpnAddrs).WithField("udpAddr", addr). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}) + f.l.WithError(err).WithField("udpAddr", addr). + WithField("vpnAddrs", hostinfo.vpnAddrs). + WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). + Info("Handshake did not contain a certificate") + return true + } - if f.l.Level > logrus.DebugLevel { - e = e.WithField("cert", remoteCert) + remoteCert, err := f.pki.GetCAPool().VerifyCertificate(time.Now(), rc) + if err != nil { + fp, err := rc.Fingerprint() + if err != nil { + fp = "" } - e.Error("Invalid certificate from host") + e := f.l.WithError(err).WithField("udpAddr", addr). + WithField("vpnAddrs", hostinfo.vpnAddrs). + WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). + WithField("certFingerprint", fp). + WithField("certVpnNetworks", rc.Networks()) - // The handshake state machine is complete, if things break now there is no chance to recover. Tear down and start again + if f.l.Level >= logrus.DebugLevel { + e = e.WithField("cert", rc) + } + + e.Info("Invalid certificate from host") return true } if len(remoteCert.Certificate.Networks()) == 0 { - e := f.l.WithError(err).WithField("udpAddr", addr). - WithField("handshake", m{"stage": 2, "style": "ix_psk0"}) - - if f.l.Level > logrus.DebugLevel { - e = e.WithField("cert", remoteCert) - } - - e.Info("Empty networks from host") + f.l.WithError(err).WithField("udpAddr", addr). + WithField("vpnAddrs", hostinfo.vpnAddrs). + WithField("cert", remoteCert). + WithField("handshake", m{"stage": 2, "style": "ix_psk0"}). + Info("No networks in certificate") return true } diff --git a/handshake_manager.go b/handshake_manager.go index 6d3ed12..6f95402 100644 --- a/handshake_manager.go +++ b/handshake_manager.go @@ -257,7 +257,7 @@ func (hm *HandshakeManager) handleOutbound(vpnIp netip.Addr, lighthouseTriggered WithField("initiatorIndex", hostinfo.localIndexId). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}). Info("Handshake message sent") - } else if hm.l.IsLevelEnabled(logrus.DebugLevel) { + } else if hm.l.Level >= logrus.DebugLevel { hostinfo.logger(hm.l).WithField("udpAddrs", sentTo). WithField("initiatorIndex", hostinfo.localIndexId). WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).