Enable running testifylint in CI (#1350)

This commit is contained in:
Caleb Jasik 2025-03-10 17:38:14 -05:00 committed by GitHub
parent 612637f529
commit 088af8edb2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
25 changed files with 451 additions and 416 deletions

View file

@ -31,6 +31,11 @@ jobs:
- name: Vet - name: Vet
run: make vet run: make vet
- name: golangci-lint
uses: golangci/golangci-lint-action@v6
with:
version: v1.64
- name: Test - name: Test
run: make test run: make test
@ -109,6 +114,11 @@ jobs:
- name: Vet - name: Vet
run: make vet run: make vet
- name: golangci-lint
uses: golangci/golangci-lint-action@v6
with:
version: v1.64
- name: Test - name: Test
run: make test run: make test

9
.golangci.yaml Normal file
View file

@ -0,0 +1,9 @@
# yaml-language-server: $schema=https://golangci-lint.run/jsonschema/golangci.jsonschema.json
linters:
# Disable all linters.
# Default: false
disable-all: true
# Enable specific linter
# https://golangci-lint.run/usage/linters/#enabled-by-default
enable:
- testifylint

View file

@ -9,6 +9,7 @@ import (
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/test" "github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestNewAllowListFromConfig(t *testing.T) { func TestNewAllowListFromConfig(t *testing.T) {
@ -18,21 +19,21 @@ func TestNewAllowListFromConfig(t *testing.T) {
"192.168.0.0": true, "192.168.0.0": true,
} }
r, err := newAllowListFromConfig(c, "allowlist", nil) r, err := newAllowListFromConfig(c, "allowlist", nil)
assert.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0. netip.ParsePrefix(\"192.168.0.0\"): no '/'") require.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0. netip.ParsePrefix(\"192.168.0.0\"): no '/'")
assert.Nil(t, r) assert.Nil(t, r)
c.Settings["allowlist"] = map[interface{}]interface{}{ c.Settings["allowlist"] = map[interface{}]interface{}{
"192.168.0.0/16": "abc", "192.168.0.0/16": "abc",
} }
r, err = newAllowListFromConfig(c, "allowlist", nil) r, err = newAllowListFromConfig(c, "allowlist", nil)
assert.EqualError(t, err, "config `allowlist` has invalid value (type string): abc") require.EqualError(t, err, "config `allowlist` has invalid value (type string): abc")
c.Settings["allowlist"] = map[interface{}]interface{}{ c.Settings["allowlist"] = map[interface{}]interface{}{
"192.168.0.0/16": true, "192.168.0.0/16": true,
"10.0.0.0/8": false, "10.0.0.0/8": false,
} }
r, err = newAllowListFromConfig(c, "allowlist", nil) r, err = newAllowListFromConfig(c, "allowlist", nil)
assert.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for 0.0.0.0/0") require.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for 0.0.0.0/0")
c.Settings["allowlist"] = map[interface{}]interface{}{ c.Settings["allowlist"] = map[interface{}]interface{}{
"0.0.0.0/0": true, "0.0.0.0/0": true,
@ -42,7 +43,7 @@ func TestNewAllowListFromConfig(t *testing.T) {
"fd00:fd00::/16": false, "fd00:fd00::/16": false,
} }
r, err = newAllowListFromConfig(c, "allowlist", nil) r, err = newAllowListFromConfig(c, "allowlist", nil)
assert.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for ::/0") require.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for ::/0")
c.Settings["allowlist"] = map[interface{}]interface{}{ c.Settings["allowlist"] = map[interface{}]interface{}{
"0.0.0.0/0": true, "0.0.0.0/0": true,
@ -75,7 +76,7 @@ func TestNewAllowListFromConfig(t *testing.T) {
}, },
} }
lr, err := NewLocalAllowListFromConfig(c, "allowlist") lr, err := NewLocalAllowListFromConfig(c, "allowlist")
assert.EqualError(t, err, "config `allowlist.interfaces` has invalid value (type string): foo") require.EqualError(t, err, "config `allowlist.interfaces` has invalid value (type string): foo")
c.Settings["allowlist"] = map[interface{}]interface{}{ c.Settings["allowlist"] = map[interface{}]interface{}{
"interfaces": map[interface{}]interface{}{ "interfaces": map[interface{}]interface{}{
@ -84,7 +85,7 @@ func TestNewAllowListFromConfig(t *testing.T) {
}, },
} }
lr, err = NewLocalAllowListFromConfig(c, "allowlist") lr, err = NewLocalAllowListFromConfig(c, "allowlist")
assert.EqualError(t, err, "config `allowlist.interfaces` values must all be the same true/false value") require.EqualError(t, err, "config `allowlist.interfaces` values must all be the same true/false value")
c.Settings["allowlist"] = map[interface{}]interface{}{ c.Settings["allowlist"] = map[interface{}]interface{}{
"interfaces": map[interface{}]interface{}{ "interfaces": map[interface{}]interface{}{

View file

@ -15,10 +15,10 @@ func TestCalculatedRemoteApply(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
input, err := netip.ParseAddr("10.0.10.182") input, err := netip.ParseAddr("10.0.10.182")
assert.NoError(t, err) require.NoError(t, err)
expected, err := netip.ParseAddr("192.168.1.182") expected, err := netip.ParseAddr("192.168.1.182")
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, netAddrToProtoV4AddrPort(expected, 4242), c.ApplyV4(input)) assert.Equal(t, netAddrToProtoV4AddrPort(expected, 4242), c.ApplyV4(input))
@ -28,10 +28,10 @@ func TestCalculatedRemoteApply(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef") input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef")
assert.NoError(t, err) require.NoError(t, err)
expected, err = netip.ParseAddr("ffff:ffff:ffff:ffff:beef:beef:beef:beef") expected, err = netip.ParseAddr("ffff:ffff:ffff:ffff:beef:beef:beef:beef")
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input)) assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input))
@ -41,10 +41,10 @@ func TestCalculatedRemoteApply(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef") input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef")
assert.NoError(t, err) require.NoError(t, err)
expected, err = netip.ParseAddr("ffff:ffff:ffff:ffff:ffff:beef:beef:beef") expected, err = netip.ParseAddr("ffff:ffff:ffff:ffff:ffff:beef:beef:beef")
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input)) assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input))
@ -54,10 +54,10 @@ func TestCalculatedRemoteApply(t *testing.T) {
require.NoError(t, err) require.NoError(t, err)
input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef") input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef")
assert.NoError(t, err) require.NoError(t, err)
expected, err = netip.ParseAddr("ffff:ffff:ffff:beef:beef:beef:beef:beef") expected, err = netip.ParseAddr("ffff:ffff:ffff:beef:beef:beef:beef:beef")
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input)) assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input))
} }

View file

@ -6,6 +6,7 @@ import (
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestNewCAPoolFromBytes(t *testing.T) { func TestNewCAPoolFromBytes(t *testing.T) {
@ -82,12 +83,12 @@ k+coOv04r+zh33ISyhbsafnYduN17p2eD7CmHvHuerguXD9f32gcxo/KsFCKEjMe
} }
p, err := NewCAPoolFromPEM([]byte(noNewLines)) p, err := NewCAPoolFromPEM([]byte(noNewLines))
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, p.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name) assert.Equal(t, p.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name)
assert.Equal(t, p.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name) assert.Equal(t, p.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name)
pp, err := NewCAPoolFromPEM([]byte(withNewLines)) pp, err := NewCAPoolFromPEM([]byte(withNewLines))
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, pp.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name) assert.Equal(t, pp.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name)
assert.Equal(t, pp.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name) assert.Equal(t, pp.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name)
@ -105,7 +106,7 @@ k+coOv04r+zh33ISyhbsafnYduN17p2eD7CmHvHuerguXD9f32gcxo/KsFCKEjMe
assert.Len(t, pppp.CAs, 3) assert.Len(t, pppp.CAs, 3)
ppppp, err := NewCAPoolFromPEM([]byte(p256)) ppppp, err := NewCAPoolFromPEM([]byte(p256))
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, ppppp.CAs["552bf7d99bec1fc775a0e4c324bf6d8f789b3078f1919c7960d2e5e0c351ee97"].Certificate.Name(), rootCAP256.details.name) assert.Equal(t, ppppp.CAs["552bf7d99bec1fc775a0e4c324bf6d8f789b3078f1919c7960d2e5e0c351ee97"].Certificate.Name(), rootCAP256.details.name)
assert.Len(t, ppppp.CAs, 1) assert.Len(t, ppppp.CAs, 1)
} }
@ -115,21 +116,21 @@ func TestCertificateV1_Verify(t *testing.T) {
c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
caPool := NewCAPool() caPool := NewCAPool()
assert.NoError(t, caPool.AddCA(ca)) require.NoError(t, caPool.AddCA(ca))
f, err := c.Fingerprint() f, err := c.Fingerprint()
assert.NoError(t, err) require.NoError(t, err)
caPool.BlocklistFingerprint(f) caPool.BlocklistFingerprint(f)
_, err = caPool.VerifyCertificate(time.Now(), c) _, err = caPool.VerifyCertificate(time.Now(), c)
assert.EqualError(t, err, "certificate is in the block list") require.EqualError(t, err, "certificate is in the block list")
caPool.ResetCertBlocklist() caPool.ResetCertBlocklist()
_, err = caPool.VerifyCertificate(time.Now(), c) _, err = caPool.VerifyCertificate(time.Now(), c)
assert.NoError(t, err) require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c) _, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
assert.EqualError(t, err, "root certificate is expired") require.EqualError(t, err, "root certificate is expired")
assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() { assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() {
NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert2", time.Time{}, time.Time{}, nil, nil, nil) NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert2", time.Time{}, time.Time{}, nil, nil, nil)
@ -138,11 +139,11 @@ func TestCertificateV1_Verify(t *testing.T) {
// Test group assertion // Test group assertion
ca, _, caKey, _ = NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) ca, _, caKey, _ = NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
caPem, err := ca.MarshalPEM() caPem, err := ca.MarshalPEM()
assert.NoError(t, err) require.NoError(t, err)
caPool = NewCAPool() caPool = NewCAPool()
b, err := caPool.AddCAFromPEM(caPem) b, err := caPool.AddCAFromPEM(caPem)
assert.NoError(t, err) require.NoError(t, err)
assert.Empty(t, b) assert.Empty(t, b)
assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() { assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() {
@ -150,9 +151,9 @@ func TestCertificateV1_Verify(t *testing.T) {
}) })
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test2", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test2", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
assert.NoError(t, err) require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c) _, err = caPool.VerifyCertificate(time.Now(), c)
assert.NoError(t, err) require.NoError(t, err)
} }
func TestCertificateV1_VerifyP256(t *testing.T) { func TestCertificateV1_VerifyP256(t *testing.T) {
@ -160,21 +161,21 @@ func TestCertificateV1_VerifyP256(t *testing.T) {
c, _, _, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) c, _, _, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
caPool := NewCAPool() caPool := NewCAPool()
assert.NoError(t, caPool.AddCA(ca)) require.NoError(t, caPool.AddCA(ca))
f, err := c.Fingerprint() f, err := c.Fingerprint()
assert.NoError(t, err) require.NoError(t, err)
caPool.BlocklistFingerprint(f) caPool.BlocklistFingerprint(f)
_, err = caPool.VerifyCertificate(time.Now(), c) _, err = caPool.VerifyCertificate(time.Now(), c)
assert.EqualError(t, err, "certificate is in the block list") require.EqualError(t, err, "certificate is in the block list")
caPool.ResetCertBlocklist() caPool.ResetCertBlocklist()
_, err = caPool.VerifyCertificate(time.Now(), c) _, err = caPool.VerifyCertificate(time.Now(), c)
assert.NoError(t, err) require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c) _, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
assert.EqualError(t, err, "root certificate is expired") require.EqualError(t, err, "root certificate is expired")
assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() { assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() {
NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
@ -183,11 +184,11 @@ func TestCertificateV1_VerifyP256(t *testing.T) {
// Test group assertion // Test group assertion
ca, _, caKey, _ = NewTestCaCert(Version1, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) ca, _, caKey, _ = NewTestCaCert(Version1, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
caPem, err := ca.MarshalPEM() caPem, err := ca.MarshalPEM()
assert.NoError(t, err) require.NoError(t, err)
caPool = NewCAPool() caPool = NewCAPool()
b, err := caPool.AddCAFromPEM(caPem) b, err := caPool.AddCAFromPEM(caPem)
assert.NoError(t, err) require.NoError(t, err)
assert.Empty(t, b) assert.Empty(t, b)
assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() { assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() {
@ -196,7 +197,7 @@ func TestCertificateV1_VerifyP256(t *testing.T) {
c, _, _, _ = NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) c, _, _, _ = NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
_, err = caPool.VerifyCertificate(time.Now(), c) _, err = caPool.VerifyCertificate(time.Now(), c)
assert.NoError(t, err) require.NoError(t, err)
} }
func TestCertificateV1_Verify_IPs(t *testing.T) { func TestCertificateV1_Verify_IPs(t *testing.T) {
@ -205,11 +206,11 @@ func TestCertificateV1_Verify_IPs(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
caPem, err := ca.MarshalPEM() caPem, err := ca.MarshalPEM()
assert.NoError(t, err) require.NoError(t, err)
caPool := NewCAPool() caPool := NewCAPool()
b, err := caPool.AddCAFromPEM(caPem) b, err := caPool.AddCAFromPEM(caPem)
assert.NoError(t, err) require.NoError(t, err)
assert.Empty(t, b) assert.Empty(t, b)
// ip is outside the network // ip is outside the network
@ -245,25 +246,25 @@ func TestCertificateV1_Verify_IPs(t *testing.T) {
cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
_, err = caPool.VerifyCertificate(time.Now(), c) _, err = caPool.VerifyCertificate(time.Now(), c)
assert.NoError(t, err) require.NoError(t, err)
// Exact matches // Exact matches
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
assert.NoError(t, err) require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c) _, err = caPool.VerifyCertificate(time.Now(), c)
assert.NoError(t, err) require.NoError(t, err)
// Exact matches reversed // Exact matches reversed
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"}) c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"})
assert.NoError(t, err) require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c) _, err = caPool.VerifyCertificate(time.Now(), c)
assert.NoError(t, err) require.NoError(t, err)
// Exact matches reversed with just 1 // Exact matches reversed with just 1
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"}) c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"})
assert.NoError(t, err) require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c) _, err = caPool.VerifyCertificate(time.Now(), c)
assert.NoError(t, err) require.NoError(t, err)
} }
func TestCertificateV1_Verify_Subnets(t *testing.T) { func TestCertificateV1_Verify_Subnets(t *testing.T) {
@ -272,11 +273,11 @@ func TestCertificateV1_Verify_Subnets(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
caPem, err := ca.MarshalPEM() caPem, err := ca.MarshalPEM()
assert.NoError(t, err) require.NoError(t, err)
caPool := NewCAPool() caPool := NewCAPool()
b, err := caPool.AddCAFromPEM(caPem) b, err := caPool.AddCAFromPEM(caPem)
assert.NoError(t, err) require.NoError(t, err)
assert.Empty(t, b) assert.Empty(t, b)
// ip is outside the network // ip is outside the network
@ -311,27 +312,27 @@ func TestCertificateV1_Verify_Subnets(t *testing.T) {
cIp1 = mustParsePrefixUnmapped("10.0.1.0/16") cIp1 = mustParsePrefixUnmapped("10.0.1.0/16")
cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
assert.NoError(t, err) require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c) _, err = caPool.VerifyCertificate(time.Now(), c)
assert.NoError(t, err) require.NoError(t, err)
// Exact matches // Exact matches
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
assert.NoError(t, err) require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c) _, err = caPool.VerifyCertificate(time.Now(), c)
assert.NoError(t, err) require.NoError(t, err)
// Exact matches reversed // Exact matches reversed
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"}) c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"})
assert.NoError(t, err) require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c) _, err = caPool.VerifyCertificate(time.Now(), c)
assert.NoError(t, err) require.NoError(t, err)
// Exact matches reversed with just 1 // Exact matches reversed with just 1
c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"}) c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"})
assert.NoError(t, err) require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c) _, err = caPool.VerifyCertificate(time.Now(), c)
assert.NoError(t, err) require.NoError(t, err)
} }
func TestCertificateV2_Verify(t *testing.T) { func TestCertificateV2_Verify(t *testing.T) {
@ -339,21 +340,21 @@ func TestCertificateV2_Verify(t *testing.T) {
c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
caPool := NewCAPool() caPool := NewCAPool()
assert.NoError(t, caPool.AddCA(ca)) require.NoError(t, caPool.AddCA(ca))
f, err := c.Fingerprint() f, err := c.Fingerprint()
assert.NoError(t, err) require.NoError(t, err)
caPool.BlocklistFingerprint(f) caPool.BlocklistFingerprint(f)
_, err = caPool.VerifyCertificate(time.Now(), c) _, err = caPool.VerifyCertificate(time.Now(), c)
assert.EqualError(t, err, "certificate is in the block list") require.EqualError(t, err, "certificate is in the block list")
caPool.ResetCertBlocklist() caPool.ResetCertBlocklist()
_, err = caPool.VerifyCertificate(time.Now(), c) _, err = caPool.VerifyCertificate(time.Now(), c)
assert.NoError(t, err) require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c) _, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
assert.EqualError(t, err, "root certificate is expired") require.EqualError(t, err, "root certificate is expired")
assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() { assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() {
NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test cert2", time.Time{}, time.Time{}, nil, nil, nil) NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test cert2", time.Time{}, time.Time{}, nil, nil, nil)
@ -362,11 +363,11 @@ func TestCertificateV2_Verify(t *testing.T) {
// Test group assertion // Test group assertion
ca, _, caKey, _ = NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) ca, _, caKey, _ = NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
caPem, err := ca.MarshalPEM() caPem, err := ca.MarshalPEM()
assert.NoError(t, err) require.NoError(t, err)
caPool = NewCAPool() caPool = NewCAPool()
b, err := caPool.AddCAFromPEM(caPem) b, err := caPool.AddCAFromPEM(caPem)
assert.NoError(t, err) require.NoError(t, err)
assert.Empty(t, b) assert.Empty(t, b)
assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() { assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() {
@ -374,9 +375,9 @@ func TestCertificateV2_Verify(t *testing.T) {
}) })
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test2", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test2", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
assert.NoError(t, err) require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c) _, err = caPool.VerifyCertificate(time.Now(), c)
assert.NoError(t, err) require.NoError(t, err)
} }
func TestCertificateV2_VerifyP256(t *testing.T) { func TestCertificateV2_VerifyP256(t *testing.T) {
@ -384,21 +385,21 @@ func TestCertificateV2_VerifyP256(t *testing.T) {
c, _, _, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) c, _, _, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil)
caPool := NewCAPool() caPool := NewCAPool()
assert.NoError(t, caPool.AddCA(ca)) require.NoError(t, caPool.AddCA(ca))
f, err := c.Fingerprint() f, err := c.Fingerprint()
assert.NoError(t, err) require.NoError(t, err)
caPool.BlocklistFingerprint(f) caPool.BlocklistFingerprint(f)
_, err = caPool.VerifyCertificate(time.Now(), c) _, err = caPool.VerifyCertificate(time.Now(), c)
assert.EqualError(t, err, "certificate is in the block list") require.EqualError(t, err, "certificate is in the block list")
caPool.ResetCertBlocklist() caPool.ResetCertBlocklist()
_, err = caPool.VerifyCertificate(time.Now(), c) _, err = caPool.VerifyCertificate(time.Now(), c)
assert.NoError(t, err) require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c) _, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c)
assert.EqualError(t, err, "root certificate is expired") require.EqualError(t, err, "root certificate is expired")
assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() { assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() {
NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
@ -407,11 +408,11 @@ func TestCertificateV2_VerifyP256(t *testing.T) {
// Test group assertion // Test group assertion
ca, _, caKey, _ = NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) ca, _, caKey, _ = NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"})
caPem, err := ca.MarshalPEM() caPem, err := ca.MarshalPEM()
assert.NoError(t, err) require.NoError(t, err)
caPool = NewCAPool() caPool = NewCAPool()
b, err := caPool.AddCAFromPEM(caPem) b, err := caPool.AddCAFromPEM(caPem)
assert.NoError(t, err) require.NoError(t, err)
assert.Empty(t, b) assert.Empty(t, b)
assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() { assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() {
@ -420,7 +421,7 @@ func TestCertificateV2_VerifyP256(t *testing.T) {
c, _, _, _ = NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) c, _, _, _ = NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"})
_, err = caPool.VerifyCertificate(time.Now(), c) _, err = caPool.VerifyCertificate(time.Now(), c)
assert.NoError(t, err) require.NoError(t, err)
} }
func TestCertificateV2_Verify_IPs(t *testing.T) { func TestCertificateV2_Verify_IPs(t *testing.T) {
@ -429,11 +430,11 @@ func TestCertificateV2_Verify_IPs(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
caPem, err := ca.MarshalPEM() caPem, err := ca.MarshalPEM()
assert.NoError(t, err) require.NoError(t, err)
caPool := NewCAPool() caPool := NewCAPool()
b, err := caPool.AddCAFromPEM(caPem) b, err := caPool.AddCAFromPEM(caPem)
assert.NoError(t, err) require.NoError(t, err)
assert.Empty(t, b) assert.Empty(t, b)
// ip is outside the network // ip is outside the network
@ -469,25 +470,25 @@ func TestCertificateV2_Verify_IPs(t *testing.T) {
cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"})
_, err = caPool.VerifyCertificate(time.Now(), c) _, err = caPool.VerifyCertificate(time.Now(), c)
assert.NoError(t, err) require.NoError(t, err)
// Exact matches // Exact matches
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"})
assert.NoError(t, err) require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c) _, err = caPool.VerifyCertificate(time.Now(), c)
assert.NoError(t, err) require.NoError(t, err)
// Exact matches reversed // Exact matches reversed
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"}) c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"})
assert.NoError(t, err) require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c) _, err = caPool.VerifyCertificate(time.Now(), c)
assert.NoError(t, err) require.NoError(t, err)
// Exact matches reversed with just 1 // Exact matches reversed with just 1
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"}) c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"})
assert.NoError(t, err) require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c) _, err = caPool.VerifyCertificate(time.Now(), c)
assert.NoError(t, err) require.NoError(t, err)
} }
func TestCertificateV2_Verify_Subnets(t *testing.T) { func TestCertificateV2_Verify_Subnets(t *testing.T) {
@ -496,11 +497,11 @@ func TestCertificateV2_Verify_Subnets(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
caPem, err := ca.MarshalPEM() caPem, err := ca.MarshalPEM()
assert.NoError(t, err) require.NoError(t, err)
caPool := NewCAPool() caPool := NewCAPool()
b, err := caPool.AddCAFromPEM(caPem) b, err := caPool.AddCAFromPEM(caPem)
assert.NoError(t, err) require.NoError(t, err)
assert.Empty(t, b) assert.Empty(t, b)
// ip is outside the network // ip is outside the network
@ -535,25 +536,25 @@ func TestCertificateV2_Verify_Subnets(t *testing.T) {
cIp1 = mustParsePrefixUnmapped("10.0.1.0/16") cIp1 = mustParsePrefixUnmapped("10.0.1.0/16")
cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") cIp2 = mustParsePrefixUnmapped("192.168.0.1/25")
c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"})
assert.NoError(t, err) require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c) _, err = caPool.VerifyCertificate(time.Now(), c)
assert.NoError(t, err) require.NoError(t, err)
// Exact matches // Exact matches
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"})
assert.NoError(t, err) require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c) _, err = caPool.VerifyCertificate(time.Now(), c)
assert.NoError(t, err) require.NoError(t, err)
// Exact matches reversed // Exact matches reversed
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"}) c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"})
assert.NoError(t, err) require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c) _, err = caPool.VerifyCertificate(time.Now(), c)
assert.NoError(t, err) require.NoError(t, err)
// Exact matches reversed with just 1 // Exact matches reversed with just 1
c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"}) c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"})
assert.NoError(t, err) require.NoError(t, err)
_, err = caPool.VerifyCertificate(time.Now(), c) _, err = caPool.VerifyCertificate(time.Now(), c)
assert.NoError(t, err) require.NoError(t, err)
} }

View file

@ -39,11 +39,11 @@ func TestCertificateV1_Marshal(t *testing.T) {
} }
b, err := nc.Marshal() b, err := nc.Marshal()
assert.NoError(t, err) require.NoError(t, err)
//t.Log("Cert size:", len(b)) //t.Log("Cert size:", len(b))
nc2, err := unmarshalCertificateV1(b, nil) nc2, err := unmarshalCertificateV1(b, nil)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, Version1, nc.Version()) assert.Equal(t, Version1, nc.Version())
assert.Equal(t, Curve_CURVE25519, nc.Curve()) assert.Equal(t, Curve_CURVE25519, nc.Curve())
@ -99,7 +99,7 @@ func TestCertificateV1_MarshalJSON(t *testing.T) {
} }
b, err := nc.MarshalJSON() b, err := nc.MarshalJSON()
assert.NoError(t, err) require.NoError(t, err)
assert.JSONEq( assert.JSONEq(
t, t,
"{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"isCa\":false,\"issuer\":\"1234567890abcedfghij1234567890ab\",\"name\":\"testing\",\"networks\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"publicKey\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"unsafeNetworks\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"3944c53d4267a229295b56cb2d27d459164c010ac97d655063ba421e0670f4ba\",\"signature\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"version\":1}", "{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"isCa\":false,\"issuer\":\"1234567890abcedfghij1234567890ab\",\"name\":\"testing\",\"networks\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"publicKey\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"unsafeNetworks\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"3944c53d4267a229295b56cb2d27d459164c010ac97d655063ba421e0670f4ba\",\"signature\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"version\":1}",
@ -110,47 +110,47 @@ func TestCertificateV1_MarshalJSON(t *testing.T) {
func TestCertificateV1_VerifyPrivateKey(t *testing.T) { func TestCertificateV1_VerifyPrivateKey(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil) ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil)
err := ca.VerifyPrivateKey(Curve_CURVE25519, caKey) err := ca.VerifyPrivateKey(Curve_CURVE25519, caKey)
assert.NoError(t, err) require.NoError(t, err)
_, _, caKey2, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil) _, _, caKey2, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil)
assert.NoError(t, err) require.NoError(t, err)
err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2) err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2)
assert.Error(t, err) require.Error(t, err)
c, _, priv, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) c, _, priv, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv) rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
assert.NoError(t, err) require.NoError(t, err)
assert.Empty(t, b) assert.Empty(t, b)
assert.Equal(t, Curve_CURVE25519, curve) assert.Equal(t, Curve_CURVE25519, curve)
err = c.VerifyPrivateKey(Curve_CURVE25519, rawPriv) err = c.VerifyPrivateKey(Curve_CURVE25519, rawPriv)
assert.NoError(t, err) require.NoError(t, err)
_, priv2 := X25519Keypair() _, priv2 := X25519Keypair()
err = c.VerifyPrivateKey(Curve_CURVE25519, priv2) err = c.VerifyPrivateKey(Curve_CURVE25519, priv2)
assert.Error(t, err) require.Error(t, err)
} }
func TestCertificateV1_VerifyPrivateKeyP256(t *testing.T) { func TestCertificateV1_VerifyPrivateKeyP256(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) ca, _, caKey, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
err := ca.VerifyPrivateKey(Curve_P256, caKey) err := ca.VerifyPrivateKey(Curve_P256, caKey)
assert.NoError(t, err) require.NoError(t, err)
_, _, caKey2, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) _, _, caKey2, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
assert.NoError(t, err) require.NoError(t, err)
err = ca.VerifyPrivateKey(Curve_P256, caKey2) err = ca.VerifyPrivateKey(Curve_P256, caKey2)
assert.Error(t, err) require.Error(t, err)
c, _, priv, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) c, _, priv, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv) rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
assert.NoError(t, err) require.NoError(t, err)
assert.Empty(t, b) assert.Empty(t, b)
assert.Equal(t, Curve_P256, curve) assert.Equal(t, Curve_P256, curve)
err = c.VerifyPrivateKey(Curve_P256, rawPriv) err = c.VerifyPrivateKey(Curve_P256, rawPriv)
assert.NoError(t, err) require.NoError(t, err)
_, priv2 := P256Keypair() _, priv2 := P256Keypair()
err = c.VerifyPrivateKey(Curve_P256, priv2) err = c.VerifyPrivateKey(Curve_P256, priv2)
assert.Error(t, err) require.Error(t, err)
} }
// Ensure that upgrading the protobuf library does not change how certificates // Ensure that upgrading the protobuf library does not change how certificates
@ -186,7 +186,7 @@ func TestMarshalingCertificateV1Consistency(t *testing.T) {
assert.Equal(t, "0a8e010a0774657374696e671212828284508080fcff0f8182845080feffff0f1a12838284488080fcff0f8282844880feffff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328cd1c30cdb8ccf0af073a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf1220313233343536373839306162636564666768696a313233343536373839306162", fmt.Sprintf("%x", b)) assert.Equal(t, "0a8e010a0774657374696e671212828284508080fcff0f8182845080feffff0f1a12838284488080fcff0f8282844880feffff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328cd1c30cdb8ccf0af073a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf1220313233343536373839306162636564666768696a313233343536373839306162", fmt.Sprintf("%x", b))
b, err = proto.Marshal(nc.getRawDetails()) b, err = proto.Marshal(nc.getRawDetails())
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "0a0774657374696e671212828284508080fcff0f8182845080feffff0f1a12838284488080fcff0f8282844880feffff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328cd1c30cdb8ccf0af073a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf", fmt.Sprintf("%x", b)) assert.Equal(t, "0a0774657374696e671212828284508080fcff0f8182845080feffff0f1a12838284488080fcff0f8282844880feffff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328cd1c30cdb8ccf0af073a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf", fmt.Sprintf("%x", b))
} }
@ -201,7 +201,7 @@ func TestUnmarshalCertificateV1(t *testing.T) {
// Test that we don't panic with an invalid certificate (#332) // Test that we don't panic with an invalid certificate (#332)
data := []byte("\x98\x00\x00") data := []byte("\x98\x00\x00")
_, err := unmarshalCertificateV1(data, nil) _, err := unmarshalCertificateV1(data, nil)
assert.EqualError(t, err, "encoded Details was nil") require.EqualError(t, err, "encoded Details was nil")
} }
func appendByteSlices(b ...[]byte) []byte { func appendByteSlices(b ...[]byte) []byte {

View file

@ -49,7 +49,7 @@ func TestCertificateV2_Marshal(t *testing.T) {
//t.Log("Cert size:", len(b)) //t.Log("Cert size:", len(b))
nc2, err := unmarshalCertificateV2(b, nil, Curve_CURVE25519) nc2, err := unmarshalCertificateV2(b, nil, Curve_CURVE25519)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, Version2, nc.Version()) assert.Equal(t, Version2, nc.Version())
assert.Equal(t, Curve_CURVE25519, nc.Curve()) assert.Equal(t, Curve_CURVE25519, nc.Curve())
@ -114,14 +114,14 @@ func TestCertificateV2_MarshalJSON(t *testing.T) {
} }
b, err := nc.MarshalJSON() b, err := nc.MarshalJSON()
assert.ErrorIs(t, err, ErrMissingDetails) require.ErrorIs(t, err, ErrMissingDetails)
rd, err := nc.details.Marshal() rd, err := nc.details.Marshal()
assert.NoError(t, err) require.NoError(t, err)
nc.rawDetails = rd nc.rawDetails = rd
b, err = nc.MarshalJSON() b, err = nc.MarshalJSON()
assert.NoError(t, err) require.NoError(t, err)
assert.JSONEq( assert.JSONEq(
t, t,
"{\"curve\":\"CURVE25519\",\"details\":{\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"isCa\":false,\"issuer\":\"1234567890abcedf1234567890abcedf\",\"name\":\"testing\",\"networks\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"unsafeNetworks\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"152d9a7400c1e001cb76cffd035215ebb351f69eeb797f7f847dd086e15e56dd\",\"publicKey\":\"3132333435363738393061626365646631323334353637383930616263656466\",\"signature\":\"31323334353637383930616263656466313233343536373839306162636564663132333435363738393061626365646631323334353637383930616263656466\",\"version\":2}", "{\"curve\":\"CURVE25519\",\"details\":{\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"isCa\":false,\"issuer\":\"1234567890abcedf1234567890abcedf\",\"name\":\"testing\",\"networks\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"unsafeNetworks\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"152d9a7400c1e001cb76cffd035215ebb351f69eeb797f7f847dd086e15e56dd\",\"publicKey\":\"3132333435363738393061626365646631323334353637383930616263656466\",\"signature\":\"31323334353637383930616263656466313233343536373839306162636564663132333435363738393061626365646631323334353637383930616263656466\",\"version\":2}",
@ -132,85 +132,85 @@ func TestCertificateV2_MarshalJSON(t *testing.T) {
func TestCertificateV2_VerifyPrivateKey(t *testing.T) { func TestCertificateV2_VerifyPrivateKey(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil) ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil)
err := ca.VerifyPrivateKey(Curve_CURVE25519, caKey) err := ca.VerifyPrivateKey(Curve_CURVE25519, caKey)
assert.NoError(t, err) require.NoError(t, err)
err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey[:16]) err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey[:16])
assert.ErrorIs(t, err, ErrInvalidPrivateKey) require.ErrorIs(t, err, ErrInvalidPrivateKey)
_, caKey2, err := ed25519.GenerateKey(rand.Reader) _, caKey2, err := ed25519.GenerateKey(rand.Reader)
require.NoError(t, err) require.NoError(t, err)
err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2) err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2)
assert.ErrorIs(t, err, ErrPublicPrivateKeyMismatch) require.ErrorIs(t, err, ErrPublicPrivateKeyMismatch)
c, _, priv, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) c, _, priv, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv) rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
assert.NoError(t, err) require.NoError(t, err)
assert.Empty(t, b) assert.Empty(t, b)
assert.Equal(t, Curve_CURVE25519, curve) assert.Equal(t, Curve_CURVE25519, curve)
err = c.VerifyPrivateKey(Curve_CURVE25519, rawPriv) err = c.VerifyPrivateKey(Curve_CURVE25519, rawPriv)
assert.NoError(t, err) require.NoError(t, err)
_, priv2 := X25519Keypair() _, priv2 := X25519Keypair()
err = c.VerifyPrivateKey(Curve_P256, priv2) err = c.VerifyPrivateKey(Curve_P256, priv2)
assert.ErrorIs(t, err, ErrPublicPrivateCurveMismatch) require.ErrorIs(t, err, ErrPublicPrivateCurveMismatch)
err = c.VerifyPrivateKey(Curve_CURVE25519, priv2) err = c.VerifyPrivateKey(Curve_CURVE25519, priv2)
assert.ErrorIs(t, err, ErrPublicPrivateKeyMismatch) require.ErrorIs(t, err, ErrPublicPrivateKeyMismatch)
err = c.VerifyPrivateKey(Curve_CURVE25519, priv2[:16]) err = c.VerifyPrivateKey(Curve_CURVE25519, priv2[:16])
assert.ErrorIs(t, err, ErrInvalidPrivateKey) require.ErrorIs(t, err, ErrInvalidPrivateKey)
ac, ok := c.(*certificateV2) ac, ok := c.(*certificateV2)
require.True(t, ok) require.True(t, ok)
ac.curve = Curve(99) ac.curve = Curve(99)
err = c.VerifyPrivateKey(Curve(99), priv2) err = c.VerifyPrivateKey(Curve(99), priv2)
assert.EqualError(t, err, "invalid curve: 99") require.EqualError(t, err, "invalid curve: 99")
ca2, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) ca2, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey) err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey)
assert.NoError(t, err) require.NoError(t, err)
err = ca2.VerifyPrivateKey(Curve_P256, caKey2[:16]) err = ca2.VerifyPrivateKey(Curve_P256, caKey2[:16])
assert.ErrorIs(t, err, ErrInvalidPrivateKey) require.ErrorIs(t, err, ErrInvalidPrivateKey)
c, _, priv, _ = NewTestCert(Version2, Curve_P256, ca2, caKey2, "test", time.Time{}, time.Time{}, nil, nil, nil) c, _, priv, _ = NewTestCert(Version2, Curve_P256, ca2, caKey2, "test", time.Time{}, time.Time{}, nil, nil, nil)
rawPriv, b, curve, err = UnmarshalPrivateKeyFromPEM(priv) rawPriv, b, curve, err = UnmarshalPrivateKeyFromPEM(priv)
err = c.VerifyPrivateKey(Curve_P256, priv[:16]) err = c.VerifyPrivateKey(Curve_P256, priv[:16])
assert.ErrorIs(t, err, ErrInvalidPrivateKey) require.ErrorIs(t, err, ErrInvalidPrivateKey)
err = c.VerifyPrivateKey(Curve_P256, priv) err = c.VerifyPrivateKey(Curve_P256, priv)
assert.ErrorIs(t, err, ErrInvalidPrivateKey) require.ErrorIs(t, err, ErrInvalidPrivateKey)
aCa, ok := ca2.(*certificateV2) aCa, ok := ca2.(*certificateV2)
require.True(t, ok) require.True(t, ok)
aCa.curve = Curve(99) aCa.curve = Curve(99)
err = aCa.VerifyPrivateKey(Curve(99), priv2) err = aCa.VerifyPrivateKey(Curve(99), priv2)
assert.EqualError(t, err, "invalid curve: 99") require.EqualError(t, err, "invalid curve: 99")
} }
func TestCertificateV2_VerifyPrivateKeyP256(t *testing.T) { func TestCertificateV2_VerifyPrivateKeyP256(t *testing.T) {
ca, _, caKey, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) ca, _, caKey, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
err := ca.VerifyPrivateKey(Curve_P256, caKey) err := ca.VerifyPrivateKey(Curve_P256, caKey)
assert.NoError(t, err) require.NoError(t, err)
_, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) _, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil)
assert.NoError(t, err) require.NoError(t, err)
err = ca.VerifyPrivateKey(Curve_P256, caKey2) err = ca.VerifyPrivateKey(Curve_P256, caKey2)
assert.Error(t, err) require.Error(t, err)
c, _, priv, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) c, _, priv, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil)
rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv) rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv)
assert.NoError(t, err) require.NoError(t, err)
assert.Empty(t, b) assert.Empty(t, b)
assert.Equal(t, Curve_P256, curve) assert.Equal(t, Curve_P256, curve)
err = c.VerifyPrivateKey(Curve_P256, rawPriv) err = c.VerifyPrivateKey(Curve_P256, rawPriv)
assert.NoError(t, err) require.NoError(t, err)
_, priv2 := P256Keypair() _, priv2 := P256Keypair()
err = c.VerifyPrivateKey(Curve_P256, priv2) err = c.VerifyPrivateKey(Curve_P256, priv2)
assert.Error(t, err) require.Error(t, err)
} }
func TestCertificateV2_Copy(t *testing.T) { func TestCertificateV2_Copy(t *testing.T) {
@ -223,7 +223,7 @@ func TestCertificateV2_Copy(t *testing.T) {
func TestUnmarshalCertificateV2(t *testing.T) { func TestUnmarshalCertificateV2(t *testing.T) {
data := []byte("\x98\x00\x00") data := []byte("\x98\x00\x00")
_, err := unmarshalCertificateV2(data, nil, Curve_CURVE25519) _, err := unmarshalCertificateV2(data, nil, Curve_CURVE25519)
assert.EqualError(t, err, "bad wire format") require.EqualError(t, err, "bad wire format")
} }
func TestCertificateV2_marshalForSigningStability(t *testing.T) { func TestCertificateV2_marshalForSigningStability(t *testing.T) {

View file

@ -4,6 +4,7 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/argon2" "golang.org/x/crypto/argon2"
) )
@ -61,33 +62,33 @@ qrlJ69wer3ZUHFXA
// Success test case // Success test case
curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, keyBundle) curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, keyBundle)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, Curve_CURVE25519, curve) assert.Equal(t, Curve_CURVE25519, curve)
assert.Len(t, k, 64) assert.Len(t, k, 64)
assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
// Fail due to short key // Fail due to short key
curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest) curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest)
assert.EqualError(t, err, "key was not 64 bytes, is invalid ed25519 private key") require.EqualError(t, err, "key was not 64 bytes, is invalid ed25519 private key")
assert.Nil(t, k) assert.Nil(t, k)
assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
// Fail due to invalid banner // Fail due to invalid banner
curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest) curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest)
assert.EqualError(t, err, "bytes did not contain a proper nebula encrypted Ed25519/ECDSA private key banner") require.EqualError(t, err, "bytes did not contain a proper nebula encrypted Ed25519/ECDSA private key banner")
assert.Nil(t, k) assert.Nil(t, k)
assert.Equal(t, rest, invalidPem) assert.Equal(t, rest, invalidPem)
// Fail due to ivalid PEM format, because // Fail due to ivalid PEM format, because
// it's missing the requisite pre-encapsulation boundary. // it's missing the requisite pre-encapsulation boundary.
curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest) curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest)
assert.EqualError(t, err, "input did not contain a valid PEM encoded block") require.EqualError(t, err, "input did not contain a valid PEM encoded block")
assert.Nil(t, k) assert.Nil(t, k)
assert.Equal(t, rest, invalidPem) assert.Equal(t, rest, invalidPem)
// Fail due to invalid passphrase // Fail due to invalid passphrase
curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey([]byte("invalid passphrase"), privKey) curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey([]byte("invalid passphrase"), privKey)
assert.EqualError(t, err, "invalid passphrase or corrupt private key") require.EqualError(t, err, "invalid passphrase or corrupt private key")
assert.Nil(t, k) assert.Nil(t, k)
assert.Equal(t, []byte{}, rest) assert.Equal(t, []byte{}, rest)
} }
@ -99,14 +100,14 @@ func TestEncryptAndMarshalSigningPrivateKey(t *testing.T) {
bytes := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA") bytes := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA")
kdfParams := NewArgon2Parameters(64*1024, 4, 3) kdfParams := NewArgon2Parameters(64*1024, 4, 3)
key, err := EncryptAndMarshalSigningPrivateKey(Curve_CURVE25519, bytes, passphrase, kdfParams) key, err := EncryptAndMarshalSigningPrivateKey(Curve_CURVE25519, bytes, passphrase, kdfParams)
assert.NoError(t, err) require.NoError(t, err)
// Verify the "key" can be decrypted successfully // Verify the "key" can be decrypted successfully
curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, key) curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, key)
assert.Len(t, k, 64) assert.Len(t, k, 64)
assert.Equal(t, Curve_CURVE25519, curve) assert.Equal(t, Curve_CURVE25519, curve)
assert.Equal(t, []byte{}, rest) assert.Equal(t, []byte{}, rest)
assert.NoError(t, err) require.NoError(t, err)
// EncryptAndMarshalEd25519PrivateKey does not create any errors itself // EncryptAndMarshalEd25519PrivateKey does not create any errors itself
} }

View file

@ -4,6 +4,7 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestUnmarshalCertificateFromPEM(t *testing.T) { func TestUnmarshalCertificateFromPEM(t *testing.T) {
@ -35,20 +36,20 @@ bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB
cert, rest, err := UnmarshalCertificateFromPEM(certBundle) cert, rest, err := UnmarshalCertificateFromPEM(certBundle)
assert.NotNil(t, cert) assert.NotNil(t, cert)
assert.Equal(t, rest, append(badBanner, invalidPem...)) assert.Equal(t, rest, append(badBanner, invalidPem...))
assert.NoError(t, err) require.NoError(t, err)
// Fail due to invalid banner. // Fail due to invalid banner.
cert, rest, err = UnmarshalCertificateFromPEM(rest) cert, rest, err = UnmarshalCertificateFromPEM(rest)
assert.Nil(t, cert) assert.Nil(t, cert)
assert.Equal(t, rest, invalidPem) assert.Equal(t, rest, invalidPem)
assert.EqualError(t, err, "bytes did not contain a proper certificate banner") require.EqualError(t, err, "bytes did not contain a proper certificate banner")
// Fail due to ivalid PEM format, because // Fail due to ivalid PEM format, because
// it's missing the requisite pre-encapsulation boundary. // it's missing the requisite pre-encapsulation boundary.
cert, rest, err = UnmarshalCertificateFromPEM(rest) cert, rest, err = UnmarshalCertificateFromPEM(rest)
assert.Nil(t, cert) assert.Nil(t, cert)
assert.Equal(t, rest, invalidPem) assert.Equal(t, rest, invalidPem)
assert.EqualError(t, err, "input did not contain a valid PEM encoded block") require.EqualError(t, err, "input did not contain a valid PEM encoded block")
} }
func TestUnmarshalSigningPrivateKeyFromPEM(t *testing.T) { func TestUnmarshalSigningPrivateKeyFromPEM(t *testing.T) {
@ -84,33 +85,33 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA
assert.Len(t, k, 64) assert.Len(t, k, 64)
assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem)) assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem))
assert.Equal(t, Curve_CURVE25519, curve) assert.Equal(t, Curve_CURVE25519, curve)
assert.NoError(t, err) require.NoError(t, err)
// Success test case // Success test case
k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest) k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest)
assert.Len(t, k, 32) assert.Len(t, k, 32)
assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
assert.Equal(t, Curve_P256, curve) assert.Equal(t, Curve_P256, curve)
assert.NoError(t, err) require.NoError(t, err)
// Fail due to short key // Fail due to short key
k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest) k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest)
assert.Nil(t, k) assert.Nil(t, k)
assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
assert.EqualError(t, err, "key was not 64 bytes, is invalid Ed25519 private key") require.EqualError(t, err, "key was not 64 bytes, is invalid Ed25519 private key")
// Fail due to invalid banner // Fail due to invalid banner
k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest) k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest)
assert.Nil(t, k) assert.Nil(t, k)
assert.Equal(t, rest, invalidPem) assert.Equal(t, rest, invalidPem)
assert.EqualError(t, err, "bytes did not contain a proper Ed25519/ECDSA private key banner") require.EqualError(t, err, "bytes did not contain a proper Ed25519/ECDSA private key banner")
// Fail due to ivalid PEM format, because // Fail due to ivalid PEM format, because
// it's missing the requisite pre-encapsulation boundary. // it's missing the requisite pre-encapsulation boundary.
k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest) k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest)
assert.Nil(t, k) assert.Nil(t, k)
assert.Equal(t, rest, invalidPem) assert.Equal(t, rest, invalidPem)
assert.EqualError(t, err, "input did not contain a valid PEM encoded block") require.EqualError(t, err, "input did not contain a valid PEM encoded block")
} }
func TestUnmarshalPrivateKeyFromPEM(t *testing.T) { func TestUnmarshalPrivateKeyFromPEM(t *testing.T) {
@ -146,33 +147,33 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
assert.Len(t, k, 32) assert.Len(t, k, 32)
assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem)) assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem))
assert.Equal(t, Curve_CURVE25519, curve) assert.Equal(t, Curve_CURVE25519, curve)
assert.NoError(t, err) require.NoError(t, err)
// Success test case // Success test case
k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest) k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest)
assert.Len(t, k, 32) assert.Len(t, k, 32)
assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
assert.Equal(t, Curve_P256, curve) assert.Equal(t, Curve_P256, curve)
assert.NoError(t, err) require.NoError(t, err)
// Fail due to short key // Fail due to short key
k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest) k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest)
assert.Nil(t, k) assert.Nil(t, k)
assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
assert.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 private key") require.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 private key")
// Fail due to invalid banner // Fail due to invalid banner
k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest) k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest)
assert.Nil(t, k) assert.Nil(t, k)
assert.Equal(t, rest, invalidPem) assert.Equal(t, rest, invalidPem)
assert.EqualError(t, err, "bytes did not contain a proper private key banner") require.EqualError(t, err, "bytes did not contain a proper private key banner")
// Fail due to ivalid PEM format, because // Fail due to ivalid PEM format, because
// it's missing the requisite pre-encapsulation boundary. // it's missing the requisite pre-encapsulation boundary.
k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest) k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest)
assert.Nil(t, k) assert.Nil(t, k)
assert.Equal(t, rest, invalidPem) assert.Equal(t, rest, invalidPem)
assert.EqualError(t, err, "input did not contain a valid PEM encoded block") require.EqualError(t, err, "input did not contain a valid PEM encoded block")
} }
func TestUnmarshalPublicKeyFromPEM(t *testing.T) { func TestUnmarshalPublicKeyFromPEM(t *testing.T) {
@ -202,7 +203,7 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle) k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle)
assert.Len(t, k, 32) assert.Len(t, k, 32)
assert.Equal(t, Curve_CURVE25519, curve) assert.Equal(t, Curve_CURVE25519, curve)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
// Fail due to short key // Fail due to short key
@ -210,13 +211,13 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
assert.Nil(t, k) assert.Nil(t, k)
assert.Equal(t, Curve_CURVE25519, curve) assert.Equal(t, Curve_CURVE25519, curve)
assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
assert.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key") require.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key")
// Fail due to invalid banner // Fail due to invalid banner
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
assert.Nil(t, k) assert.Nil(t, k)
assert.Equal(t, Curve_CURVE25519, curve) assert.Equal(t, Curve_CURVE25519, curve)
assert.EqualError(t, err, "bytes did not contain a proper public key banner") require.EqualError(t, err, "bytes did not contain a proper public key banner")
assert.Equal(t, rest, invalidPem) assert.Equal(t, rest, invalidPem)
// Fail due to ivalid PEM format, because // Fail due to ivalid PEM format, because
@ -225,7 +226,7 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
assert.Nil(t, k) assert.Nil(t, k)
assert.Equal(t, Curve_CURVE25519, curve) assert.Equal(t, Curve_CURVE25519, curve)
assert.Equal(t, rest, invalidPem) assert.Equal(t, rest, invalidPem)
assert.EqualError(t, err, "input did not contain a valid PEM encoded block") require.EqualError(t, err, "input did not contain a valid PEM encoded block")
} }
func TestUnmarshalX25519PublicKey(t *testing.T) { func TestUnmarshalX25519PublicKey(t *testing.T) {
@ -260,14 +261,14 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
// Success test case // Success test case
k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle) k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle)
assert.Len(t, k, 32) assert.Len(t, k, 32)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, rest, appendByteSlices(pubP256Key, shortKey, invalidBanner, invalidPem)) assert.Equal(t, rest, appendByteSlices(pubP256Key, shortKey, invalidBanner, invalidPem))
assert.Equal(t, Curve_CURVE25519, curve) assert.Equal(t, Curve_CURVE25519, curve)
// Success test case // Success test case
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
assert.Len(t, k, 65) assert.Len(t, k, 65)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem))
assert.Equal(t, Curve_P256, curve) assert.Equal(t, Curve_P256, curve)
@ -275,12 +276,12 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
assert.Nil(t, k) assert.Nil(t, k)
assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem))
assert.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key") require.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key")
// Fail due to invalid banner // Fail due to invalid banner
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
assert.Nil(t, k) assert.Nil(t, k)
assert.EqualError(t, err, "bytes did not contain a proper public key banner") require.EqualError(t, err, "bytes did not contain a proper public key banner")
assert.Equal(t, rest, invalidPem) assert.Equal(t, rest, invalidPem)
// Fail due to ivalid PEM format, because // Fail due to ivalid PEM format, because
@ -288,5 +289,5 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA=
k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest)
assert.Nil(t, k) assert.Nil(t, k)
assert.Equal(t, rest, invalidPem) assert.Equal(t, rest, invalidPem)
assert.EqualError(t, err, "input did not contain a valid PEM encoded block") require.EqualError(t, err, "input did not contain a valid PEM encoded block")
} }

View file

@ -10,6 +10,7 @@ import (
"time" "time"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestCertificateV1_Sign(t *testing.T) { func TestCertificateV1_Sign(t *testing.T) {
@ -37,14 +38,14 @@ func TestCertificateV1_Sign(t *testing.T) {
pub, priv, err := ed25519.GenerateKey(rand.Reader) pub, priv, err := ed25519.GenerateKey(rand.Reader)
c, err := tbs.Sign(&certificateV1{details: detailsV1{notBefore: before, notAfter: after}}, Curve_CURVE25519, priv) c, err := tbs.Sign(&certificateV1{details: detailsV1{notBefore: before, notAfter: after}}, Curve_CURVE25519, priv)
assert.NoError(t, err) require.NoError(t, err)
assert.NotNil(t, c) assert.NotNil(t, c)
assert.True(t, c.CheckSignature(pub)) assert.True(t, c.CheckSignature(pub))
b, err := c.Marshal() b, err := c.Marshal()
assert.NoError(t, err) require.NoError(t, err)
uc, err := unmarshalCertificateV1(b, nil) uc, err := unmarshalCertificateV1(b, nil)
assert.NoError(t, err) require.NoError(t, err)
assert.NotNil(t, uc) assert.NotNil(t, uc)
} }
@ -73,18 +74,18 @@ func TestCertificateV1_SignP256(t *testing.T) {
} }
priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader)
assert.NoError(t, err) require.NoError(t, err)
pub := elliptic.Marshal(elliptic.P256(), priv.PublicKey.X, priv.PublicKey.Y) pub := elliptic.Marshal(elliptic.P256(), priv.PublicKey.X, priv.PublicKey.Y)
rawPriv := priv.D.FillBytes(make([]byte, 32)) rawPriv := priv.D.FillBytes(make([]byte, 32))
c, err := tbs.Sign(&certificateV1{details: detailsV1{notBefore: before, notAfter: after}}, Curve_P256, rawPriv) c, err := tbs.Sign(&certificateV1{details: detailsV1{notBefore: before, notAfter: after}}, Curve_P256, rawPriv)
assert.NoError(t, err) require.NoError(t, err)
assert.NotNil(t, c) assert.NotNil(t, c)
assert.True(t, c.CheckSignature(pub)) assert.True(t, c.CheckSignature(pub))
b, err := c.Marshal() b, err := c.Marshal()
assert.NoError(t, err) require.NoError(t, err)
uc, err := unmarshalCertificateV1(b, nil) uc, err := unmarshalCertificateV1(b, nil)
assert.NoError(t, err) require.NoError(t, err)
assert.NotNil(t, uc) assert.NotNil(t, uc)
} }

View file

@ -14,6 +14,7 @@ import (
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func Test_caSummary(t *testing.T) { func Test_caSummary(t *testing.T) {
@ -106,34 +107,34 @@ func Test_ca(t *testing.T) {
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args := []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey"} args := []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey"}
assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) require.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
assert.Equal(t, "", ob.String()) assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String()) assert.Equal(t, "", eb.String())
// create temp key file // create temp key file
keyF, err := os.CreateTemp("", "test.key") keyF, err := os.CreateTemp("", "test.key")
assert.NoError(t, err) require.NoError(t, err)
assert.NoError(t, os.Remove(keyF.Name())) require.NoError(t, os.Remove(keyF.Name()))
// failed cert write // failed cert write
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name()} args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name()}
assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError) require.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
assert.Equal(t, "", ob.String()) assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String()) assert.Equal(t, "", eb.String())
// create temp cert file // create temp cert file
crtF, err := os.CreateTemp("", "test.crt") crtF, err := os.CreateTemp("", "test.crt")
assert.NoError(t, err) require.NoError(t, err)
assert.NoError(t, os.Remove(crtF.Name())) require.NoError(t, os.Remove(crtF.Name()))
assert.NoError(t, os.Remove(keyF.Name())) require.NoError(t, os.Remove(keyF.Name()))
// test proper cert with removed empty groups and subnets // test proper cert with removed empty groups and subnets
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
assert.NoError(t, ca(args, ob, eb, nopw)) require.NoError(t, ca(args, ob, eb, nopw))
assert.Equal(t, "", ob.String()) assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String()) assert.Equal(t, "", eb.String())
@ -142,13 +143,13 @@ func Test_ca(t *testing.T) {
lKey, b, c, err := cert.UnmarshalSigningPrivateKeyFromPEM(rb) lKey, b, c, err := cert.UnmarshalSigningPrivateKeyFromPEM(rb)
assert.Equal(t, cert.Curve_CURVE25519, c) assert.Equal(t, cert.Curve_CURVE25519, c)
assert.Empty(t, b) assert.Empty(t, b)
assert.NoError(t, err) require.NoError(t, err)
assert.Len(t, lKey, 64) assert.Len(t, lKey, 64)
rb, _ = os.ReadFile(crtF.Name()) rb, _ = os.ReadFile(crtF.Name())
lCrt, b, err := cert.UnmarshalCertificateFromPEM(rb) lCrt, b, err := cert.UnmarshalCertificateFromPEM(rb)
assert.Empty(t, b) assert.Empty(t, b)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "test", lCrt.Name()) assert.Equal(t, "test", lCrt.Name())
assert.Empty(t, lCrt.Networks()) assert.Empty(t, lCrt.Networks())
@ -166,7 +167,7 @@ func Test_ca(t *testing.T) {
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
assert.NoError(t, ca(args, ob, eb, testpw)) require.NoError(t, ca(args, ob, eb, testpw))
assert.Equal(t, pwPromptOb, ob.String()) assert.Equal(t, pwPromptOb, ob.String())
assert.Equal(t, "", eb.String()) assert.Equal(t, "", eb.String())
@ -174,7 +175,7 @@ func Test_ca(t *testing.T) {
rb, _ = os.ReadFile(keyF.Name()) rb, _ = os.ReadFile(keyF.Name())
k, _ := pem.Decode(rb) k, _ := pem.Decode(rb)
ned, err := cert.UnmarshalNebulaEncryptedData(k.Bytes) ned, err := cert.UnmarshalNebulaEncryptedData(k.Bytes)
assert.NoError(t, err) require.NoError(t, err)
// we won't know salt in advance, so just check start of string // we won't know salt in advance, so just check start of string
assert.Equal(t, uint32(2*1024*1024), ned.EncryptionMetadata.Argon2Parameters.Memory) assert.Equal(t, uint32(2*1024*1024), ned.EncryptionMetadata.Argon2Parameters.Memory)
assert.Equal(t, uint8(4), ned.EncryptionMetadata.Argon2Parameters.Parallelism) assert.Equal(t, uint8(4), ned.EncryptionMetadata.Argon2Parameters.Parallelism)
@ -184,7 +185,7 @@ func Test_ca(t *testing.T) {
var curve cert.Curve var curve cert.Curve
curve, lKey, b, err = cert.DecryptAndUnmarshalSigningPrivateKey(passphrase, rb) curve, lKey, b, err = cert.DecryptAndUnmarshalSigningPrivateKey(passphrase, rb)
assert.Equal(t, cert.Curve_CURVE25519, curve) assert.Equal(t, cert.Curve_CURVE25519, curve)
assert.NoError(t, err) require.NoError(t, err)
assert.Empty(t, b) assert.Empty(t, b)
assert.Len(t, lKey, 64) assert.Len(t, lKey, 64)
@ -194,7 +195,7 @@ func Test_ca(t *testing.T) {
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
assert.Error(t, ca(args, ob, eb, errpw)) require.Error(t, ca(args, ob, eb, errpw))
assert.Equal(t, pwPromptOb, ob.String()) assert.Equal(t, pwPromptOb, ob.String())
assert.Equal(t, "", eb.String()) assert.Equal(t, "", eb.String())
@ -204,7 +205,7 @@ func Test_ca(t *testing.T) {
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
assert.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext") require.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext")
assert.Equal(t, strings.Repeat(pwPromptOb, 5), ob.String()) // prompts 5 times before giving up assert.Equal(t, strings.Repeat(pwPromptOb, 5), ob.String()) // prompts 5 times before giving up
assert.Equal(t, "", eb.String()) assert.Equal(t, "", eb.String())
@ -214,13 +215,13 @@ func Test_ca(t *testing.T) {
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
assert.NoError(t, ca(args, ob, eb, nopw)) require.NoError(t, ca(args, ob, eb, nopw))
// test that we won't overwrite existing certificate file // test that we won't overwrite existing certificate file
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA key: "+keyF.Name()) require.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA key: "+keyF.Name())
assert.Equal(t, "", ob.String()) assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String()) assert.Equal(t, "", eb.String())
@ -229,7 +230,7 @@ func Test_ca(t *testing.T) {
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()}
assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA cert: "+crtF.Name()) require.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA cert: "+crtF.Name())
assert.Equal(t, "", ob.String()) assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String()) assert.Equal(t, "", eb.String())
os.Remove(keyF.Name()) os.Remove(keyF.Name())

View file

@ -7,6 +7,7 @@ import (
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func Test_keygenSummary(t *testing.T) { func Test_keygenSummary(t *testing.T) {
@ -47,33 +48,33 @@ func Test_keygen(t *testing.T) {
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args := []string{"-out-pub", "/do/not/write/pleasepub", "-out-key", "/do/not/write/pleasekey"} args := []string{"-out-pub", "/do/not/write/pleasepub", "-out-key", "/do/not/write/pleasekey"}
assert.EqualError(t, keygen(args, ob, eb), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) require.EqualError(t, keygen(args, ob, eb), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
assert.Equal(t, "", ob.String()) assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String()) assert.Equal(t, "", eb.String())
// create temp key file // create temp key file
keyF, err := os.CreateTemp("", "test.key") keyF, err := os.CreateTemp("", "test.key")
assert.NoError(t, err) require.NoError(t, err)
defer os.Remove(keyF.Name()) defer os.Remove(keyF.Name())
// failed pub write // failed pub write
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-out-pub", "/do/not/write/pleasepub", "-out-key", keyF.Name()} args = []string{"-out-pub", "/do/not/write/pleasepub", "-out-key", keyF.Name()}
assert.EqualError(t, keygen(args, ob, eb), "error while writing out-pub: open /do/not/write/pleasepub: "+NoSuchDirError) require.EqualError(t, keygen(args, ob, eb), "error while writing out-pub: open /do/not/write/pleasepub: "+NoSuchDirError)
assert.Equal(t, "", ob.String()) assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String()) assert.Equal(t, "", eb.String())
// create temp pub file // create temp pub file
pubF, err := os.CreateTemp("", "test.pub") pubF, err := os.CreateTemp("", "test.pub")
assert.NoError(t, err) require.NoError(t, err)
defer os.Remove(pubF.Name()) defer os.Remove(pubF.Name())
// test proper keygen // test proper keygen
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-out-pub", pubF.Name(), "-out-key", keyF.Name()} args = []string{"-out-pub", pubF.Name(), "-out-key", keyF.Name()}
assert.NoError(t, keygen(args, ob, eb)) require.NoError(t, keygen(args, ob, eb))
assert.Equal(t, "", ob.String()) assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String()) assert.Equal(t, "", eb.String())
@ -82,13 +83,13 @@ func Test_keygen(t *testing.T) {
lKey, b, curve, err := cert.UnmarshalPrivateKeyFromPEM(rb) lKey, b, curve, err := cert.UnmarshalPrivateKeyFromPEM(rb)
assert.Equal(t, cert.Curve_CURVE25519, curve) assert.Equal(t, cert.Curve_CURVE25519, curve)
assert.Empty(t, b) assert.Empty(t, b)
assert.NoError(t, err) require.NoError(t, err)
assert.Len(t, lKey, 32) assert.Len(t, lKey, 32)
rb, _ = os.ReadFile(pubF.Name()) rb, _ = os.ReadFile(pubF.Name())
lPub, b, curve, err := cert.UnmarshalPublicKeyFromPEM(rb) lPub, b, curve, err := cert.UnmarshalPublicKeyFromPEM(rb)
assert.Equal(t, cert.Curve_CURVE25519, curve) assert.Equal(t, cert.Curve_CURVE25519, curve)
assert.Empty(t, b) assert.Empty(t, b)
assert.NoError(t, err) require.NoError(t, err)
assert.Len(t, lPub, 32) assert.Len(t, lPub, 32)
} }

View file

@ -9,6 +9,7 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func Test_help(t *testing.T) { func Test_help(t *testing.T) {
@ -79,7 +80,7 @@ func assertHelpError(t *testing.T, err error, msg string) {
t.Fatal(fmt.Sprintf("err was not a helpError: %q, expected %q", err, msg)) t.Fatal(fmt.Sprintf("err was not a helpError: %q, expected %q", err, msg))
} }
assert.EqualError(t, err, msg) require.EqualError(t, err, msg)
} }
func optionalPkcs11String(msg string) string { func optionalPkcs11String(msg string) string {

View file

@ -12,6 +12,7 @@ import (
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func Test_printSummary(t *testing.T) { func Test_printSummary(t *testing.T) {
@ -52,20 +53,20 @@ func Test_printCert(t *testing.T) {
err = printCert([]string{"-path", "does_not_exist"}, ob, eb) err = printCert([]string{"-path", "does_not_exist"}, ob, eb)
assert.Equal(t, "", ob.String()) assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String()) assert.Equal(t, "", eb.String())
assert.EqualError(t, err, "unable to read cert; open does_not_exist: "+NoSuchFileError) require.EqualError(t, err, "unable to read cert; open does_not_exist: "+NoSuchFileError)
// invalid cert at path // invalid cert at path
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
tf, err := os.CreateTemp("", "print-cert") tf, err := os.CreateTemp("", "print-cert")
assert.NoError(t, err) require.NoError(t, err)
defer os.Remove(tf.Name()) defer os.Remove(tf.Name())
tf.WriteString("-----BEGIN NOPE-----") tf.WriteString("-----BEGIN NOPE-----")
err = printCert([]string{"-path", tf.Name()}, ob, eb) err = printCert([]string{"-path", tf.Name()}, ob, eb)
assert.Equal(t, "", ob.String()) assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String()) assert.Equal(t, "", eb.String())
assert.EqualError(t, err, "error while unmarshaling cert: input did not contain a valid PEM encoded block") require.EqualError(t, err, "error while unmarshaling cert: input did not contain a valid PEM encoded block")
// test multiple certs // test multiple certs
ob.Reset() ob.Reset()
@ -84,7 +85,7 @@ func Test_printCert(t *testing.T) {
fp, _ := c.Fingerprint() fp, _ := c.Fingerprint()
pk := hex.EncodeToString(c.PublicKey()) pk := hex.EncodeToString(c.PublicKey())
sig := hex.EncodeToString(c.Signature()) sig := hex.EncodeToString(c.Signature())
assert.NoError(t, err) require.NoError(t, err)
assert.Equal( assert.Equal(
t, t,
//"NebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\n", //"NebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\n",
@ -169,7 +170,7 @@ func Test_printCert(t *testing.T) {
fp, _ = c.Fingerprint() fp, _ = c.Fingerprint()
pk = hex.EncodeToString(c.PublicKey()) pk = hex.EncodeToString(c.PublicKey())
sig = hex.EncodeToString(c.Signature()) sig = hex.EncodeToString(c.Signature())
assert.NoError(t, err) require.NoError(t, err)
assert.Equal( assert.Equal(
t, t,
`[{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1}] `[{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1}]

View file

@ -13,6 +13,7 @@ import (
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ed25519" "golang.org/x/crypto/ed25519"
) )
@ -103,17 +104,17 @@ func Test_signCert(t *testing.T) {
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args := []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} args := []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-key: open ./nope: "+NoSuchFileError) require.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-key: open ./nope: "+NoSuchFileError)
// failed to unmarshal key // failed to unmarshal key
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
caKeyF, err := os.CreateTemp("", "sign-cert.key") caKeyF, err := os.CreateTemp("", "sign-cert.key")
assert.NoError(t, err) require.NoError(t, err)
defer os.Remove(caKeyF.Name()) defer os.Remove(caKeyF.Name())
args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-key: input did not contain a valid PEM encoded block") require.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-key: input did not contain a valid PEM encoded block")
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@ -125,7 +126,7 @@ func Test_signCert(t *testing.T) {
// failed to read cert // failed to read cert
args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-crt: open ./nope: "+NoSuchFileError) require.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-crt: open ./nope: "+NoSuchFileError)
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@ -133,11 +134,11 @@ func Test_signCert(t *testing.T) {
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
caCrtF, err := os.CreateTemp("", "sign-cert.crt") caCrtF, err := os.CreateTemp("", "sign-cert.crt")
assert.NoError(t, err) require.NoError(t, err)
defer os.Remove(caCrtF.Name()) defer os.Remove(caCrtF.Name())
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-crt: input did not contain a valid PEM encoded block") require.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-crt: input did not contain a valid PEM encoded block")
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@ -148,7 +149,7 @@ func Test_signCert(t *testing.T) {
// failed to read pub // failed to read pub
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", "./nope", "-duration", "100m"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", "./nope", "-duration", "100m"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading in-pub: open ./nope: "+NoSuchFileError) require.EqualError(t, signCert(args, ob, eb, nopw), "error while reading in-pub: open ./nope: "+NoSuchFileError)
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@ -156,11 +157,11 @@ func Test_signCert(t *testing.T) {
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
inPubF, err := os.CreateTemp("", "in.pub") inPubF, err := os.CreateTemp("", "in.pub")
assert.NoError(t, err) require.NoError(t, err)
defer os.Remove(inPubF.Name()) defer os.Remove(inPubF.Name())
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", inPubF.Name(), "-duration", "100m"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", inPubF.Name(), "-duration", "100m"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing in-pub: input did not contain a valid PEM encoded block") require.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing in-pub: input did not contain a valid PEM encoded block")
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@ -210,14 +211,14 @@ func Test_signCert(t *testing.T) {
// mismatched ca key // mismatched ca key
_, caPriv2, _ := ed25519.GenerateKey(rand.Reader) _, caPriv2, _ := ed25519.GenerateKey(rand.Reader)
caKeyF2, err := os.CreateTemp("", "sign-cert-2.key") caKeyF2, err := os.CreateTemp("", "sign-cert-2.key")
assert.NoError(t, err) require.NoError(t, err)
defer os.Remove(caKeyF2.Name()) defer os.Remove(caKeyF2.Name())
caKeyF2.Write(cert.MarshalSigningPrivateKeyToPEM(cert.Curve_CURVE25519, caPriv2)) caKeyF2.Write(cert.MarshalSigningPrivateKeyToPEM(cert.Curve_CURVE25519, caPriv2))
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF2.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF2.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to sign, root certificate does not match private key") require.EqualError(t, signCert(args, ob, eb, nopw), "refusing to sign, root certificate does not match private key")
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@ -225,34 +226,34 @@ func Test_signCert(t *testing.T) {
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey", "-duration", "100m", "-subnets", "10.1.1.1/32"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey", "-duration", "100m", "-subnets", "10.1.1.1/32"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) require.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError)
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
// create temp key file // create temp key file
keyF, err := os.CreateTemp("", "test.key") keyF, err := os.CreateTemp("", "test.key")
assert.NoError(t, err) require.NoError(t, err)
os.Remove(keyF.Name()) os.Remove(keyF.Name())
// failed cert write // failed cert write
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError) require.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError)
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
os.Remove(keyF.Name()) os.Remove(keyF.Name())
// create temp cert file // create temp cert file
crtF, err := os.CreateTemp("", "test.crt") crtF, err := os.CreateTemp("", "test.crt")
assert.NoError(t, err) require.NoError(t, err)
os.Remove(crtF.Name()) os.Remove(crtF.Name())
// test proper cert with removed empty groups and subnets // test proper cert with removed empty groups and subnets
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.NoError(t, signCert(args, ob, eb, nopw)) require.NoError(t, signCert(args, ob, eb, nopw))
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@ -261,13 +262,13 @@ func Test_signCert(t *testing.T) {
lKey, b, curve, err := cert.UnmarshalPrivateKeyFromPEM(rb) lKey, b, curve, err := cert.UnmarshalPrivateKeyFromPEM(rb)
assert.Equal(t, cert.Curve_CURVE25519, curve) assert.Equal(t, cert.Curve_CURVE25519, curve)
assert.Empty(t, b) assert.Empty(t, b)
assert.NoError(t, err) require.NoError(t, err)
assert.Len(t, lKey, 32) assert.Len(t, lKey, 32)
rb, _ = os.ReadFile(crtF.Name()) rb, _ = os.ReadFile(crtF.Name())
lCrt, b, err := cert.UnmarshalCertificateFromPEM(rb) lCrt, b, err := cert.UnmarshalCertificateFromPEM(rb)
assert.Empty(t, b) assert.Empty(t, b)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "test", lCrt.Name()) assert.Equal(t, "test", lCrt.Name())
assert.Equal(t, "1.1.1.1/24", lCrt.Networks()[0].String()) assert.Equal(t, "1.1.1.1/24", lCrt.Networks()[0].String())
@ -295,7 +296,7 @@ func Test_signCert(t *testing.T) {
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-in-pub", inPubF.Name(), "-duration", "100m", "-groups", "1"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-in-pub", inPubF.Name(), "-duration", "100m", "-groups", "1"}
assert.NoError(t, signCert(args, ob, eb, nopw)) require.NoError(t, signCert(args, ob, eb, nopw))
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@ -303,7 +304,7 @@ func Test_signCert(t *testing.T) {
rb, _ = os.ReadFile(crtF.Name()) rb, _ = os.ReadFile(crtF.Name())
lCrt, b, err = cert.UnmarshalCertificateFromPEM(rb) lCrt, b, err = cert.UnmarshalCertificateFromPEM(rb)
assert.Empty(t, b) assert.Empty(t, b)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, lCrt.PublicKey(), inPub) assert.Equal(t, lCrt.PublicKey(), inPub)
// test refuse to sign cert with duration beyond root // test refuse to sign cert with duration beyond root
@ -312,7 +313,7 @@ func Test_signCert(t *testing.T) {
os.Remove(keyF.Name()) os.Remove(keyF.Name())
os.Remove(crtF.Name()) os.Remove(crtF.Name())
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "1000m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "1000m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "error while signing: certificate expires after signing certificate") require.EqualError(t, signCert(args, ob, eb, nopw), "error while signing: certificate expires after signing certificate")
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@ -320,14 +321,14 @@ func Test_signCert(t *testing.T) {
os.Remove(keyF.Name()) os.Remove(keyF.Name())
os.Remove(crtF.Name()) os.Remove(crtF.Name())
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.NoError(t, signCert(args, ob, eb, nopw)) require.NoError(t, signCert(args, ob, eb, nopw))
// test that we won't overwrite existing key file // test that we won't overwrite existing key file
os.Remove(crtF.Name()) os.Remove(crtF.Name())
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing key: "+keyF.Name()) require.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing key: "+keyF.Name())
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@ -335,14 +336,14 @@ func Test_signCert(t *testing.T) {
os.Remove(keyF.Name()) os.Remove(keyF.Name())
os.Remove(crtF.Name()) os.Remove(crtF.Name())
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.NoError(t, signCert(args, ob, eb, nopw)) require.NoError(t, signCert(args, ob, eb, nopw))
// test that we won't overwrite existing certificate file // test that we won't overwrite existing certificate file
os.Remove(keyF.Name()) os.Remove(keyF.Name())
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing cert: "+crtF.Name()) require.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing cert: "+crtF.Name())
assert.Empty(t, ob.String()) assert.Empty(t, ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@ -355,11 +356,11 @@ func Test_signCert(t *testing.T) {
eb.Reset() eb.Reset()
caKeyF, err = os.CreateTemp("", "sign-cert.key") caKeyF, err = os.CreateTemp("", "sign-cert.key")
assert.NoError(t, err) require.NoError(t, err)
defer os.Remove(caKeyF.Name()) defer os.Remove(caKeyF.Name())
caCrtF, err = os.CreateTemp("", "sign-cert.crt") caCrtF, err = os.CreateTemp("", "sign-cert.crt")
assert.NoError(t, err) require.NoError(t, err)
defer os.Remove(caCrtF.Name()) defer os.Remove(caCrtF.Name())
// generate the encrypted key // generate the encrypted key
@ -374,7 +375,7 @@ func Test_signCert(t *testing.T) {
// test with the proper password // test with the proper password
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.NoError(t, signCert(args, ob, eb, testpw)) require.NoError(t, signCert(args, ob, eb, testpw))
assert.Equal(t, "Enter passphrase: ", ob.String()) assert.Equal(t, "Enter passphrase: ", ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@ -384,7 +385,7 @@ func Test_signCert(t *testing.T) {
testpw.password = []byte("invalid password") testpw.password = []byte("invalid password")
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.Error(t, signCert(args, ob, eb, testpw)) require.Error(t, signCert(args, ob, eb, testpw))
assert.Equal(t, "Enter passphrase: ", ob.String()) assert.Equal(t, "Enter passphrase: ", ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@ -393,7 +394,7 @@ func Test_signCert(t *testing.T) {
eb.Reset() eb.Reset()
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.Error(t, signCert(args, ob, eb, nopw)) require.Error(t, signCert(args, ob, eb, nopw))
// normally the user hitting enter on the prompt would add newlines between these // normally the user hitting enter on the prompt would add newlines between these
assert.Equal(t, "Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: ", ob.String()) assert.Equal(t, "Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: ", ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
@ -403,7 +404,7 @@ func Test_signCert(t *testing.T) {
eb.Reset() eb.Reset()
args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"}
assert.Error(t, signCert(args, ob, eb, errpw)) require.Error(t, signCert(args, ob, eb, errpw))
assert.Equal(t, "Enter passphrase: ", ob.String()) assert.Equal(t, "Enter passphrase: ", ob.String())
assert.Empty(t, eb.String()) assert.Empty(t, eb.String())
} }

View file

@ -9,6 +9,7 @@ import (
"github.com/slackhq/nebula/cert" "github.com/slackhq/nebula/cert"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/crypto/ed25519" "golang.org/x/crypto/ed25519"
) )
@ -50,20 +51,20 @@ func Test_verify(t *testing.T) {
err := verify([]string{"-ca", "does_not_exist", "-crt", "does_not_exist"}, ob, eb) err := verify([]string{"-ca", "does_not_exist", "-crt", "does_not_exist"}, ob, eb)
assert.Equal(t, "", ob.String()) assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String()) assert.Equal(t, "", eb.String())
assert.EqualError(t, err, "error while reading ca: open does_not_exist: "+NoSuchFileError) require.EqualError(t, err, "error while reading ca: open does_not_exist: "+NoSuchFileError)
// invalid ca at path // invalid ca at path
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
caFile, err := os.CreateTemp("", "verify-ca") caFile, err := os.CreateTemp("", "verify-ca")
assert.NoError(t, err) require.NoError(t, err)
defer os.Remove(caFile.Name()) defer os.Remove(caFile.Name())
caFile.WriteString("-----BEGIN NOPE-----") caFile.WriteString("-----BEGIN NOPE-----")
err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb) err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb)
assert.Equal(t, "", ob.String()) assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String()) assert.Equal(t, "", eb.String())
assert.EqualError(t, err, "error while adding ca cert to pool: input did not contain a valid PEM encoded block") require.EqualError(t, err, "error while adding ca cert to pool: input did not contain a valid PEM encoded block")
// make a ca for later // make a ca for later
caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader) caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader)
@ -77,20 +78,20 @@ func Test_verify(t *testing.T) {
err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb) err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb)
assert.Equal(t, "", ob.String()) assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String()) assert.Equal(t, "", eb.String())
assert.EqualError(t, err, "unable to read crt: open does_not_exist: "+NoSuchFileError) require.EqualError(t, err, "unable to read crt: open does_not_exist: "+NoSuchFileError)
// invalid crt at path // invalid crt at path
ob.Reset() ob.Reset()
eb.Reset() eb.Reset()
certFile, err := os.CreateTemp("", "verify-cert") certFile, err := os.CreateTemp("", "verify-cert")
assert.NoError(t, err) require.NoError(t, err)
defer os.Remove(certFile.Name()) defer os.Remove(certFile.Name())
certFile.WriteString("-----BEGIN NOPE-----") certFile.WriteString("-----BEGIN NOPE-----")
err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb) err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb)
assert.Equal(t, "", ob.String()) assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String()) assert.Equal(t, "", eb.String())
assert.EqualError(t, err, "error while parsing crt: input did not contain a valid PEM encoded block") require.EqualError(t, err, "error while parsing crt: input did not contain a valid PEM encoded block")
// unverifiable cert at path // unverifiable cert at path
crt, _ := NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil) crt, _ := NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil)
@ -107,7 +108,7 @@ func Test_verify(t *testing.T) {
err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb) err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb)
assert.Equal(t, "", ob.String()) assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String()) assert.Equal(t, "", eb.String())
assert.ErrorIs(t, err, cert.ErrSignatureMismatch) require.ErrorIs(t, err, cert.ErrSignatureMismatch)
// verified cert at path // verified cert at path
crt, _ = NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil) crt, _ = NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil)
@ -119,5 +120,5 @@ func Test_verify(t *testing.T) {
err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb) err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb)
assert.Equal(t, "", ob.String()) assert.Equal(t, "", ob.String())
assert.Equal(t, "", eb.String()) assert.Equal(t, "", eb.String())
assert.NoError(t, err) require.NoError(t, err)
} }

View file

@ -19,18 +19,18 @@ func TestConfig_Load(t *testing.T) {
// invalid yaml // invalid yaml
c := NewC(l) c := NewC(l)
os.WriteFile(filepath.Join(dir, "01.yaml"), []byte(" invalid yaml"), 0644) os.WriteFile(filepath.Join(dir, "01.yaml"), []byte(" invalid yaml"), 0644)
assert.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n line 1: cannot unmarshal !!str `invalid...` into map[interface {}]interface {}") require.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n line 1: cannot unmarshal !!str `invalid...` into map[interface {}]interface {}")
// simple multi config merge // simple multi config merge
c = NewC(l) c = NewC(l)
os.RemoveAll(dir) os.RemoveAll(dir)
os.Mkdir(dir, 0755) os.Mkdir(dir, 0755)
assert.NoError(t, err) require.NoError(t, err)
os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644) os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644)
os.WriteFile(filepath.Join(dir, "02.yml"), []byte("outer:\n inner: override\nnew: hi"), 0644) os.WriteFile(filepath.Join(dir, "02.yml"), []byte("outer:\n inner: override\nnew: hi"), 0644)
assert.NoError(t, c.Load(dir)) require.NoError(t, c.Load(dir))
expected := map[interface{}]interface{}{ expected := map[interface{}]interface{}{
"outer": map[interface{}]interface{}{ "outer": map[interface{}]interface{}{
"inner": "override", "inner": "override",
@ -117,11 +117,11 @@ func TestConfig_ReloadConfig(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
done := make(chan bool, 1) done := make(chan bool, 1)
dir, err := os.MkdirTemp("", "config-test") dir, err := os.MkdirTemp("", "config-test")
assert.NoError(t, err) require.NoError(t, err)
os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644) os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644)
c := NewC(l) c := NewC(l)
assert.NoError(t, c.Load(dir)) require.NoError(t, c.Load(dir))
assert.False(t, c.HasChanged("outer.inner")) assert.False(t, c.HasChanged("outer.inner"))
assert.False(t, c.HasChanged("outer")) assert.False(t, c.HasChanged("outer"))

View file

@ -14,6 +14,7 @@ import (
"github.com/slackhq/nebula/test" "github.com/slackhq/nebula/test"
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func newTestLighthouse() *LightHouse { func newTestLighthouse() *LightHouse {
@ -223,9 +224,9 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
} }
caCert, err := tbs.Sign(nil, cert.Curve_CURVE25519, privCA) caCert, err := tbs.Sign(nil, cert.Curve_CURVE25519, privCA)
assert.NoError(t, err) require.NoError(t, err)
ncp := cert.NewCAPool() ncp := cert.NewCAPool()
assert.NoError(t, ncp.AddCA(caCert)) require.NoError(t, ncp.AddCA(caCert))
pubCrt, _, _ := ed25519.GenerateKey(rand.Reader) pubCrt, _, _ := ed25519.GenerateKey(rand.Reader)
tbs = &cert.TBSCertificate{ tbs = &cert.TBSCertificate{
@ -237,7 +238,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) {
PublicKey: pubCrt, PublicKey: pubCrt,
} }
peerCert, err := tbs.Sign(caCert, cert.Curve_CURVE25519, privCA) peerCert, err := tbs.Sign(caCert, cert.Curve_CURVE25519, privCA)
assert.NoError(t, err) require.NoError(t, err)
cachedPeerCert, err := ncp.VerifyCertificate(now.Add(time.Second), peerCert) cachedPeerCert, err := ncp.VerifyCertificate(now.Add(time.Second), peerCert)

View file

@ -19,6 +19,7 @@ import (
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/udp" "github.com/slackhq/nebula/udp"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
) )
@ -771,7 +772,7 @@ func TestRehandshakingRelays(t *testing.T) {
"key": string(myNextPrivKey), "key": string(myNextPrivKey),
} }
rc, err := yaml.Marshal(relayConfig.Settings) rc, err := yaml.Marshal(relayConfig.Settings)
assert.NoError(t, err) require.NoError(t, err)
relayConfig.ReloadConfigString(string(rc)) relayConfig.ReloadConfigString(string(rc))
for { for {
@ -875,7 +876,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) {
"key": string(myNextPrivKey), "key": string(myNextPrivKey),
} }
rc, err := yaml.Marshal(relayConfig.Settings) rc, err := yaml.Marshal(relayConfig.Settings)
assert.NoError(t, err) require.NoError(t, err)
relayConfig.ReloadConfigString(string(rc)) relayConfig.ReloadConfigString(string(rc))
for { for {
@ -970,7 +971,7 @@ func TestRehandshaking(t *testing.T) {
"key": string(myNextPrivKey), "key": string(myNextPrivKey),
} }
rc, err := yaml.Marshal(myConfig.Settings) rc, err := yaml.Marshal(myConfig.Settings)
assert.NoError(t, err) require.NoError(t, err)
myConfig.ReloadConfigString(string(rc)) myConfig.ReloadConfigString(string(rc))
for { for {
@ -987,9 +988,9 @@ func TestRehandshaking(t *testing.T) {
r.Log("Got the new cert") r.Log("Got the new cert")
// Flip their firewall to only allowing the new group to catch the tunnels reverting incorrectly // Flip their firewall to only allowing the new group to catch the tunnels reverting incorrectly
rc, err = yaml.Marshal(theirConfig.Settings) rc, err = yaml.Marshal(theirConfig.Settings)
assert.NoError(t, err) require.NoError(t, err)
var theirNewConfig m var theirNewConfig m
assert.NoError(t, yaml.Unmarshal(rc, &theirNewConfig)) require.NoError(t, yaml.Unmarshal(rc, &theirNewConfig))
theirFirewall := theirNewConfig["firewall"].(map[interface{}]interface{}) theirFirewall := theirNewConfig["firewall"].(map[interface{}]interface{})
theirFirewall["inbound"] = []m{{ theirFirewall["inbound"] = []m{{
"proto": "any", "proto": "any",
@ -997,7 +998,7 @@ func TestRehandshaking(t *testing.T) {
"group": "new group", "group": "new group",
}} }}
rc, err = yaml.Marshal(theirNewConfig) rc, err = yaml.Marshal(theirNewConfig)
assert.NoError(t, err) require.NoError(t, err)
theirConfig.ReloadConfigString(string(rc)) theirConfig.ReloadConfigString(string(rc))
r.Log("Spin until there is only 1 tunnel") r.Log("Spin until there is only 1 tunnel")
@ -1067,7 +1068,7 @@ func TestRehandshakingLoser(t *testing.T) {
"key": string(theirNextPrivKey), "key": string(theirNextPrivKey),
} }
rc, err := yaml.Marshal(theirConfig.Settings) rc, err := yaml.Marshal(theirConfig.Settings)
assert.NoError(t, err) require.NoError(t, err)
theirConfig.ReloadConfigString(string(rc)) theirConfig.ReloadConfigString(string(rc))
for { for {
@ -1083,9 +1084,9 @@ func TestRehandshakingLoser(t *testing.T) {
// Flip my firewall to only allowing the new group to catch the tunnels reverting incorrectly // Flip my firewall to only allowing the new group to catch the tunnels reverting incorrectly
rc, err = yaml.Marshal(myConfig.Settings) rc, err = yaml.Marshal(myConfig.Settings)
assert.NoError(t, err) require.NoError(t, err)
var myNewConfig m var myNewConfig m
assert.NoError(t, yaml.Unmarshal(rc, &myNewConfig)) require.NoError(t, yaml.Unmarshal(rc, &myNewConfig))
theirFirewall := myNewConfig["firewall"].(map[interface{}]interface{}) theirFirewall := myNewConfig["firewall"].(map[interface{}]interface{})
theirFirewall["inbound"] = []m{{ theirFirewall["inbound"] = []m{{
"proto": "any", "proto": "any",
@ -1093,7 +1094,7 @@ func TestRehandshakingLoser(t *testing.T) {
"group": "their new group", "group": "their new group",
}} }}
rc, err = yaml.Marshal(myNewConfig) rc, err = yaml.Marshal(myNewConfig)
assert.NoError(t, err) require.NoError(t, err)
myConfig.ReloadConfigString(string(rc)) myConfig.ReloadConfigString(string(rc))
r.Log("Spin until there is only 1 tunnel") r.Log("Spin until there is only 1 tunnel")

View file

@ -66,61 +66,61 @@ func TestFirewall_AddRule(t *testing.T) {
assert.NotNil(t, fw.OutRules) assert.NotNil(t, fw.OutRules)
ti, err := netip.ParsePrefix("1.2.3.4/32") ti, err := netip.ParsePrefix("1.2.3.4/32")
assert.NoError(t, err) require.NoError(t, err)
assert.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
// An empty rule is any // An empty rule is any
assert.True(t, fw.InRules.TCP[1].Any.Any.Any) assert.True(t, fw.InRules.TCP[1].Any.Any.Any)
assert.Empty(t, fw.InRules.TCP[1].Any.Groups) assert.Empty(t, fw.InRules.TCP[1].Any.Groups)
assert.Empty(t, fw.InRules.TCP[1].Any.Hosts) assert.Empty(t, fw.InRules.TCP[1].Any.Hosts)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
assert.Nil(t, fw.InRules.UDP[1].Any.Any) assert.Nil(t, fw.InRules.UDP[1].Any.Any)
assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1") assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1")
assert.Empty(t, fw.InRules.UDP[1].Any.Hosts) assert.Empty(t, fw.InRules.UDP[1].Any.Hosts)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", netip.Prefix{}, netip.Prefix{}, "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", netip.Prefix{}, netip.Prefix{}, "", ""))
assert.Nil(t, fw.InRules.ICMP[1].Any.Any) assert.Nil(t, fw.InRules.ICMP[1].Any.Any)
assert.Empty(t, fw.InRules.ICMP[1].Any.Groups) assert.Empty(t, fw.InRules.ICMP[1].Any.Groups)
assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1") assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, netip.Prefix{}, "", "")) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, netip.Prefix{}, "", ""))
assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any) assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any)
_, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti) _, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti)
assert.True(t, ok) assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", "")) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", ""))
assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any) assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any)
_, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti) _, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti)
assert.True(t, ok) assert.True(t, ok)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", ""))
assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name") assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "ca-sha")) require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "ca-sha"))
assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha") assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha")
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", netip.Prefix{}, netip.Prefix{}, "", "")) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", netip.Prefix{}, netip.Prefix{}, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
anyIp, err := netip.ParsePrefix("0.0.0.0/0") anyIp, err := netip.ParsePrefix("0.0.0.0/0")
assert.NoError(t, err) require.NoError(t, err)
assert.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", "")) require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", ""))
assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any)
// Test error conditions // Test error conditions
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c)
assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "")) require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
assert.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "")) require.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
} }
func TestFirewall_Drop(t *testing.T) { func TestFirewall_Drop(t *testing.T) {
@ -155,16 +155,16 @@ func TestFirewall_Drop(t *testing.T) {
h.buildNetworks(c.networks, c.unsafeNetworks) h.buildNetworks(c.networks, c.unsafeNetworks)
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
cp := cert.NewCAPool() cp := cert.NewCAPool()
// Drop outbound // Drop outbound
assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil)) assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil))
// Allow inbound // Allow inbound
resetConntrack(fw) resetConntrack(fw)
assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) require.NoError(t, fw.Drop(p, true, &h, cp, nil))
// Allow outbound because conntrack // Allow outbound because conntrack
assert.NoError(t, fw.Drop(p, false, &h, cp, nil)) require.NoError(t, fw.Drop(p, false, &h, cp, nil))
// test remote mismatch // test remote mismatch
oldRemote := p.RemoteAddr oldRemote := p.RemoteAddr
@ -174,29 +174,29 @@ func TestFirewall_Drop(t *testing.T) {
// ensure signer doesn't get in the way of group checks // ensure signer doesn't get in the way of group checks
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caSha doesn't drop on match // test caSha doesn't drop on match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad"))
assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum"))
assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) require.NoError(t, fw.Drop(p, true, &h, cp, nil))
// ensure ca name doesn't get in the way of group checks // ensure ca name doesn't get in the way of group checks
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule)
// test caName doesn't drop on match // test caName doesn't drop on match
cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}}
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c)
assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", ""))
assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", ""))
assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) require.NoError(t, fw.Drop(p, true, &h, cp, nil))
} }
func BenchmarkFirewallTable_match(b *testing.B) { func BenchmarkFirewallTable_match(b *testing.B) {
@ -350,14 +350,14 @@ func TestFirewall_Drop2(t *testing.T) {
h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks()) h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks())
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
cp := cert.NewCAPool() cp := cert.NewCAPool()
// h1/c1 lacks the proper groups // h1/c1 lacks the proper groups
assert.ErrorIs(t, fw.Drop(p, true, &h1, cp, nil), ErrNoMatchingRule) require.ErrorIs(t, fw.Drop(p, true, &h1, cp, nil), ErrNoMatchingRule)
// c has the proper groups // c has the proper groups
resetConntrack(fw) resetConntrack(fw)
assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) require.NoError(t, fw.Drop(p, true, &h, cp, nil))
} }
func TestFirewall_Drop3(t *testing.T) { func TestFirewall_Drop3(t *testing.T) {
@ -428,23 +428,23 @@ func TestFirewall_Drop3(t *testing.T) {
h3.buildNetworks(c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks()) h3.buildNetworks(c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks())
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", ""))
assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-sha")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-sha"))
cp := cert.NewCAPool() cp := cert.NewCAPool()
// c1 should pass because host match // c1 should pass because host match
assert.NoError(t, fw.Drop(p, true, &h1, cp, nil)) require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
// c2 should pass because ca sha match // c2 should pass because ca sha match
resetConntrack(fw) resetConntrack(fw)
assert.NoError(t, fw.Drop(p, true, &h2, cp, nil)) require.NoError(t, fw.Drop(p, true, &h2, cp, nil))
// c3 should fail because no match // c3 should fail because no match
resetConntrack(fw) resetConntrack(fw)
assert.Equal(t, fw.Drop(p, true, &h3, cp, nil), ErrNoMatchingRule) assert.Equal(t, fw.Drop(p, true, &h3, cp, nil), ErrNoMatchingRule)
// Test a remote address match // Test a remote address match
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("1.2.3.4/24"), netip.Prefix{}, "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("1.2.3.4/24"), netip.Prefix{}, "", ""))
assert.NoError(t, fw.Drop(p, true, &h1, cp, nil)) require.NoError(t, fw.Drop(p, true, &h1, cp, nil))
} }
func TestFirewall_DropConntrackReload(t *testing.T) { func TestFirewall_DropConntrackReload(t *testing.T) {
@ -480,29 +480,29 @@ func TestFirewall_DropConntrackReload(t *testing.T) {
h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks()) h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks())
fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
cp := cert.NewCAPool() cp := cert.NewCAPool()
// Drop outbound // Drop outbound
assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule) assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule)
// Allow inbound // Allow inbound
resetConntrack(fw) resetConntrack(fw)
assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) require.NoError(t, fw.Drop(p, true, &h, cp, nil))
// Allow outbound because conntrack // Allow outbound because conntrack
assert.NoError(t, fw.Drop(p, false, &h, cp, nil)) require.NoError(t, fw.Drop(p, false, &h, cp, nil))
oldFw := fw oldFw := fw
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
fw.Conntrack = oldFw.Conntrack fw.Conntrack = oldFw.Conntrack
fw.rulesVersion = oldFw.rulesVersion + 1 fw.rulesVersion = oldFw.rulesVersion + 1
// Allow outbound because conntrack and new rules allow port 10 // Allow outbound because conntrack and new rules allow port 10
assert.NoError(t, fw.Drop(p, false, &h, cp, nil)) require.NoError(t, fw.Drop(p, false, &h, cp, nil))
oldFw = fw oldFw = fw
fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate)
assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", ""))
fw.Conntrack = oldFw.Conntrack fw.Conntrack = oldFw.Conntrack
fw.rulesVersion = oldFw.rulesVersion + 1 fw.rulesVersion = oldFw.rulesVersion + 1
@ -585,42 +585,42 @@ func BenchmarkLookup(b *testing.B) {
func Test_parsePort(t *testing.T) { func Test_parsePort(t *testing.T) {
_, _, err := parsePort("") _, _, err := parsePort("")
assert.EqualError(t, err, "was not a number; ``") require.EqualError(t, err, "was not a number; ``")
_, _, err = parsePort(" ") _, _, err = parsePort(" ")
assert.EqualError(t, err, "was not a number; ` `") require.EqualError(t, err, "was not a number; ` `")
_, _, err = parsePort("-") _, _, err = parsePort("-")
assert.EqualError(t, err, "appears to be a range but could not be parsed; `-`") require.EqualError(t, err, "appears to be a range but could not be parsed; `-`")
_, _, err = parsePort(" - ") _, _, err = parsePort(" - ")
assert.EqualError(t, err, "appears to be a range but could not be parsed; ` - `") require.EqualError(t, err, "appears to be a range but could not be parsed; ` - `")
_, _, err = parsePort("a-b") _, _, err = parsePort("a-b")
assert.EqualError(t, err, "beginning range was not a number; `a`") require.EqualError(t, err, "beginning range was not a number; `a`")
_, _, err = parsePort("1-b") _, _, err = parsePort("1-b")
assert.EqualError(t, err, "ending range was not a number; `b`") require.EqualError(t, err, "ending range was not a number; `b`")
s, e, err := parsePort(" 1 - 2 ") s, e, err := parsePort(" 1 - 2 ")
assert.Equal(t, int32(1), s) assert.Equal(t, int32(1), s)
assert.Equal(t, int32(2), e) assert.Equal(t, int32(2), e)
assert.NoError(t, err) require.NoError(t, err)
s, e, err = parsePort("0-1") s, e, err = parsePort("0-1")
assert.Equal(t, int32(0), s) assert.Equal(t, int32(0), s)
assert.Equal(t, int32(0), e) assert.Equal(t, int32(0), e)
assert.NoError(t, err) require.NoError(t, err)
s, e, err = parsePort("9919") s, e, err = parsePort("9919")
assert.Equal(t, int32(9919), s) assert.Equal(t, int32(9919), s)
assert.Equal(t, int32(9919), e) assert.Equal(t, int32(9919), e)
assert.NoError(t, err) require.NoError(t, err)
s, e, err = parsePort("any") s, e, err = parsePort("any")
assert.Equal(t, int32(0), s) assert.Equal(t, int32(0), s)
assert.Equal(t, int32(0), e) assert.Equal(t, int32(0), e)
assert.NoError(t, err) require.NoError(t, err)
} }
func TestNewFirewallFromConfig(t *testing.T) { func TestNewFirewallFromConfig(t *testing.T) {
@ -633,53 +633,53 @@ func TestNewFirewallFromConfig(t *testing.T) {
conf := config.NewC(l) conf := config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"}
_, err = NewFirewallFromConfig(l, cs, conf) _, err = NewFirewallFromConfig(l, cs, conf)
assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules") require.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules")
// Test both port and code // Test both port and code
conf = config.NewC(l) conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}}
_, err = NewFirewallFromConfig(l, cs, conf) _, err = NewFirewallFromConfig(l, cs, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided") require.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided")
// Test missing host, group, cidr, ca_name and ca_sha // Test missing host, group, cidr, ca_name and ca_sha
conf = config.NewC(l) conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}}
_, err = NewFirewallFromConfig(l, cs, conf) _, err = NewFirewallFromConfig(l, cs, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided") require.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided")
// Test code/port error // Test code/port error
conf = config.NewC(l) conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}}
_, err = NewFirewallFromConfig(l, cs, conf) _, err = NewFirewallFromConfig(l, cs, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`") require.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`")
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "a", "host": "testh"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "a", "host": "testh"}}}
_, err = NewFirewallFromConfig(l, cs, conf) _, err = NewFirewallFromConfig(l, cs, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`") require.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`")
// Test proto error // Test proto error
conf = config.NewC(l) conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}}
_, err = NewFirewallFromConfig(l, cs, conf) _, err = NewFirewallFromConfig(l, cs, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``") require.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``")
// Test cidr parse error // Test cidr parse error
conf = config.NewC(l) conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}}
_, err = NewFirewallFromConfig(l, cs, conf) _, err = NewFirewallFromConfig(l, cs, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") require.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
// Test local_cidr parse error // Test local_cidr parse error
conf = config.NewC(l) conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}}
_, err = NewFirewallFromConfig(l, cs, conf) _, err = NewFirewallFromConfig(l, cs, conf)
assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") require.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'")
// Test both group and groups // Test both group and groups
conf = config.NewC(l) conf = config.NewC(l)
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}}
_, err = NewFirewallFromConfig(l, cs, conf) _, err = NewFirewallFromConfig(l, cs, conf)
assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided") require.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided")
} }
func TestAddFirewallRulesFromConfig(t *testing.T) { func TestAddFirewallRulesFromConfig(t *testing.T) {
@ -688,28 +688,28 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
conf := config.NewC(l) conf := config.NewC(l)
mf := &mockFirewall{} mf := &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}}
assert.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding udp rule // Test adding udp rule
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}}
assert.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding icmp rule // Test adding icmp rule
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}}
assert.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf))
assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding any rule // Test adding any rule
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
assert.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding rule with cidr // Test adding rule with cidr
@ -717,49 +717,49 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "cidr": cidr.String()}}} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "cidr": cidr.String()}}}
assert.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: netip.Prefix{}}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: netip.Prefix{}}, mf.lastCall)
// Test adding rule with local_cidr // Test adding rule with local_cidr
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "local_cidr": cidr.String()}}} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "local_cidr": cidr.String()}}}
assert.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall)
// Test adding rule with ca_sha // Test adding rule with ca_sha
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}}
assert.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caSha: "12312313123"}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caSha: "12312313123"}, mf.lastCall)
// Test adding rule with ca_name // Test adding rule with ca_name
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}}
assert.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caName: "root01"}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caName: "root01"}, mf.lastCall)
// Test single group // Test single group
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}}
assert.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test single groups // Test single groups
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}}
assert.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test multiple AND groups // Test multiple AND groups
conf = config.NewC(l) conf = config.NewC(l)
mf = &mockFirewall{} mf = &mockFirewall{}
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}}
assert.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf))
assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall)
// Test Add error // Test Add error
@ -767,7 +767,7 @@ func TestAddFirewallRulesFromConfig(t *testing.T) {
mf = &mockFirewall{} mf = &mockFirewall{}
mf.nextCallReturn = errors.New("test error") mf.nextCallReturn = errors.New("test error")
conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}}
assert.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; `test error`") require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; `test error`")
} }
func TestFirewall_convertRule(t *testing.T) { func TestFirewall_convertRule(t *testing.T) {
@ -782,7 +782,7 @@ func TestFirewall_convertRule(t *testing.T) {
r, err := convertRule(l, c, "test", 1) r, err := convertRule(l, c, "test", 1)
assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value") assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value")
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "group1", r.Group) assert.Equal(t, "group1", r.Group)
// Ensure group array of > 1 is errord // Ensure group array of > 1 is errord
@ -793,7 +793,7 @@ func TestFirewall_convertRule(t *testing.T) {
r, err = convertRule(l, c, "test", 1) r, err = convertRule(l, c, "test", 1)
assert.Equal(t, "", ob.String()) assert.Equal(t, "", ob.String())
assert.Error(t, err, "group should contain a single value, an array with more than one entry was provided") require.Error(t, err, "group should contain a single value, an array with more than one entry was provided")
// Make sure a well formed group is alright // Make sure a well formed group is alright
ob.Reset() ob.Reset()
@ -802,7 +802,7 @@ func TestFirewall_convertRule(t *testing.T) {
} }
r, err = convertRule(l, c, "test", 1) r, err = convertRule(l, c, "test", 1)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, "group1", r.Group) assert.Equal(t, "group1", r.Group)
} }

View file

@ -5,6 +5,7 @@ import (
"testing" "testing"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
type headerTest struct { type headerTest struct {
@ -111,7 +112,7 @@ func TestHeader_String(t *testing.T) {
func TestHeader_MarshalJSON(t *testing.T) { func TestHeader_MarshalJSON(t *testing.T) {
b, err := (&H{100, Test, TestRequest, 99, 98, 97}).MarshalJSON() b, err := (&H{100, Test, TestRequest, 99, 98, 97}).MarshalJSON()
assert.NoError(t, err) require.NoError(t, err)
assert.Equal( assert.Equal(
t, t,
"{\"messageCounter\":97,\"remoteIndex\":98,\"reserved\":99,\"subType\":\"testRequest\",\"type\":\"test\",\"version\":100}", "{\"messageCounter\":97,\"remoteIndex\":98,\"reserved\":99,\"subType\":\"testRequest\",\"type\":\"test\",\"version\":100}",

View file

@ -13,6 +13,7 @@ import (
"github.com/slackhq/nebula/header" "github.com/slackhq/nebula/header"
"github.com/slackhq/nebula/test" "github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"gopkg.in/yaml.v2" "gopkg.in/yaml.v2"
) )
@ -21,7 +22,7 @@ func TestOldIPv4Only(t *testing.T) {
b := []byte{8, 129, 130, 132, 80, 16, 10} b := []byte{8, 129, 130, 132, 80, 16, 10}
var m V4AddrPort var m V4AddrPort
err := m.Unmarshal(b) err := m.Unmarshal(b)
assert.NoError(t, err) require.NoError(t, err)
ip := netip.MustParseAddr("10.1.1.1") ip := netip.MustParseAddr("10.1.1.1")
bp := ip.As4() bp := ip.As4()
assert.Equal(t, binary.BigEndian.Uint32(bp[:]), m.GetAddr()) assert.Equal(t, binary.BigEndian.Uint32(bp[:]), m.GetAddr())
@ -42,14 +43,14 @@ func Test_lhStaticMapping(t *testing.T) {
c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1}} c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1}}
c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}} c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}}
_, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) _, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
assert.NoError(t, err) require.NoError(t, err)
lh2 := "10.128.0.3" lh2 := "10.128.0.3"
c = config.NewC(l) c = config.NewC(l)
c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1, lh2}} c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1, lh2}}
c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"100.1.1.1:4242"}} c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"100.1.1.1:4242"}}
_, err = NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) _, err = NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
assert.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry") require.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry")
} }
func TestReloadLighthouseInterval(t *testing.T) { func TestReloadLighthouseInterval(t *testing.T) {
@ -71,19 +72,19 @@ func TestReloadLighthouseInterval(t *testing.T) {
c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}} c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}}
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
assert.NoError(t, err) require.NoError(t, err)
lh.ifce = &mockEncWriter{} lh.ifce = &mockEncWriter{}
// The first one routine is kicked off by main.go currently, lets make sure that one dies // The first one routine is kicked off by main.go currently, lets make sure that one dies
assert.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 5")) require.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 5"))
assert.Equal(t, int64(5), lh.interval.Load()) assert.Equal(t, int64(5), lh.interval.Load())
// Subsequent calls are killed off by the LightHouse.Reload function // Subsequent calls are killed off by the LightHouse.Reload function
assert.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 10")) require.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 10"))
assert.Equal(t, int64(10), lh.interval.Load()) assert.Equal(t, int64(10), lh.interval.Load())
// If this completes then nothing is stealing our reload routine // If this completes then nothing is stealing our reload routine
assert.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 11")) require.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 11"))
assert.Equal(t, int64(11), lh.interval.Load()) assert.Equal(t, int64(11), lh.interval.Load())
} }
@ -99,9 +100,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
c := config.NewC(l) c := config.NewC(l)
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
if !assert.NoError(b, err) { require.NoError(b, err)
b.Fatal()
}
hAddr := netip.MustParseAddrPort("4.5.6.7:12345") hAddr := netip.MustParseAddrPort("4.5.6.7:12345")
hAddr2 := netip.MustParseAddrPort("4.5.6.7:12346") hAddr2 := netip.MustParseAddrPort("4.5.6.7:12346")
@ -145,7 +144,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
}, },
} }
p, err := req.Marshal() p, err := req.Marshal()
assert.NoError(b, err) require.NoError(b, err)
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
lhh.HandleRequest(rAddr, hi, p, mw) lhh.HandleRequest(rAddr, hi, p, mw)
} }
@ -160,7 +159,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) {
}, },
} }
p, err := req.Marshal() p, err := req.Marshal()
assert.NoError(b, err) require.NoError(b, err)
for n := 0; n < b.N; n++ { for n := 0; n < b.N; n++ {
lhh.HandleRequest(rAddr, hi, p, mw) lhh.HandleRequest(rAddr, hi, p, mw)
@ -205,7 +204,7 @@ func TestLighthouse_Memory(t *testing.T) {
} }
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
lh.ifce = &mockEncWriter{} lh.ifce = &mockEncWriter{}
assert.NoError(t, err) require.NoError(t, err)
lhh := lh.NewRequestHandler() lhh := lh.NewRequestHandler()
// Test that my first update responds with just that // Test that my first update responds with just that
@ -290,7 +289,7 @@ func TestLighthouse_reload(t *testing.T) {
} }
lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil)
assert.NoError(t, err) require.NoError(t, err)
nc := map[interface{}]interface{}{ nc := map[interface{}]interface{}{
"static_host_map": map[interface{}]interface{}{ "static_host_map": map[interface{}]interface{}{
@ -298,11 +297,11 @@ func TestLighthouse_reload(t *testing.T) {
}, },
} }
rc, err := yaml.Marshal(nc) rc, err := yaml.Marshal(nc)
assert.NoError(t, err) require.NoError(t, err)
c.ReloadConfigString(string(rc)) c.ReloadConfigString(string(rc))
err = lh.reload(c, false) err = lh.reload(c, false)
assert.NoError(t, err) require.NoError(t, err)
} }
func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, lhh *LightHouseHandler) testLhReply { func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, lhh *LightHouseHandler) testLhReply {

View file

@ -12,6 +12,7 @@ import (
"github.com/slackhq/nebula/firewall" "github.com/slackhq/nebula/firewall"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
"golang.org/x/net/ipv4" "golang.org/x/net/ipv4"
) )
@ -20,13 +21,13 @@ func Test_newPacket(t *testing.T) {
// length fails // length fails
err := newPacket([]byte{}, true, p) err := newPacket([]byte{}, true, p)
assert.ErrorIs(t, err, ErrPacketTooShort) require.ErrorIs(t, err, ErrPacketTooShort)
err = newPacket([]byte{0x40}, true, p) err = newPacket([]byte{0x40}, true, p)
assert.ErrorIs(t, err, ErrIPv4PacketTooShort) require.ErrorIs(t, err, ErrIPv4PacketTooShort)
err = newPacket([]byte{0x60}, true, p) err = newPacket([]byte{0x60}, true, p)
assert.ErrorIs(t, err, ErrIPv6PacketTooShort) require.ErrorIs(t, err, ErrIPv6PacketTooShort)
// length fail with ip options // length fail with ip options
h := ipv4.Header{ h := ipv4.Header{
@ -39,15 +40,15 @@ func Test_newPacket(t *testing.T) {
b, _ := h.Marshal() b, _ := h.Marshal()
err = newPacket(b, true, p) err = newPacket(b, true, p)
assert.ErrorIs(t, err, ErrIPv4InvalidHeaderLength) require.ErrorIs(t, err, ErrIPv4InvalidHeaderLength)
// not an ipv4 packet // not an ipv4 packet
err = newPacket([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p) err = newPacket([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
assert.ErrorIs(t, err, ErrUnknownIPVersion) require.ErrorIs(t, err, ErrUnknownIPVersion)
// invalid ihl // invalid ihl
err = newPacket([]byte{4<<4 | (8 >> 2 & 0x0f), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p) err = newPacket([]byte{4<<4 | (8 >> 2 & 0x0f), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p)
assert.ErrorIs(t, err, ErrIPv4InvalidHeaderLength) require.ErrorIs(t, err, ErrIPv4InvalidHeaderLength)
// account for variable ip header length - incoming // account for variable ip header length - incoming
h = ipv4.Header{ h = ipv4.Header{
@ -63,7 +64,7 @@ func Test_newPacket(t *testing.T) {
b = append(b, []byte{0, 3, 0, 4}...) b = append(b, []byte{0, 3, 0, 4}...)
err = newPacket(b, true, p) err = newPacket(b, true, p)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol) assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol)
assert.Equal(t, netip.MustParseAddr("10.0.0.2"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("10.0.0.2"), p.LocalAddr)
assert.Equal(t, netip.MustParseAddr("10.0.0.1"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("10.0.0.1"), p.RemoteAddr)
@ -85,7 +86,7 @@ func Test_newPacket(t *testing.T) {
b = append(b, []byte{0, 5, 0, 6}...) b = append(b, []byte{0, 5, 0, 6}...)
err = newPacket(b, false, p) err = newPacket(b, false, p)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, uint8(2), p.Protocol) assert.Equal(t, uint8(2), p.Protocol)
assert.Equal(t, netip.MustParseAddr("10.0.0.1"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("10.0.0.1"), p.LocalAddr)
assert.Equal(t, netip.MustParseAddr("10.0.0.2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("10.0.0.2"), p.RemoteAddr)
@ -111,10 +112,10 @@ func Test_newPacket_v6(t *testing.T) {
FixLengths: false, FixLengths: false,
} }
err := gopacket.SerializeLayers(buffer, opt, &ip) err := gopacket.SerializeLayers(buffer, opt, &ip)
assert.NoError(t, err) require.NoError(t, err)
err = newPacket(buffer.Bytes(), true, p) err = newPacket(buffer.Bytes(), true, p)
assert.ErrorIs(t, err, ErrIPv6CouldNotFindPayload) require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
// A good ICMP packet // A good ICMP packet
ip = layers.IPv6{ ip = layers.IPv6{
@ -134,7 +135,7 @@ func Test_newPacket_v6(t *testing.T) {
} }
err = newPacket(buffer.Bytes(), true, p) err = newPacket(buffer.Bytes(), true, p)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, uint8(layers.IPProtocolICMPv6), p.Protocol) assert.Equal(t, uint8(layers.IPProtocolICMPv6), p.Protocol)
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
@ -146,7 +147,7 @@ func Test_newPacket_v6(t *testing.T) {
b := buffer.Bytes() b := buffer.Bytes()
b[6] = byte(layers.IPProtocolESP) b[6] = byte(layers.IPProtocolESP)
err = newPacket(b, true, p) err = newPacket(b, true, p)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, uint8(layers.IPProtocolESP), p.Protocol) assert.Equal(t, uint8(layers.IPProtocolESP), p.Protocol)
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
@ -158,7 +159,7 @@ func Test_newPacket_v6(t *testing.T) {
b = buffer.Bytes() b = buffer.Bytes()
b[6] = byte(layers.IPProtocolNoNextHeader) b[6] = byte(layers.IPProtocolNoNextHeader)
err = newPacket(b, true, p) err = newPacket(b, true, p)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, uint8(layers.IPProtocolNoNextHeader), p.Protocol) assert.Equal(t, uint8(layers.IPProtocolNoNextHeader), p.Protocol)
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
@ -170,7 +171,7 @@ func Test_newPacket_v6(t *testing.T) {
b = buffer.Bytes() b = buffer.Bytes()
b[6] = 255 // 255 is a reserved protocol number b[6] = 255 // 255 is a reserved protocol number
err = newPacket(b, true, p) err = newPacket(b, true, p)
assert.ErrorIs(t, err, ErrIPv6CouldNotFindPayload) require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
// A good UDP packet // A good UDP packet
ip = layers.IPv6{ ip = layers.IPv6{
@ -186,7 +187,7 @@ func Test_newPacket_v6(t *testing.T) {
DstPort: layers.UDPPort(22), DstPort: layers.UDPPort(22),
} }
err = udp.SetNetworkLayerForChecksum(&ip) err = udp.SetNetworkLayerForChecksum(&ip)
assert.NoError(t, err) require.NoError(t, err)
buffer.Clear() buffer.Clear()
err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload([]byte{0xde, 0xad, 0xbe, 0xef})) err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload([]byte{0xde, 0xad, 0xbe, 0xef}))
@ -197,7 +198,7 @@ func Test_newPacket_v6(t *testing.T) {
// incoming // incoming
err = newPacket(b, true, p) err = newPacket(b, true, p)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol) assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol)
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
@ -207,7 +208,7 @@ func Test_newPacket_v6(t *testing.T) {
// outgoing // outgoing
err = newPacket(b, false, p) err = newPacket(b, false, p)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol) assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol)
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
@ -217,14 +218,14 @@ func Test_newPacket_v6(t *testing.T) {
// Too short UDP packet // Too short UDP packet
err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes
assert.ErrorIs(t, err, ErrIPv6PacketTooShort) require.ErrorIs(t, err, ErrIPv6PacketTooShort)
// A good TCP packet // A good TCP packet
b[6] = byte(layers.IPProtocolTCP) b[6] = byte(layers.IPProtocolTCP)
// incoming // incoming
err = newPacket(b, true, p) err = newPacket(b, true, p)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol) assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol)
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
@ -234,7 +235,7 @@ func Test_newPacket_v6(t *testing.T) {
// outgoing // outgoing
err = newPacket(b, false, p) err = newPacket(b, false, p)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol) assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol)
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
@ -244,7 +245,7 @@ func Test_newPacket_v6(t *testing.T) {
// Too short TCP packet // Too short TCP packet
err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes
assert.ErrorIs(t, err, ErrIPv6PacketTooShort) require.ErrorIs(t, err, ErrIPv6PacketTooShort)
// A good UDP packet with an AH header // A good UDP packet with an AH header
ip = layers.IPv6{ ip = layers.IPv6{
@ -279,7 +280,7 @@ func Test_newPacket_v6(t *testing.T) {
b = append(b, udpHeader...) b = append(b, udpHeader...)
err = newPacket(b, true, p) err = newPacket(b, true, p)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol) assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol)
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
@ -290,7 +291,7 @@ func Test_newPacket_v6(t *testing.T) {
// Invalid AH header // Invalid AH header
b = buffer.Bytes() b = buffer.Bytes()
err = newPacket(b, true, p) err = newPacket(b, true, p)
assert.ErrorIs(t, err, ErrIPv6CouldNotFindPayload) require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload)
} }
func Test_newPacket_ipv6Fragment(t *testing.T) { func Test_newPacket_ipv6Fragment(t *testing.T) {
@ -338,7 +339,7 @@ func Test_newPacket_ipv6Fragment(t *testing.T) {
// Test first fragment incoming // Test first fragment incoming
err = newPacket(firstFrag, true, p) err = newPacket(firstFrag, true, p)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol) assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
@ -348,7 +349,7 @@ func Test_newPacket_ipv6Fragment(t *testing.T) {
// Test first fragment outgoing // Test first fragment outgoing
err = newPacket(firstFrag, false, p) err = newPacket(firstFrag, false, p)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol) assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
@ -377,7 +378,7 @@ func Test_newPacket_ipv6Fragment(t *testing.T) {
// Test second fragment incoming // Test second fragment incoming
err = newPacket(secondFrag, true, p) err = newPacket(secondFrag, true, p)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr)
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr)
assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol) assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
@ -387,7 +388,7 @@ func Test_newPacket_ipv6Fragment(t *testing.T) {
// Test second fragment outgoing // Test second fragment outgoing
err = newPacket(secondFrag, false, p) err = newPacket(secondFrag, false, p)
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr)
assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr)
assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol) assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol)
@ -397,7 +398,7 @@ func Test_newPacket_ipv6Fragment(t *testing.T) {
// Too short of a fragment packet // Too short of a fragment packet
err = newPacket(secondFrag[:len(secondFrag)-10], false, p) err = newPacket(secondFrag[:len(secondFrag)-10], false, p)
assert.ErrorIs(t, err, ErrIPv6PacketTooShort) require.ErrorIs(t, err, ErrIPv6PacketTooShort)
} }
func BenchmarkParseV6(b *testing.B) { func BenchmarkParseV6(b *testing.B) {

View file

@ -8,84 +8,85 @@ import (
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/test" "github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func Test_parseRoutes(t *testing.T) { func Test_parseRoutes(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
c := config.NewC(l) c := config.NewC(l)
n, err := netip.ParsePrefix("10.0.0.0/24") n, err := netip.ParsePrefix("10.0.0.0/24")
assert.NoError(t, err) require.NoError(t, err)
// test no routes config // test no routes config
routes, err := parseRoutes(c, []netip.Prefix{n}) routes, err := parseRoutes(c, []netip.Prefix{n})
assert.NoError(t, err) require.NoError(t, err)
assert.Empty(t, routes) assert.Empty(t, routes)
// not an array // not an array
c.Settings["tun"] = map[interface{}]interface{}{"routes": "hi"} c.Settings["tun"] = map[interface{}]interface{}{"routes": "hi"}
routes, err = parseRoutes(c, []netip.Prefix{n}) routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "tun.routes is not an array") require.EqualError(t, err, "tun.routes is not an array")
// no routes // no routes
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{}} c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{}}
routes, err = parseRoutes(c, []netip.Prefix{n}) routes, err = parseRoutes(c, []netip.Prefix{n})
assert.NoError(t, err) require.NoError(t, err)
assert.Empty(t, routes) assert.Empty(t, routes)
// weird route // weird route
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{"asdf"}} c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{"asdf"}}
routes, err = parseRoutes(c, []netip.Prefix{n}) routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1 in tun.routes is invalid") require.EqualError(t, err, "entry 1 in tun.routes is invalid")
// no mtu // no mtu
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{}}} c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{}}}
routes, err = parseRoutes(c, []netip.Prefix{n}) routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.mtu in tun.routes is not present") require.EqualError(t, err, "entry 1.mtu in tun.routes is not present")
// bad mtu // bad mtu
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "nope"}}} c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "nope"}}}
routes, err = parseRoutes(c, []netip.Prefix{n}) routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.mtu in tun.routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax") require.EqualError(t, err, "entry 1.mtu in tun.routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
// low mtu // low mtu
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "499"}}} c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "499"}}}
routes, err = parseRoutes(c, []netip.Prefix{n}) routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.mtu in tun.routes is below 500: 499") require.EqualError(t, err, "entry 1.mtu in tun.routes is below 500: 499")
// missing route // missing route
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500"}}} c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500"}}}
routes, err = parseRoutes(c, []netip.Prefix{n}) routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.routes is not present") require.EqualError(t, err, "entry 1.route in tun.routes is not present")
// unparsable route // unparsable route
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "nope"}}} c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "nope"}}}
routes, err = parseRoutes(c, []netip.Prefix{n}) routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'") require.EqualError(t, err, "entry 1.route in tun.routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
// below network range // below network range
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "1.0.0.0/8"}}} c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "1.0.0.0/8"}}}
routes, err = parseRoutes(c, []netip.Prefix{n}) routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 1.0.0.0/8, networks: [10.0.0.0/24]") require.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 1.0.0.0/8, networks: [10.0.0.0/24]")
// above network range // above network range
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "10.0.1.0/24"}}} c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "10.0.1.0/24"}}}
routes, err = parseRoutes(c, []netip.Prefix{n}) routes, err = parseRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 10.0.1.0/24, networks: [10.0.0.0/24]") require.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 10.0.1.0/24, networks: [10.0.0.0/24]")
// Not in multiple ranges // Not in multiple ranges
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "192.0.0.0/24"}}} c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "192.0.0.0/24"}}}
routes, err = parseRoutes(c, []netip.Prefix{n, netip.MustParsePrefix("192.1.0.0/24")}) routes, err = parseRoutes(c, []netip.Prefix{n, netip.MustParsePrefix("192.1.0.0/24")})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 192.0.0.0/24, networks: [10.0.0.0/24 192.1.0.0/24]") require.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 192.0.0.0/24, networks: [10.0.0.0/24 192.1.0.0/24]")
// happy case // happy case
c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{ c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{
@ -93,7 +94,7 @@ func Test_parseRoutes(t *testing.T) {
map[interface{}]interface{}{"mtu": "8000", "route": "10.0.0.1/32"}, map[interface{}]interface{}{"mtu": "8000", "route": "10.0.0.1/32"},
}} }}
routes, err = parseRoutes(c, []netip.Prefix{n}) routes, err = parseRoutes(c, []netip.Prefix{n})
assert.NoError(t, err) require.NoError(t, err)
assert.Len(t, routes, 2) assert.Len(t, routes, 2)
tested := 0 tested := 0
@ -119,36 +120,36 @@ func Test_parseUnsafeRoutes(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
c := config.NewC(l) c := config.NewC(l)
n, err := netip.ParsePrefix("10.0.0.0/24") n, err := netip.ParsePrefix("10.0.0.0/24")
assert.NoError(t, err) require.NoError(t, err)
// test no routes config // test no routes config
routes, err := parseUnsafeRoutes(c, []netip.Prefix{n}) routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
assert.NoError(t, err) require.NoError(t, err)
assert.Empty(t, routes) assert.Empty(t, routes)
// not an array // not an array
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": "hi"} c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": "hi"}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "tun.unsafe_routes is not an array") require.EqualError(t, err, "tun.unsafe_routes is not an array")
// no routes // no routes
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{}} c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.NoError(t, err) require.NoError(t, err)
assert.Empty(t, routes) assert.Empty(t, routes)
// weird route // weird route
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{"asdf"}} c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{"asdf"}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1 in tun.unsafe_routes is invalid") require.EqualError(t, err, "entry 1 in tun.unsafe_routes is invalid")
// no via // no via
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{}}} c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{}}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not present") require.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not present")
// invalid via // invalid via
for _, invalidValue := range []interface{}{ for _, invalidValue := range []interface{}{
@ -157,44 +158,44 @@ func Test_parseUnsafeRoutes(t *testing.T) {
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": invalidValue}}} c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": invalidValue}}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string: found %T", invalidValue)) require.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string: found %T", invalidValue))
} }
// unparsable via // unparsable via
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": "nope"}}} c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": "nope"}}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP") require.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP")
// missing route // missing route
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500"}}} c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500"}}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is not present") require.EqualError(t, err, "entry 1.route in tun.unsafe_routes is not present")
// unparsable route // unparsable route
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}} c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'") require.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'")
// within network range // within network range
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.0.0/24"}}} c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.0.0/24"}}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is contained within the configured vpn networks; route: 10.0.0.0/24, network: 10.0.0.0/24") require.EqualError(t, err, "entry 1.route in tun.unsafe_routes is contained within the configured vpn networks; route: 10.0.0.0/24, network: 10.0.0.0/24")
// below network range // below network range
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}} c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Len(t, routes, 1) assert.Len(t, routes, 1)
assert.NoError(t, err) require.NoError(t, err)
// above network range // above network range
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.1.0/24"}}} c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.1.0/24"}}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Len(t, routes, 1) assert.Len(t, routes, 1)
assert.NoError(t, err) require.NoError(t, err)
// no mtu // no mtu
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}} c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}}
@ -206,19 +207,19 @@ func Test_parseUnsafeRoutes(t *testing.T) {
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "nope"}}} c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "nope"}}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax") require.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax")
// low mtu // low mtu
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "499"}}} c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "499"}}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is below 500: 499") require.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is below 500: 499")
// bad install // bad install
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "nope"}}} c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "nope"}}}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.Nil(t, routes) assert.Nil(t, routes)
assert.EqualError(t, err, "entry 1.install in tun.unsafe_routes is not a boolean: strconv.ParseBool: parsing \"nope\": invalid syntax") require.EqualError(t, err, "entry 1.install in tun.unsafe_routes is not a boolean: strconv.ParseBool: parsing \"nope\": invalid syntax")
// happy case // happy case
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{ c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{
@ -228,7 +229,7 @@ func Test_parseUnsafeRoutes(t *testing.T) {
map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32"}, map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32"},
}} }}
routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) routes, err = parseUnsafeRoutes(c, []netip.Prefix{n})
assert.NoError(t, err) require.NoError(t, err)
assert.Len(t, routes, 4) assert.Len(t, routes, 4)
tested := 0 tested := 0
@ -260,38 +261,38 @@ func Test_makeRouteTree(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
c := config.NewC(l) c := config.NewC(l)
n, err := netip.ParsePrefix("10.0.0.0/24") n, err := netip.ParsePrefix("10.0.0.0/24")
assert.NoError(t, err) require.NoError(t, err)
c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{ c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{
map[interface{}]interface{}{"via": "192.168.0.1", "route": "1.0.0.0/28"}, map[interface{}]interface{}{"via": "192.168.0.1", "route": "1.0.0.0/28"},
map[interface{}]interface{}{"via": "192.168.0.2", "route": "1.0.0.1/32"}, map[interface{}]interface{}{"via": "192.168.0.2", "route": "1.0.0.1/32"},
}} }}
routes, err := parseUnsafeRoutes(c, []netip.Prefix{n}) routes, err := parseUnsafeRoutes(c, []netip.Prefix{n})
assert.NoError(t, err) require.NoError(t, err)
assert.Len(t, routes, 2) assert.Len(t, routes, 2)
routeTree, err := makeRouteTree(l, routes, true) routeTree, err := makeRouteTree(l, routes, true)
assert.NoError(t, err) require.NoError(t, err)
ip, err := netip.ParseAddr("1.0.0.2") ip, err := netip.ParseAddr("1.0.0.2")
assert.NoError(t, err) require.NoError(t, err)
r, ok := routeTree.Lookup(ip) r, ok := routeTree.Lookup(ip)
assert.True(t, ok) assert.True(t, ok)
nip, err := netip.ParseAddr("192.168.0.1") nip, err := netip.ParseAddr("192.168.0.1")
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, nip, r) assert.Equal(t, nip, r)
ip, err = netip.ParseAddr("1.0.0.1") ip, err = netip.ParseAddr("1.0.0.1")
assert.NoError(t, err) require.NoError(t, err)
r, ok = routeTree.Lookup(ip) r, ok = routeTree.Lookup(ip)
assert.True(t, ok) assert.True(t, ok)
nip, err = netip.ParseAddr("192.168.0.2") nip, err = netip.ParseAddr("192.168.0.2")
assert.NoError(t, err) require.NoError(t, err)
assert.Equal(t, nip, r) assert.Equal(t, nip, r)
ip, err = netip.ParseAddr("1.1.0.1") ip, err = netip.ParseAddr("1.1.0.1")
assert.NoError(t, err) require.NoError(t, err)
r, ok = routeTree.Lookup(ip) r, ok = routeTree.Lookup(ip)
assert.False(t, ok) assert.False(t, ok)
} }

View file

@ -7,6 +7,7 @@ import (
"github.com/slackhq/nebula/config" "github.com/slackhq/nebula/config"
"github.com/slackhq/nebula/test" "github.com/slackhq/nebula/test"
"github.com/stretchr/testify/assert" "github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
) )
func TestNewPunchyFromConfig(t *testing.T) { func TestNewPunchyFromConfig(t *testing.T) {
@ -56,7 +57,7 @@ func TestPunchy_reload(t *testing.T) {
l := test.NewLogger() l := test.NewLogger()
c := config.NewC(l) c := config.NewC(l)
delay, _ := time.ParseDuration("1m") delay, _ := time.ParseDuration("1m")
assert.NoError(t, c.LoadString(` require.NoError(t, c.LoadString(`
punchy: punchy:
delay: 1m delay: 1m
respond: false respond: false
@ -66,7 +67,7 @@ punchy:
assert.False(t, p.GetRespond()) assert.False(t, p.GetRespond())
newDelay, _ := time.ParseDuration("10m") newDelay, _ := time.ParseDuration("10m")
assert.NoError(t, c.ReloadConfigString(` require.NoError(t, c.ReloadConfigString(`
punchy: punchy:
delay: 10m delay: 10m
respond: true respond: true