Compare commits
4 commits
Author | SHA1 | Date | |
---|---|---|---|
|
6d5299715e | ||
|
ba8646fa83 | ||
|
c1ed78ffc7 | ||
|
cf3b7ec2fa |
11 changed files with 581 additions and 53 deletions
|
@ -27,7 +27,7 @@ type ConnectionState struct {
|
|||
ready bool
|
||||
}
|
||||
|
||||
func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern noise.HandshakePattern, psk []byte, pskStage int) *ConnectionState {
|
||||
func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, p []byte) (*ConnectionState, error) {
|
||||
cs := noise.NewCipherSuite(noise.DH25519, noise.CipherAESGCM, noise.HashSHA256)
|
||||
if f.cipher == "chachapoly" {
|
||||
cs = noise.NewCipherSuite(noise.DH25519, noise.CipherChaChaPoly, noise.HashSHA256)
|
||||
|
@ -43,14 +43,15 @@ func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern
|
|||
hs, err := noise.NewHandshakeState(noise.Config{
|
||||
CipherSuite: cs,
|
||||
Random: rand.Reader,
|
||||
Pattern: pattern,
|
||||
Pattern: noise.HandshakeIX,
|
||||
Initiator: initiator,
|
||||
StaticKeypair: static,
|
||||
PresharedKey: psk,
|
||||
PresharedKeyPlacement: pskStage,
|
||||
PresharedKey: p,
|
||||
PresharedKeyPlacement: 0,
|
||||
})
|
||||
|
||||
if err != nil {
|
||||
return nil
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// The queue and ready params prevent a counter race that would happen when
|
||||
|
@ -63,7 +64,7 @@ func (f *Interface) newConnectionState(l *logrus.Logger, initiator bool, pattern
|
|||
certState: curCertState,
|
||||
}
|
||||
|
||||
return ci
|
||||
return ci, nil
|
||||
}
|
||||
|
||||
func (cs *ConnectionState) MarshalJSON() ([]byte, error) {
|
||||
|
|
|
@ -18,8 +18,8 @@ import (
|
|||
|
||||
func TestGoodHandshake(t *testing.T) {
|
||||
ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
|
||||
myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1})
|
||||
theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2})
|
||||
myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, nil)
|
||||
theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
|
||||
|
||||
// Put their info in our lighthouse
|
||||
myControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr)
|
||||
|
@ -70,9 +70,9 @@ func TestWrongResponderHandshake(t *testing.T) {
|
|||
// The IPs here are chosen on purpose:
|
||||
// The current remote handling will sort by preference, public, and then lexically.
|
||||
// So we need them to have a higher address than evil (we could apply a preference though)
|
||||
myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 100})
|
||||
theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 99})
|
||||
evilControl, evilVpnIp, evilUdpAddr := newSimpleServer(ca, caKey, "evil", net.IP{10, 0, 0, 2})
|
||||
myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 100}, nil)
|
||||
theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 99}, nil)
|
||||
evilControl, evilVpnIp, evilUdpAddr := newSimpleServer(ca, caKey, "evil", net.IP{10, 0, 0, 2}, nil)
|
||||
|
||||
// Add their real udp addr, which should be tried after evil.
|
||||
myControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr)
|
||||
|
@ -130,8 +130,8 @@ func TestWrongResponderHandshake(t *testing.T) {
|
|||
|
||||
func Test_Case1_Stage1Race(t *testing.T) {
|
||||
ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
|
||||
myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1})
|
||||
theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2})
|
||||
myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me ", net.IP{10, 0, 0, 1}, nil)
|
||||
theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, nil)
|
||||
|
||||
// Put their info in our lighthouse and vice versa
|
||||
myControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr)
|
||||
|
@ -183,3 +183,151 @@ func Test_Case1_Stage1Race(t *testing.T) {
|
|||
}
|
||||
|
||||
//TODO: add a test with many lies
|
||||
|
||||
func TestPSK(t *testing.T) {
|
||||
tests := []struct {
|
||||
name string
|
||||
myPskMode nebula.PskMode
|
||||
theirPskMode nebula.PskMode
|
||||
}{
|
||||
// None and transitional-accepting both ways
|
||||
{
|
||||
name: "none to transitional-accepting",
|
||||
myPskMode: nebula.PskNone,
|
||||
theirPskMode: nebula.PskTransitionalAccepting,
|
||||
},
|
||||
{
|
||||
name: "transitional-accepting to none",
|
||||
myPskMode: nebula.PskTransitionalAccepting,
|
||||
theirPskMode: nebula.PskNone,
|
||||
},
|
||||
|
||||
// All transitional-accepting
|
||||
{
|
||||
name: "both transitional-accepting",
|
||||
myPskMode: nebula.PskTransitionalAccepting,
|
||||
theirPskMode: nebula.PskTransitionalAccepting,
|
||||
},
|
||||
|
||||
// transitional-accepting and transitional-sending both ways
|
||||
{
|
||||
name: "transitional-accepting to transitional-sending",
|
||||
myPskMode: nebula.PskTransitionalAccepting,
|
||||
theirPskMode: nebula.PskTransitionalSending,
|
||||
},
|
||||
{
|
||||
name: "transitional-sending to transitional-accepting",
|
||||
myPskMode: nebula.PskTransitionalSending,
|
||||
theirPskMode: nebula.PskTransitionalAccepting,
|
||||
},
|
||||
|
||||
// All transitional-sending
|
||||
{
|
||||
name: "transitional-sending to transitional-sending",
|
||||
myPskMode: nebula.PskTransitionalSending,
|
||||
theirPskMode: nebula.PskTransitionalSending,
|
||||
},
|
||||
|
||||
// enforced and transitional-sending both ways
|
||||
{
|
||||
name: "enforced to transitional-sending",
|
||||
myPskMode: nebula.PskEnforced,
|
||||
theirPskMode: nebula.PskTransitionalSending,
|
||||
},
|
||||
{
|
||||
name: "transitional-sending to enforced",
|
||||
myPskMode: nebula.PskTransitionalSending,
|
||||
theirPskMode: nebula.PskEnforced,
|
||||
},
|
||||
|
||||
// All enforced
|
||||
{
|
||||
name: "both enforced",
|
||||
myPskMode: nebula.PskEnforced,
|
||||
theirPskMode: nebula.PskEnforced,
|
||||
},
|
||||
|
||||
// Enforced can technically handshake with a traditional-accepting but it is bad to be in this state
|
||||
{
|
||||
name: "enforced to traditional-accepting",
|
||||
myPskMode: nebula.PskEnforced,
|
||||
theirPskMode: nebula.PskTransitionalAccepting,
|
||||
},
|
||||
}
|
||||
|
||||
for _, test := range tests {
|
||||
t.Run(test.name, func(t *testing.T) {
|
||||
var myPskSettings, theirPskSettings *m
|
||||
|
||||
switch test.myPskMode {
|
||||
case nebula.PskNone:
|
||||
myPskSettings = &m{"handshakes": &m{"psk": &m{"mode": "none"}}}
|
||||
case nebula.PskTransitionalAccepting:
|
||||
myPskSettings = &m{"handshakes": &m{"psk": &m{"mode": "transitional-accepting", "keys": []string{"this is a key"}}}}
|
||||
case nebula.PskTransitionalSending:
|
||||
myPskSettings = &m{"handshakes": &m{"psk": &m{"mode": "transitional-sending", "keys": []string{"this is a key"}}}}
|
||||
case nebula.PskEnforced:
|
||||
myPskSettings = &m{"handshakes": &m{"psk": &m{"mode": "enforced", "keys": []string{"this is a key"}}}}
|
||||
}
|
||||
|
||||
switch test.theirPskMode {
|
||||
case nebula.PskNone:
|
||||
theirPskSettings = &m{"handshakes": &m{"psk": &m{"mode": "none"}}}
|
||||
case nebula.PskTransitionalAccepting:
|
||||
theirPskSettings = &m{"handshakes": &m{"psk": &m{"mode": "transitional-accepting", "keys": []string{"this is a key"}}}}
|
||||
case nebula.PskTransitionalSending:
|
||||
theirPskSettings = &m{"handshakes": &m{"psk": &m{"mode": "transitional-sending", "keys": []string{"this is a key"}}}}
|
||||
case nebula.PskEnforced:
|
||||
theirPskSettings = &m{"handshakes": &m{"psk": &m{"mode": "enforced", "keys": []string{"this is a key"}}}}
|
||||
}
|
||||
|
||||
ca, _, caKey, _ := newTestCaCert(time.Now(), time.Now().Add(10*time.Minute), []*net.IPNet{}, []*net.IPNet{}, []string{})
|
||||
myControl, myVpnIp, myUdpAddr := newSimpleServer(ca, caKey, "me", net.IP{10, 0, 0, 1}, myPskSettings)
|
||||
theirControl, theirVpnIp, theirUdpAddr := newSimpleServer(ca, caKey, "them", net.IP{10, 0, 0, 2}, theirPskSettings)
|
||||
|
||||
myControl.InjectLightHouseAddr(theirVpnIp, theirUdpAddr)
|
||||
r := router.NewR(myControl, theirControl)
|
||||
|
||||
// Start the servers
|
||||
myControl.Start()
|
||||
theirControl.Start()
|
||||
|
||||
t.Log("Route until we see our cached packet flow")
|
||||
myControl.InjectTunUDPPacket(theirVpnIp, 80, 80, []byte("Hi from me"))
|
||||
r.RouteForAllExitFunc(func(p *udp.Packet, c *nebula.Control) router.ExitType {
|
||||
h := &header.H{}
|
||||
err := h.Parse(p.Data)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
// If this is the stage 1 handshake packet and I am configured to send with a psk, my cert name should
|
||||
// not appear. It would likely be more obvious to unmarshal the payload and check but this works fine for now
|
||||
if test.myPskMode == nebula.PskEnforced || test.myPskMode == nebula.PskTransitionalSending {
|
||||
if h.Type == 0 && h.MessageCounter == 1 {
|
||||
assert.NotContains(t, string(p.Data), "test me")
|
||||
}
|
||||
}
|
||||
|
||||
if p.ToIp.Equal(theirUdpAddr.IP) && p.ToPort == uint16(theirUdpAddr.Port) && h.Type == 1 {
|
||||
return router.RouteAndExit
|
||||
}
|
||||
|
||||
return router.KeepRouting
|
||||
})
|
||||
|
||||
t.Log("My cached packet should be received by them")
|
||||
myCachedPacket := theirControl.GetFromTun(true)
|
||||
assertUdpPacket(t, []byte("Hi from me"), myCachedPacket, myVpnIp, theirVpnIp, 80, 80)
|
||||
|
||||
t.Log("Test the tunnel with them")
|
||||
assertHostInfoPair(t, myUdpAddr, theirUdpAddr, myVpnIp, theirVpnIp, myControl, theirControl)
|
||||
assertTunnel(t, myVpnIp, theirVpnIp, myControl, theirControl, r)
|
||||
|
||||
myControl.Stop()
|
||||
theirControl.Stop()
|
||||
//TODO: assert hostmaps
|
||||
})
|
||||
}
|
||||
|
||||
}
|
||||
|
|
|
@ -15,6 +15,7 @@ import (
|
|||
|
||||
"github.com/google/gopacket"
|
||||
"github.com/google/gopacket/layers"
|
||||
"github.com/imdario/mergo"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula"
|
||||
"github.com/slackhq/nebula/cert"
|
||||
|
@ -30,7 +31,7 @@ import (
|
|||
type m map[string]interface{}
|
||||
|
||||
// newSimpleServer creates a nebula instance with many assumptions
|
||||
func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp net.IP) (*nebula.Control, net.IP, *net.UDPAddr) {
|
||||
func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, udpIp net.IP, customConfig *m) (*nebula.Control, net.IP, *net.UDPAddr) {
|
||||
l := NewTestLogger()
|
||||
|
||||
vpnIpNet := &net.IPNet{IP: make([]byte, len(udpIp)), Mask: net.IPMask{255, 255, 255, 0}}
|
||||
|
@ -40,7 +41,7 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u
|
|||
IP: udpIp,
|
||||
Port: 4242,
|
||||
}
|
||||
_, _, myPrivKey, myPEM := newTestCert(caCrt, caKey, name, time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{})
|
||||
_, _, myPrivKey, myPEM := newTestCert(caCrt, caKey, "test "+name, time.Now(), time.Now().Add(5*time.Minute), vpnIpNet, nil, []string{})
|
||||
|
||||
caB, err := caCrt.MarshalToPEM()
|
||||
if err != nil {
|
||||
|
@ -86,6 +87,24 @@ func newSimpleServer(caCrt *cert.NebulaCertificate, caKey []byte, name string, u
|
|||
c := config.NewC(l)
|
||||
c.LoadString(string(cb))
|
||||
|
||||
if customConfig != nil {
|
||||
ccb, err := yaml.Marshal(customConfig)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
ccm := map[interface{}]interface{}{}
|
||||
err = yaml.Unmarshal(ccb, &ccm)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
|
||||
err = mergo.Merge(&c.Settings, ccm, mergo.WithAppendSlice)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
control, err := nebula.Main(c, false, "e2e-test", l, nil)
|
||||
|
||||
if err != nil {
|
||||
|
|
|
@ -215,17 +215,45 @@ logging:
|
|||
# e.g.: `lighthouse.rx.HostQuery`
|
||||
#lighthouse_metrics: false
|
||||
|
||||
# Handshake Manager Settings
|
||||
#handshakes:
|
||||
# Handshake Manger Settings
|
||||
handshakes:
|
||||
# Handshakes are sent to all known addresses at each interval with a linear backoff,
|
||||
# Wait try_interval after the 1st attempt, 2 * try_interval after the 2nd, etc, until the handshake is older than timeout
|
||||
# A 100ms interval with the default 10 retries will give a handshake 5.5 seconds to resolve before timing out
|
||||
#try_interval: 100ms
|
||||
#retries: 20
|
||||
|
||||
# trigger_buffer is the size of the buffer channel for quickly sending handshakes
|
||||
# after receiving the response for lighthouse queries
|
||||
#trigger_buffer: 64
|
||||
|
||||
# psk can be used to mask the contents of handshakes and makes handshaking with unintended recipients more difficult
|
||||
# all settings respond to a reload
|
||||
psk:
|
||||
# mode defines how the pre shared keys can be used in a handshake
|
||||
# `none` (the default) does not send or receive using a psk. Ideally `enforced` is used
|
||||
# `transitional-accepting` will send handshakes without using a psk and can receive handshakes using a psk we know about
|
||||
# `transitional-sending` will send handshakes using a psk but will still accept handshakes without them
|
||||
# `enforced` enforces the use of a psk for all tunnels. Any node not also using `enforced` or `transitional-sending` can not handshake with us
|
||||
#
|
||||
# When moving from `none` to `enforced` you will want to change every node in the mesh to `transitional-accepting` and reload
|
||||
# then move every node to `transitional-sending` then reload, and finally `enforced` then reload. This allows you to
|
||||
# avoid stopping the world to use psk. You must ensure at `transitional-accepting` that all nodes have the same psks.
|
||||
#mode: none
|
||||
|
||||
# In `transitional-accepting`, `transitional-sending` and `enforced` modes, the keys provided here are sent through
|
||||
# hkdf with the intended recipients ip used in the info section. This helps guard against handshaking with the wrong
|
||||
# host if your static_host_map or lighthouse(s) has incorrect information.
|
||||
#
|
||||
# Setting keys if mode is `none` has no effect.
|
||||
#
|
||||
# Only the first key is used for outbound handshakes but all keys provided will be tried in the order specified, on
|
||||
# incoming handshakes. This is to allow for psk rotation.
|
||||
#keys:
|
||||
# - shared secret string, this one is used in all outbound handshakes
|
||||
# - this is a fallback key, received handshakes can use this
|
||||
# - another fallback, received handshakes can use this one too
|
||||
# - "\x68\x65\x6c\x6c\x6f\x20\x66\x72\x69\x65\x6e\x64\x73" # for raw bytes if you desire
|
||||
|
||||
# Nebula security group configuration
|
||||
firewall:
|
||||
|
|
|
@ -4,7 +4,6 @@ import (
|
|||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/flynn/noise"
|
||||
"github.com/golang/protobuf/proto"
|
||||
"github.com/slackhq/nebula/header"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
|
@ -71,28 +70,51 @@ func ixHandshakeStage0(f *Interface, vpnIp iputil.VpnIp, hostinfo *HostInfo) {
|
|||
}
|
||||
|
||||
func ixHandshakeStage1(f *Interface, addr *udp.Addr, packet []byte, h *header.H) {
|
||||
ci := f.newConnectionState(f.l, false, noise.HandshakeIX, []byte{}, 0)
|
||||
// Mark packet 1 as seen so it doesn't show up as missed
|
||||
ci.window.Update(f.l, 1)
|
||||
|
||||
msg, _, _, err := ci.H.ReadMessage(nil, packet[header.Len:])
|
||||
if err != nil {
|
||||
f.l.WithError(err).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed to call noise.ReadMessage")
|
||||
return
|
||||
}
|
||||
var (
|
||||
err error
|
||||
ci *ConnectionState
|
||||
msg []byte
|
||||
)
|
||||
|
||||
hs := &NebulaHandshake{}
|
||||
err = proto.Unmarshal(msg, hs)
|
||||
/*
|
||||
l.Debugln("GOT INDEX: ", hs.Details.InitiatorIndex)
|
||||
*/
|
||||
if err != nil || hs.Details == nil {
|
||||
f.l.WithError(err).WithField("udpAddr", addr).
|
||||
WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).Error("Failed unmarshal handshake message")
|
||||
|
||||
// Handle multiple possible psk options, ensure the protobuf comes out clean too
|
||||
for _, p := range f.psk.Cache {
|
||||
//TODO: benchmark generation time of makePsk
|
||||
ci, err = f.newConnectionState(f.l, false, p)
|
||||
if err != nil {
|
||||
f.l.WithError(err).WithField("udpAddr", addr).Error("Failed to get a new connection state")
|
||||
continue
|
||||
}
|
||||
|
||||
msg, _, _, err = ci.H.ReadMessage(nil, packet[header.Len:])
|
||||
if err != nil {
|
||||
// Calls to ReadMessage with an incorrect psk should fail, try the next one if we have one
|
||||
continue
|
||||
}
|
||||
|
||||
// Sometimes ReadMessage returns fine with a nil psk even if the handshake is using a psk, ensure our protobuf
|
||||
// comes out clean as well
|
||||
err = proto.Unmarshal(msg, hs)
|
||||
if err == nil {
|
||||
// There was no error, we can continue with this handshake
|
||||
break
|
||||
}
|
||||
|
||||
// The unmarshal failed, try the next psk if we have one
|
||||
}
|
||||
|
||||
// We finished with an error, log it and get out
|
||||
if err != nil {
|
||||
// We aren't logging the error here because we can't be sure of the failure when using psk
|
||||
f.l.WithField("udpAddr", addr).WithField("handshake", m{"stage": 1, "style": "ix_psk0"}).
|
||||
Error("Was unable to decrypt the handshake")
|
||||
return
|
||||
}
|
||||
|
||||
// Mark packet 1 as seen so it doesn't show up as missed
|
||||
ci.window.Update(f.l, 1)
|
||||
|
||||
remoteCert, err := RecombineCertAndValidate(ci.H, hs.Details.Cert, f.caPool)
|
||||
if err != nil {
|
||||
f.l.WithError(err).WithField("udpAddr", addr).
|
||||
|
|
24
inside.go
24
inside.go
|
@ -3,7 +3,6 @@ package nebula
|
|||
import (
|
||||
"sync/atomic"
|
||||
|
||||
"github.com/flynn/noise"
|
||||
"github.com/sirupsen/logrus"
|
||||
"github.com/slackhq/nebula/firewall"
|
||||
"github.com/slackhq/nebula/header"
|
||||
|
@ -79,7 +78,6 @@ func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo {
|
|||
}
|
||||
hostinfo, err := f.hostMap.PromoteBestQueryVpnIp(vpnIp, f)
|
||||
|
||||
//if err != nil || hostinfo.ConnectionState == nil {
|
||||
if err != nil {
|
||||
hostinfo, err = f.handshakeManager.pendingHostMap.QueryVpnIp(vpnIp)
|
||||
if err != nil {
|
||||
|
@ -102,21 +100,27 @@ func (f *Interface) getOrHandshake(vpnIp iputil.VpnIp) *HostInfo {
|
|||
return hostinfo
|
||||
}
|
||||
|
||||
// Create a connection state if we don't have one yet
|
||||
if ci == nil {
|
||||
// if we don't have a connection state, then send a handshake initiation
|
||||
ci = f.newConnectionState(f.l, true, noise.HandshakeIX, []byte{}, 0)
|
||||
// FIXME: Maybe make XX selectable, but probably not since psk makes it nearly pointless for us.
|
||||
//ci = f.newConnectionState(true, noise.HandshakeXX, []byte{}, 0)
|
||||
// Generate a PSK based on our config, this may be nil
|
||||
p, err := f.psk.MakeFor(vpnIp)
|
||||
if err != nil {
|
||||
//TODO: This isn't fatal specifically but it's pretty bad
|
||||
f.l.WithError(err).WithField("vpnIp", vpnIp).Error("Failed to get a PSK KDF")
|
||||
return hostinfo
|
||||
}
|
||||
|
||||
ci, err = f.newConnectionState(f.l, true, p)
|
||||
if err != nil {
|
||||
f.l.WithError(err).WithField("vpnIp", vpnIp).Error("Failed to get a connection state")
|
||||
return hostinfo
|
||||
}
|
||||
hostinfo.ConnectionState = ci
|
||||
} else if ci.eKey == nil {
|
||||
// if we don't have any state at all, create it
|
||||
}
|
||||
|
||||
// If we have already created the handshake packet, we don't want to call the function at all.
|
||||
if !hostinfo.HandshakeReady {
|
||||
ixHandshakeStage0(f, vpnIp, hostinfo)
|
||||
// FIXME: Maybe make XX selectable, but probably not since psk makes it nearly pointless for us.
|
||||
//xx_handshakeStage0(f, ip, hostinfo)
|
||||
|
||||
// If this is a static host, we don't need to wait for the HostQueryReply
|
||||
// We can trigger the handshake right now
|
||||
|
|
25
interface.go
25
interface.go
|
@ -7,7 +7,9 @@ import (
|
|||
"net"
|
||||
"os"
|
||||
"runtime"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
"unsafe"
|
||||
|
||||
"github.com/rcrowley/go-metrics"
|
||||
"github.com/sirupsen/logrus"
|
||||
|
@ -48,6 +50,7 @@ type InterfaceConfig struct {
|
|||
version string
|
||||
caPool *cert.NebulaCAPool
|
||||
disconnectInvalid bool
|
||||
psk *Psk
|
||||
|
||||
ConntrackCacheTimeout time.Duration
|
||||
l *logrus.Logger
|
||||
|
@ -78,9 +81,9 @@ type Interface struct {
|
|||
version string
|
||||
|
||||
conntrackCacheTimeout time.Duration
|
||||
|
||||
writers []*udp.Conn
|
||||
readers []io.ReadWriteCloser
|
||||
psk *Psk
|
||||
writers []*udp.Conn
|
||||
readers []io.ReadWriteCloser
|
||||
|
||||
metricHandshakes metrics.Histogram
|
||||
messageMetrics *MessageMetrics
|
||||
|
@ -104,6 +107,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||
}
|
||||
|
||||
myVpnIp := iputil.Ip2VpnIp(c.certState.certificate.Details.Ips[0].IP)
|
||||
|
||||
ifce := &Interface{
|
||||
hostMap: c.HostMap,
|
||||
outside: c.Outside,
|
||||
|
@ -124,6 +128,7 @@ func NewInterface(ctx context.Context, c *InterfaceConfig) (*Interface, error) {
|
|||
readers: make([]io.ReadWriteCloser, c.routines),
|
||||
caPool: c.caPool,
|
||||
disconnectInvalid: c.disconnectInvalid,
|
||||
psk: c.psk,
|
||||
myVpnIp: myVpnIp,
|
||||
|
||||
conntrackCacheTimeout: c.ConntrackCacheTimeout,
|
||||
|
@ -234,6 +239,7 @@ func (f *Interface) RegisterConfigChangeCallbacks(c *config.C) {
|
|||
for _, udpConn := range f.writers {
|
||||
c.RegisterReloadCallback(udpConn.ReloadConfig)
|
||||
}
|
||||
c.RegisterReloadCallback(f.reloadPSKs)
|
||||
}
|
||||
|
||||
func (f *Interface) reloadCA(c *config.C) {
|
||||
|
@ -308,6 +314,19 @@ func (f *Interface) reloadFirewall(c *config.C) {
|
|||
Info("New firewall has been installed")
|
||||
}
|
||||
|
||||
func (f *Interface) reloadPSKs(c *config.C) {
|
||||
psk, err := NewPskFromConfig(c, f.myVpnIp)
|
||||
if err != nil {
|
||||
f.l.WithError(err).Error("Error while reloading PSKs")
|
||||
return
|
||||
}
|
||||
|
||||
atomic.StorePointer((*unsafe.Pointer)(unsafe.Pointer(&f.psk)), unsafe.Pointer(psk))
|
||||
|
||||
f.l.WithField("pskMode", psk.mode).WithField("keysLen", len(psk.Cache)).
|
||||
Info("New psks are in use")
|
||||
}
|
||||
|
||||
func (f *Interface) emitStats(ctx context.Context, i time.Duration) {
|
||||
ticker := time.NewTicker(i)
|
||||
defer ticker.Stop()
|
||||
|
|
10
main.go
10
main.go
|
@ -95,6 +95,11 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||
}
|
||||
}
|
||||
|
||||
psk, err := NewPskFromConfig(c, iputil.Ip2VpnIp(tunCidr.IP))
|
||||
if err != nil {
|
||||
return nil, NewContextualError("Failed to create psk", nil, err)
|
||||
}
|
||||
|
||||
////////////////////////////////////////////////////////////////////////////////////////////////////////////////////
|
||||
// All non system modifying configuration consumption should live above this line
|
||||
// tun config, listeners, anything modifying the computer should be below
|
||||
|
@ -356,10 +361,6 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||
handshakeManager := NewHandshakeManager(l, tunCidr, preferredRanges, hostMap, lightHouse, udpConns[0], handshakeConfig)
|
||||
lightHouse.handshakeTrigger = handshakeManager.trigger
|
||||
|
||||
//TODO: These will be reused for psk
|
||||
//handshakeMACKey := config.GetString("handshake_mac.key", "")
|
||||
//handshakeAcceptedMACKeys := config.GetStringSlice("handshake_mac.accepted_keys", []string{})
|
||||
|
||||
serveDns := false
|
||||
if c.GetBool("lighthouse.serve_dns", false) {
|
||||
if c.GetBool("lighthouse.am_lighthouse", false) {
|
||||
|
@ -390,6 +391,7 @@ func Main(c *config.C, configTest bool, buildVersion string, logger *logrus.Logg
|
|||
version: buildVersion,
|
||||
caPool: caPool,
|
||||
disconnectInvalid: c.GetBool("pki.disconnect_invalid", false),
|
||||
psk: psk,
|
||||
|
||||
ConntrackCacheTimeout: conntrackCacheTimeout,
|
||||
l: l,
|
||||
|
|
1
noise.go
1
noise.go
|
@ -22,7 +22,6 @@ type NebulaCipherState struct {
|
|||
|
||||
func NewNebulaCipherState(s *noise.CipherState) *NebulaCipherState {
|
||||
return &NebulaCipherState{c: s.Cipher()}
|
||||
|
||||
}
|
||||
|
||||
func (s *NebulaCipherState) EncryptDanger(out, ad, plaintext []byte, n uint64, nb []byte) ([]byte, error) {
|
||||
|
|
183
psk.go
Normal file
183
psk.go
Normal file
|
@ -0,0 +1,183 @@
|
|||
package nebula
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"github.com/slackhq/nebula/config"
|
||||
"github.com/slackhq/nebula/iputil"
|
||||
"golang.org/x/crypto/hkdf"
|
||||
)
|
||||
|
||||
var ErrNotAPskMode = errors.New("not a psk mode")
|
||||
var ErrKeyTooShort = errors.New("key is too short")
|
||||
var ErrNotEnoughPskKeys = errors.New("at least 1 key is required")
|
||||
|
||||
// The minimum length that we accept for a user defined psk, the choice is arbitrary
|
||||
const MinPskLength = 8
|
||||
|
||||
type PskMode int
|
||||
|
||||
func (p PskMode) String() string {
|
||||
switch p {
|
||||
case PskNone:
|
||||
return "none"
|
||||
case PskTransitionalAccepting:
|
||||
return "transitional-accepting"
|
||||
case PskTransitionalSending:
|
||||
return "transitional-sending"
|
||||
case PskEnforced:
|
||||
return "enforced"
|
||||
}
|
||||
|
||||
return "unknown"
|
||||
}
|
||||
|
||||
func NewPskMode(m string) (PskMode, error) {
|
||||
switch m {
|
||||
case "none":
|
||||
return PskNone, nil
|
||||
case "transitional-accepting":
|
||||
return PskTransitionalAccepting, nil
|
||||
case "transitional-sending":
|
||||
return PskTransitionalSending, nil
|
||||
case "enforced":
|
||||
return PskEnforced, nil
|
||||
}
|
||||
return PskNone, ErrNotAPskMode
|
||||
}
|
||||
|
||||
const (
|
||||
PskNone PskMode = 0
|
||||
PskTransitionalAccepting PskMode = 1
|
||||
PskTransitionalSending PskMode = 2
|
||||
PskEnforced PskMode = 3
|
||||
)
|
||||
|
||||
type Psk struct {
|
||||
// pskMode sets how psk works, ignored, allowed for incoming, or enforced for all
|
||||
mode PskMode
|
||||
|
||||
// Cache holds all pre-computed psk hkdfs
|
||||
// Handshakes iterate this directly
|
||||
Cache [][]byte
|
||||
|
||||
// The key has already been extracted and is ready to be expanded for use
|
||||
// MakeFor does the final expand step mixing in the intended recipients vpn ip
|
||||
key []byte
|
||||
}
|
||||
|
||||
// NewPskFromConfig is a helper for initial boot and config reloading.
|
||||
func NewPskFromConfig(c *config.C, myVpnIp iputil.VpnIp) (*Psk, error) {
|
||||
sMode := c.GetString("handshakes.psk.mode", "none")
|
||||
mode, err := NewPskMode(sMode)
|
||||
if err != nil {
|
||||
return nil, NewContextualError("Could not parse handshakes.psk.mode", m{"mode": mode}, err)
|
||||
}
|
||||
|
||||
return NewPsk(
|
||||
mode,
|
||||
c.GetStringSlice("handshakes.psk.keys", nil),
|
||||
myVpnIp,
|
||||
)
|
||||
}
|
||||
|
||||
// NewPsk creates a new Psk object and handles the caching of all accepted keys and preparation of the primary key
|
||||
func NewPsk(mode PskMode, keys []string, myVpnIp iputil.VpnIp) (*Psk, error) {
|
||||
psk := &Psk{
|
||||
mode: mode,
|
||||
}
|
||||
|
||||
err := psk.preparePrimaryKey(keys)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
err = psk.cachePsks(myVpnIp, keys)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return psk, nil
|
||||
}
|
||||
|
||||
// MakeFor if we are in enforced mode, the final hkdf expand stage is done on the pre extracted primary key,
|
||||
// mixing in the intended recipients vpn ip and the result is returned.
|
||||
// If we are transitional or not using psks, an empty byte slice is returned
|
||||
func (p *Psk) MakeFor(vpnIp iputil.VpnIp) ([]byte, error) {
|
||||
if p.mode == PskNone || p.mode == PskTransitionalAccepting {
|
||||
return []byte{}, nil
|
||||
}
|
||||
|
||||
hmacKey := make([]byte, sha256.Size)
|
||||
_, err := io.ReadFull(hkdf.Expand(sha256.New, p.key, vpnIp.ToIP()), hmacKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return hmacKey, nil
|
||||
}
|
||||
|
||||
// cachePsks generates all psks we accept and caches them to speed up handshaking
|
||||
func (p *Psk) cachePsks(myVpnIp iputil.VpnIp, keys []string) error {
|
||||
// If PskNone is set then we are using the nil byte array for a psk, we can return
|
||||
if p.mode == PskNone {
|
||||
p.Cache = [][]byte{nil}
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(keys) < 1 {
|
||||
return ErrNotEnoughPskKeys
|
||||
}
|
||||
|
||||
p.Cache = [][]byte{}
|
||||
|
||||
if p.mode == PskTransitionalAccepting || p.mode == PskTransitionalSending {
|
||||
// We are transitional, we accept empty psks
|
||||
p.Cache = append(p.Cache, nil)
|
||||
}
|
||||
|
||||
// We are either PskAuto or PskTransitional, build all possibilities
|
||||
for i, rk := range keys {
|
||||
k, err := sha256KdfFromString(rk, myVpnIp)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to generate key for position %v: %w", i, err)
|
||||
}
|
||||
|
||||
p.Cache = append(p.Cache, k)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// preparePrimaryKey if we are in enforced mode, will do an hkdf extract on the first key to benefit
|
||||
// outgoing handshake performance, MakeFor does the final expand step
|
||||
func (p *Psk) preparePrimaryKey(keys []string) error {
|
||||
if p.mode == PskNone || p.mode == PskTransitionalAccepting {
|
||||
// If we aren't enforcing then there is nothing to prepare
|
||||
return nil
|
||||
}
|
||||
|
||||
if len(keys) < 1 {
|
||||
return ErrNotEnoughPskKeys
|
||||
}
|
||||
|
||||
p.key = hkdf.Extract(sha256.New, []byte(keys[0]), nil)
|
||||
return nil
|
||||
}
|
||||
|
||||
// sha256KdfFromString generates a full hkdf
|
||||
func sha256KdfFromString(secret string, vpnIp iputil.VpnIp) ([]byte, error) {
|
||||
if len(secret) < MinPskLength {
|
||||
return nil, ErrKeyTooShort
|
||||
}
|
||||
|
||||
hmacKey := make([]byte, sha256.Size)
|
||||
_, err := io.ReadFull(hkdf.New(sha256.New, []byte(secret), nil, vpnIp.ToIP()), hmacKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return hmacKey, nil
|
||||
}
|
103
psk_test.go
Normal file
103
psk_test.go
Normal file
|
@ -0,0 +1,103 @@
|
|||
package nebula
|
||||
|
||||
import (
|
||||
"testing"
|
||||
|
||||
"github.com/stretchr/testify/assert"
|
||||
)
|
||||
|
||||
func TestNewPsk(t *testing.T) {
|
||||
t.Run("mode none", func(t *testing.T) {
|
||||
p, err := NewPsk(PskNone, nil, 1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, PskNone, p.mode)
|
||||
assert.Empty(t, p.key)
|
||||
|
||||
assert.Len(t, p.Cache, 1)
|
||||
assert.Nil(t, p.Cache[0])
|
||||
|
||||
b, err := p.MakeFor(0)
|
||||
assert.Equal(t, []byte{}, b)
|
||||
})
|
||||
|
||||
t.Run("mode transitional-accepting", func(t *testing.T) {
|
||||
p, err := NewPsk(PskTransitionalAccepting, nil, 1)
|
||||
assert.Error(t, ErrNotEnoughPskKeys, err)
|
||||
|
||||
p, err = NewPsk(PskTransitionalAccepting, []string{"1234567"}, 1)
|
||||
assert.Error(t, ErrKeyTooShort)
|
||||
|
||||
p, err = NewPsk(PskTransitionalAccepting, []string{"hi there friends"}, 1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, PskTransitionalAccepting, p.mode)
|
||||
assert.Empty(t, p.key)
|
||||
|
||||
assert.Len(t, p.Cache, 2)
|
||||
assert.Nil(t, p.Cache[0])
|
||||
|
||||
expectedCache := []byte{146, 120, 135, 31, 158, 102, 45, 189, 128, 190, 37, 101, 58, 254, 6, 166, 91, 209, 148, 131, 27, 193, 24, 25, 170, 65, 130, 189, 7, 179, 255, 17}
|
||||
assert.Equal(t, expectedCache, p.Cache[1])
|
||||
|
||||
b, err := p.MakeFor(0)
|
||||
assert.Equal(t, []byte{}, b)
|
||||
})
|
||||
|
||||
t.Run("mode transitional-sending", func(t *testing.T) {
|
||||
p, err := NewPsk(PskTransitionalSending, nil, 1)
|
||||
assert.Error(t, ErrNotEnoughPskKeys, err)
|
||||
|
||||
p, err = NewPsk(PskTransitionalSending, []string{"1234567"}, 1)
|
||||
assert.Error(t, ErrKeyTooShort)
|
||||
|
||||
p, err = NewPsk(PskTransitionalSending, []string{"hi there friends"}, 1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, PskTransitionalSending, p.mode)
|
||||
|
||||
expectedKey := []byte{0x9c, 0x67, 0xab, 0x58, 0x79, 0x5c, 0x8a, 0xf0, 0xaa, 0xf0, 0x4c, 0x6c, 0x9a, 0x42, 0x6b, 0xe, 0xe2, 0x94, 0xb1, 0x0, 0x28, 0x1c, 0xdc, 0x88, 0x44, 0x35, 0x3f, 0xb7, 0xd5, 0x9, 0xc0, 0xda}
|
||||
assert.Equal(t, expectedKey, p.key)
|
||||
|
||||
assert.Len(t, p.Cache, 2)
|
||||
assert.Nil(t, p.Cache[0])
|
||||
|
||||
expectedCache := []byte{146, 120, 135, 31, 158, 102, 45, 189, 128, 190, 37, 101, 58, 254, 6, 166, 91, 209, 148, 131, 27, 193, 24, 25, 170, 65, 130, 189, 7, 179, 255, 17}
|
||||
assert.Equal(t, expectedCache, p.Cache[1])
|
||||
|
||||
expectedPsk := []byte{0xd9, 0x16, 0xa3, 0x66, 0x6a, 0x20, 0x26, 0xcf, 0x5d, 0x93, 0xad, 0xa3, 0x88, 0x2d, 0x57, 0xac, 0x9b, 0xc3, 0x5a, 0xb7, 0x8f, 0x6, 0x71, 0xc4, 0x3e, 0x5, 0x9e, 0xbc, 0x4e, 0xc8, 0x24, 0x17}
|
||||
b, err := p.MakeFor(0)
|
||||
assert.Equal(t, expectedPsk, b)
|
||||
})
|
||||
|
||||
t.Run("mode enforced", func(t *testing.T) {
|
||||
p, err := NewPsk(PskEnforced, nil, 1)
|
||||
assert.Error(t, ErrNotEnoughPskKeys, err)
|
||||
|
||||
p, err = NewPsk(PskEnforced, []string{"hi there friends"}, 1)
|
||||
assert.NoError(t, err)
|
||||
assert.Equal(t, PskEnforced, p.mode)
|
||||
|
||||
expectedKey := []byte{156, 103, 171, 88, 121, 92, 138, 240, 170, 240, 76, 108, 154, 66, 107, 14, 226, 148, 177, 0, 40, 28, 220, 136, 68, 53, 63, 183, 213, 9, 192, 218}
|
||||
assert.Equal(t, expectedKey, p.key)
|
||||
|
||||
assert.Len(t, p.Cache, 1)
|
||||
expectedCache := []byte{146, 120, 135, 31, 158, 102, 45, 189, 128, 190, 37, 101, 58, 254, 6, 166, 91, 209, 148, 131, 27, 193, 24, 25, 170, 65, 130, 189, 7, 179, 255, 17}
|
||||
assert.Equal(t, expectedCache, p.Cache[0])
|
||||
|
||||
expectedPsk := []byte{0xd9, 0x16, 0xa3, 0x66, 0x6a, 0x20, 0x26, 0xcf, 0x5d, 0x93, 0xad, 0xa3, 0x88, 0x2d, 0x57, 0xac, 0x9b, 0xc3, 0x5a, 0xb7, 0x8f, 0x6, 0x71, 0xc4, 0x3e, 0x5, 0x9e, 0xbc, 0x4e, 0xc8, 0x24, 0x17}
|
||||
b, err := p.MakeFor(0)
|
||||
assert.Equal(t, expectedPsk, b)
|
||||
|
||||
// Make sure different vpn ips generate different psks
|
||||
expectedPsk = []byte{0x92, 0x78, 0x87, 0x1f, 0x9e, 0x66, 0x2d, 0xbd, 0x80, 0xbe, 0x25, 0x65, 0x3a, 0xfe, 0x6, 0xa6, 0x5b, 0xd1, 0x94, 0x83, 0x1b, 0xc1, 0x18, 0x19, 0xaa, 0x41, 0x82, 0xbd, 0x7, 0xb3, 0xff, 0x11}
|
||||
b, err = p.MakeFor(1)
|
||||
assert.Equal(t, expectedPsk, b)
|
||||
})
|
||||
}
|
||||
|
||||
func BenchmarkPsk_MakeFor(b *testing.B) {
|
||||
p, err := NewPsk(PskEnforced, []string{"hi there friends"}, 1)
|
||||
assert.NoError(b, err)
|
||||
|
||||
for n := 0; n < b.N; n++ {
|
||||
p.MakeFor(99)
|
||||
}
|
||||
}
|
Loading…
Add table
Add a link
Reference in a new issue