diff --git a/.github/workflows/test.yml b/.github/workflows/test.yml index 4f3f2ed..b8a4f03 100644 --- a/.github/workflows/test.yml +++ b/.github/workflows/test.yml @@ -31,6 +31,11 @@ jobs: - name: Vet run: make vet + - name: golangci-lint + uses: golangci/golangci-lint-action@v6 + with: + version: v1.64 + - name: Test run: make test @@ -109,6 +114,11 @@ jobs: - name: Vet run: make vet + - name: golangci-lint + uses: golangci/golangci-lint-action@v6 + with: + version: v1.64 + - name: Test run: make test diff --git a/.golangci.yaml b/.golangci.yaml new file mode 100644 index 0000000..f792069 --- /dev/null +++ b/.golangci.yaml @@ -0,0 +1,9 @@ +# yaml-language-server: $schema=https://golangci-lint.run/jsonschema/golangci.jsonschema.json +linters: + # Disable all linters. + # Default: false + disable-all: true + # Enable specific linter + # https://golangci-lint.run/usage/linters/#enabled-by-default + enable: + - testifylint diff --git a/allow_list_test.go b/allow_list_test.go index 6d5e76b..d7d2c9a 100644 --- a/allow_list_test.go +++ b/allow_list_test.go @@ -9,6 +9,7 @@ import ( "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestNewAllowListFromConfig(t *testing.T) { @@ -18,21 +19,21 @@ func TestNewAllowListFromConfig(t *testing.T) { "192.168.0.0": true, } r, err := newAllowListFromConfig(c, "allowlist", nil) - assert.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0. netip.ParsePrefix(\"192.168.0.0\"): no '/'") + require.EqualError(t, err, "config `allowlist` has invalid CIDR: 192.168.0.0. netip.ParsePrefix(\"192.168.0.0\"): no '/'") assert.Nil(t, r) c.Settings["allowlist"] = map[interface{}]interface{}{ "192.168.0.0/16": "abc", } r, err = newAllowListFromConfig(c, "allowlist", nil) - assert.EqualError(t, err, "config `allowlist` has invalid value (type string): abc") + require.EqualError(t, err, "config `allowlist` has invalid value (type string): abc") c.Settings["allowlist"] = map[interface{}]interface{}{ "192.168.0.0/16": true, "10.0.0.0/8": false, } r, err = newAllowListFromConfig(c, "allowlist", nil) - assert.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for 0.0.0.0/0") + require.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for 0.0.0.0/0") c.Settings["allowlist"] = map[interface{}]interface{}{ "0.0.0.0/0": true, @@ -42,7 +43,7 @@ func TestNewAllowListFromConfig(t *testing.T) { "fd00:fd00::/16": false, } r, err = newAllowListFromConfig(c, "allowlist", nil) - assert.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for ::/0") + require.EqualError(t, err, "config `allowlist` contains both true and false rules, but no default set for ::/0") c.Settings["allowlist"] = map[interface{}]interface{}{ "0.0.0.0/0": true, @@ -75,7 +76,7 @@ func TestNewAllowListFromConfig(t *testing.T) { }, } lr, err := NewLocalAllowListFromConfig(c, "allowlist") - assert.EqualError(t, err, "config `allowlist.interfaces` has invalid value (type string): foo") + require.EqualError(t, err, "config `allowlist.interfaces` has invalid value (type string): foo") c.Settings["allowlist"] = map[interface{}]interface{}{ "interfaces": map[interface{}]interface{}{ @@ -84,7 +85,7 @@ func TestNewAllowListFromConfig(t *testing.T) { }, } lr, err = NewLocalAllowListFromConfig(c, "allowlist") - assert.EqualError(t, err, "config `allowlist.interfaces` values must all be the same true/false value") + require.EqualError(t, err, "config `allowlist.interfaces` values must all be the same true/false value") c.Settings["allowlist"] = map[interface{}]interface{}{ "interfaces": map[interface{}]interface{}{ diff --git a/calculated_remote_test.go b/calculated_remote_test.go index 066213e..6df893c 100644 --- a/calculated_remote_test.go +++ b/calculated_remote_test.go @@ -15,10 +15,10 @@ func TestCalculatedRemoteApply(t *testing.T) { require.NoError(t, err) input, err := netip.ParseAddr("10.0.10.182") - assert.NoError(t, err) + require.NoError(t, err) expected, err := netip.ParseAddr("192.168.1.182") - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, netAddrToProtoV4AddrPort(expected, 4242), c.ApplyV4(input)) @@ -28,10 +28,10 @@ func TestCalculatedRemoteApply(t *testing.T) { require.NoError(t, err) input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef") - assert.NoError(t, err) + require.NoError(t, err) expected, err = netip.ParseAddr("ffff:ffff:ffff:ffff:beef:beef:beef:beef") - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input)) @@ -41,10 +41,10 @@ func TestCalculatedRemoteApply(t *testing.T) { require.NoError(t, err) input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef") - assert.NoError(t, err) + require.NoError(t, err) expected, err = netip.ParseAddr("ffff:ffff:ffff:ffff:ffff:beef:beef:beef") - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input)) @@ -54,10 +54,10 @@ func TestCalculatedRemoteApply(t *testing.T) { require.NoError(t, err) input, err = netip.ParseAddr("beef:beef:beef:beef:beef:beef:beef:beef") - assert.NoError(t, err) + require.NoError(t, err) expected, err = netip.ParseAddr("ffff:ffff:ffff:beef:beef:beef:beef:beef") - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, netAddrToProtoV6AddrPort(expected, 4242), c.ApplyV6(input)) } diff --git a/cert/ca_pool_test.go b/cert/ca_pool_test.go index 2f9255f..b0fdd5f 100644 --- a/cert/ca_pool_test.go +++ b/cert/ca_pool_test.go @@ -6,6 +6,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestNewCAPoolFromBytes(t *testing.T) { @@ -82,12 +83,12 @@ k+coOv04r+zh33ISyhbsafnYduN17p2eD7CmHvHuerguXD9f32gcxo/KsFCKEjMe } p, err := NewCAPoolFromPEM([]byte(noNewLines)) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, p.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name) assert.Equal(t, p.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name) pp, err := NewCAPoolFromPEM([]byte(withNewLines)) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, pp.CAs["ce4e6c7a596996eb0d82a8875f0f0137a4b53ce22d2421c9fd7150e7a26f6300"].Certificate.Name(), rootCA.details.name) assert.Equal(t, pp.CAs["04c585fcd9a49b276df956a22b7ebea3bf23f1fca5a17c0b56ce2e626631969e"].Certificate.Name(), rootCA01.details.name) @@ -105,7 +106,7 @@ k+coOv04r+zh33ISyhbsafnYduN17p2eD7CmHvHuerguXD9f32gcxo/KsFCKEjMe assert.Len(t, pppp.CAs, 3) ppppp, err := NewCAPoolFromPEM([]byte(p256)) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, ppppp.CAs["552bf7d99bec1fc775a0e4c324bf6d8f789b3078f1919c7960d2e5e0c351ee97"].Certificate.Name(), rootCAP256.details.name) assert.Len(t, ppppp.CAs, 1) } @@ -115,21 +116,21 @@ func TestCertificateV1_Verify(t *testing.T) { c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) caPool := NewCAPool() - assert.NoError(t, caPool.AddCA(ca)) + require.NoError(t, caPool.AddCA(ca)) f, err := c.Fingerprint() - assert.NoError(t, err) + require.NoError(t, err) caPool.BlocklistFingerprint(f) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.EqualError(t, err, "certificate is in the block list") + require.EqualError(t, err, "certificate is in the block list") caPool.ResetCertBlocklist() _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c) - assert.EqualError(t, err, "root certificate is expired") + require.EqualError(t, err, "root certificate is expired") assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() { NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test cert2", time.Time{}, time.Time{}, nil, nil, nil) @@ -138,11 +139,11 @@ func TestCertificateV1_Verify(t *testing.T) { // Test group assertion ca, _, caKey, _ = NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) caPem, err := ca.MarshalPEM() - assert.NoError(t, err) + require.NoError(t, err) caPool = NewCAPool() b, err := caPool.AddCAFromPEM(caPem) - assert.NoError(t, err) + require.NoError(t, err) assert.Empty(t, b) assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() { @@ -150,9 +151,9 @@ func TestCertificateV1_Verify(t *testing.T) { }) c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test2", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) } func TestCertificateV1_VerifyP256(t *testing.T) { @@ -160,21 +161,21 @@ func TestCertificateV1_VerifyP256(t *testing.T) { c, _, _, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) caPool := NewCAPool() - assert.NoError(t, caPool.AddCA(ca)) + require.NoError(t, caPool.AddCA(ca)) f, err := c.Fingerprint() - assert.NoError(t, err) + require.NoError(t, err) caPool.BlocklistFingerprint(f) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.EqualError(t, err, "certificate is in the block list") + require.EqualError(t, err, "certificate is in the block list") caPool.ResetCertBlocklist() _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c) - assert.EqualError(t, err, "root certificate is expired") + require.EqualError(t, err, "root certificate is expired") assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() { NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) @@ -183,11 +184,11 @@ func TestCertificateV1_VerifyP256(t *testing.T) { // Test group assertion ca, _, caKey, _ = NewTestCaCert(Version1, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) caPem, err := ca.MarshalPEM() - assert.NoError(t, err) + require.NoError(t, err) caPool = NewCAPool() b, err := caPool.AddCAFromPEM(caPem) - assert.NoError(t, err) + require.NoError(t, err) assert.Empty(t, b) assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() { @@ -196,7 +197,7 @@ func TestCertificateV1_VerifyP256(t *testing.T) { c, _, _, _ = NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) } func TestCertificateV1_Verify_IPs(t *testing.T) { @@ -205,11 +206,11 @@ func TestCertificateV1_Verify_IPs(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) caPem, err := ca.MarshalPEM() - assert.NoError(t, err) + require.NoError(t, err) caPool := NewCAPool() b, err := caPool.AddCAFromPEM(caPem) - assert.NoError(t, err) + require.NoError(t, err) assert.Empty(t, b) // ip is outside the network @@ -245,25 +246,25 @@ func TestCertificateV1_Verify_IPs(t *testing.T) { cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) // Exact matches c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) // Exact matches reversed c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"}) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) // Exact matches reversed with just 1 c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"}) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) } func TestCertificateV1_Verify_Subnets(t *testing.T) { @@ -272,11 +273,11 @@ func TestCertificateV1_Verify_Subnets(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) caPem, err := ca.MarshalPEM() - assert.NoError(t, err) + require.NoError(t, err) caPool := NewCAPool() b, err := caPool.AddCAFromPEM(caPem) - assert.NoError(t, err) + require.NoError(t, err) assert.Empty(t, b) // ip is outside the network @@ -311,27 +312,27 @@ func TestCertificateV1_Verify_Subnets(t *testing.T) { cIp1 = mustParsePrefixUnmapped("10.0.1.0/16") cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") c, _, _, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) // Exact matches c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) // Exact matches reversed c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"}) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) // Exact matches reversed with just 1 c, _, _, _ = NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"}) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) } func TestCertificateV2_Verify(t *testing.T) { @@ -339,21 +340,21 @@ func TestCertificateV2_Verify(t *testing.T) { c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test cert", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) caPool := NewCAPool() - assert.NoError(t, caPool.AddCA(ca)) + require.NoError(t, caPool.AddCA(ca)) f, err := c.Fingerprint() - assert.NoError(t, err) + require.NoError(t, err) caPool.BlocklistFingerprint(f) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.EqualError(t, err, "certificate is in the block list") + require.EqualError(t, err, "certificate is in the block list") caPool.ResetCertBlocklist() _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c) - assert.EqualError(t, err, "root certificate is expired") + require.EqualError(t, err, "root certificate is expired") assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() { NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test cert2", time.Time{}, time.Time{}, nil, nil, nil) @@ -362,11 +363,11 @@ func TestCertificateV2_Verify(t *testing.T) { // Test group assertion ca, _, caKey, _ = NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) caPem, err := ca.MarshalPEM() - assert.NoError(t, err) + require.NoError(t, err) caPool = NewCAPool() b, err := caPool.AddCAFromPEM(caPem) - assert.NoError(t, err) + require.NoError(t, err) assert.Empty(t, b) assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() { @@ -374,9 +375,9 @@ func TestCertificateV2_Verify(t *testing.T) { }) c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test2", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) } func TestCertificateV2_VerifyP256(t *testing.T) { @@ -384,21 +385,21 @@ func TestCertificateV2_VerifyP256(t *testing.T) { c, _, _, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, nil) caPool := NewCAPool() - assert.NoError(t, caPool.AddCA(ca)) + require.NoError(t, caPool.AddCA(ca)) f, err := c.Fingerprint() - assert.NoError(t, err) + require.NoError(t, err) caPool.BlocklistFingerprint(f) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.EqualError(t, err, "certificate is in the block list") + require.EqualError(t, err, "certificate is in the block list") caPool.ResetCertBlocklist() _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now().Add(time.Hour*1000), c) - assert.EqualError(t, err, "root certificate is expired") + require.EqualError(t, err, "root certificate is expired") assert.PanicsWithError(t, "certificate is valid before the signing certificate", func() { NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) @@ -407,11 +408,11 @@ func TestCertificateV2_VerifyP256(t *testing.T) { // Test group assertion ca, _, caKey, _ = NewTestCaCert(Version2, Curve_P256, time.Now(), time.Now().Add(10*time.Minute), nil, nil, []string{"test1", "test2"}) caPem, err := ca.MarshalPEM() - assert.NoError(t, err) + require.NoError(t, err) caPool = NewCAPool() b, err := caPool.AddCAFromPEM(caPem) - assert.NoError(t, err) + require.NoError(t, err) assert.Empty(t, b) assert.PanicsWithError(t, "certificate contained a group not present on the signing ca: bad", func() { @@ -420,7 +421,7 @@ func TestCertificateV2_VerifyP256(t *testing.T) { c, _, _, _ = NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, nil, []string{"test1"}) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) } func TestCertificateV2_Verify_IPs(t *testing.T) { @@ -429,11 +430,11 @@ func TestCertificateV2_Verify_IPs(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) caPem, err := ca.MarshalPEM() - assert.NoError(t, err) + require.NoError(t, err) caPool := NewCAPool() b, err := caPool.AddCAFromPEM(caPem) - assert.NoError(t, err) + require.NoError(t, err) assert.Empty(t, b) // ip is outside the network @@ -469,25 +470,25 @@ func TestCertificateV2_Verify_IPs(t *testing.T) { cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{cIp1, cIp2}, nil, []string{"test"}) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) // Exact matches c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1, caIp2}, nil, []string{"test"}) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) // Exact matches reversed c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp2, caIp1}, nil, []string{"test"}) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) // Exact matches reversed with just 1 c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), []netip.Prefix{caIp1}, nil, []string{"test"}) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) } func TestCertificateV2_Verify_Subnets(t *testing.T) { @@ -496,11 +497,11 @@ func TestCertificateV2_Verify_Subnets(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Now(), time.Now().Add(10*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) caPem, err := ca.MarshalPEM() - assert.NoError(t, err) + require.NoError(t, err) caPool := NewCAPool() b, err := caPool.AddCAFromPEM(caPem) - assert.NoError(t, err) + require.NoError(t, err) assert.Empty(t, b) // ip is outside the network @@ -535,25 +536,25 @@ func TestCertificateV2_Verify_Subnets(t *testing.T) { cIp1 = mustParsePrefixUnmapped("10.0.1.0/16") cIp2 = mustParsePrefixUnmapped("192.168.0.1/25") c, _, _, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{cIp1, cIp2}, []string{"test"}) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) // Exact matches c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1, caIp2}, []string{"test"}) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) // Exact matches reversed c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp2, caIp1}, []string{"test"}) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) // Exact matches reversed with just 1 c, _, _, _ = NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Now(), time.Now().Add(5*time.Minute), nil, []netip.Prefix{caIp1}, []string{"test"}) - assert.NoError(t, err) + require.NoError(t, err) _, err = caPool.VerifyCertificate(time.Now(), c) - assert.NoError(t, err) + require.NoError(t, err) } diff --git a/cert/cert_v1_test.go b/cert/cert_v1_test.go index ea98b08..c687172 100644 --- a/cert/cert_v1_test.go +++ b/cert/cert_v1_test.go @@ -39,11 +39,11 @@ func TestCertificateV1_Marshal(t *testing.T) { } b, err := nc.Marshal() - assert.NoError(t, err) + require.NoError(t, err) //t.Log("Cert size:", len(b)) nc2, err := unmarshalCertificateV1(b, nil) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, Version1, nc.Version()) assert.Equal(t, Curve_CURVE25519, nc.Curve()) @@ -99,7 +99,7 @@ func TestCertificateV1_MarshalJSON(t *testing.T) { } b, err := nc.MarshalJSON() - assert.NoError(t, err) + require.NoError(t, err) assert.JSONEq( t, "{\"details\":{\"curve\":\"CURVE25519\",\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"isCa\":false,\"issuer\":\"1234567890abcedfghij1234567890ab\",\"name\":\"testing\",\"networks\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"publicKey\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"unsafeNetworks\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"3944c53d4267a229295b56cb2d27d459164c010ac97d655063ba421e0670f4ba\",\"signature\":\"313233343536373839306162636564666768696a313233343536373839306162\",\"version\":1}", @@ -110,47 +110,47 @@ func TestCertificateV1_MarshalJSON(t *testing.T) { func TestCertificateV1_VerifyPrivateKey(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil) err := ca.VerifyPrivateKey(Curve_CURVE25519, caKey) - assert.NoError(t, err) + require.NoError(t, err) _, _, caKey2, _ := NewTestCaCert(Version1, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil) - assert.NoError(t, err) + require.NoError(t, err) err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2) - assert.Error(t, err) + require.Error(t, err) c, _, priv, _ := NewTestCert(Version1, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv) - assert.NoError(t, err) + require.NoError(t, err) assert.Empty(t, b) assert.Equal(t, Curve_CURVE25519, curve) err = c.VerifyPrivateKey(Curve_CURVE25519, rawPriv) - assert.NoError(t, err) + require.NoError(t, err) _, priv2 := X25519Keypair() err = c.VerifyPrivateKey(Curve_CURVE25519, priv2) - assert.Error(t, err) + require.Error(t, err) } func TestCertificateV1_VerifyPrivateKeyP256(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) err := ca.VerifyPrivateKey(Curve_P256, caKey) - assert.NoError(t, err) + require.NoError(t, err) _, _, caKey2, _ := NewTestCaCert(Version1, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) - assert.NoError(t, err) + require.NoError(t, err) err = ca.VerifyPrivateKey(Curve_P256, caKey2) - assert.Error(t, err) + require.Error(t, err) c, _, priv, _ := NewTestCert(Version1, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv) - assert.NoError(t, err) + require.NoError(t, err) assert.Empty(t, b) assert.Equal(t, Curve_P256, curve) err = c.VerifyPrivateKey(Curve_P256, rawPriv) - assert.NoError(t, err) + require.NoError(t, err) _, priv2 := P256Keypair() err = c.VerifyPrivateKey(Curve_P256, priv2) - assert.Error(t, err) + require.Error(t, err) } // Ensure that upgrading the protobuf library does not change how certificates @@ -186,7 +186,7 @@ func TestMarshalingCertificateV1Consistency(t *testing.T) { assert.Equal(t, "0a8e010a0774657374696e671212828284508080fcff0f8182845080feffff0f1a12838284488080fcff0f8282844880feffff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328cd1c30cdb8ccf0af073a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf1220313233343536373839306162636564666768696a313233343536373839306162", fmt.Sprintf("%x", b)) b, err = proto.Marshal(nc.getRawDetails()) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "0a0774657374696e671212828284508080fcff0f8182845080feffff0f1a12838284488080fcff0f8282844880feffff0f220b746573742d67726f757031220b746573742d67726f757032220b746573742d67726f75703328cd1c30cdb8ccf0af073a20313233343536373839306162636564666768696a3132333435363738393061624a081234567890abcedf", fmt.Sprintf("%x", b)) } @@ -201,7 +201,7 @@ func TestUnmarshalCertificateV1(t *testing.T) { // Test that we don't panic with an invalid certificate (#332) data := []byte("\x98\x00\x00") _, err := unmarshalCertificateV1(data, nil) - assert.EqualError(t, err, "encoded Details was nil") + require.EqualError(t, err, "encoded Details was nil") } func appendByteSlices(b ...[]byte) []byte { diff --git a/cert/cert_v2_test.go b/cert/cert_v2_test.go index 6d55750..c84f8c9 100644 --- a/cert/cert_v2_test.go +++ b/cert/cert_v2_test.go @@ -49,7 +49,7 @@ func TestCertificateV2_Marshal(t *testing.T) { //t.Log("Cert size:", len(b)) nc2, err := unmarshalCertificateV2(b, nil, Curve_CURVE25519) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, Version2, nc.Version()) assert.Equal(t, Curve_CURVE25519, nc.Curve()) @@ -114,14 +114,14 @@ func TestCertificateV2_MarshalJSON(t *testing.T) { } b, err := nc.MarshalJSON() - assert.ErrorIs(t, err, ErrMissingDetails) + require.ErrorIs(t, err, ErrMissingDetails) rd, err := nc.details.Marshal() - assert.NoError(t, err) + require.NoError(t, err) nc.rawDetails = rd b, err = nc.MarshalJSON() - assert.NoError(t, err) + require.NoError(t, err) assert.JSONEq( t, "{\"curve\":\"CURVE25519\",\"details\":{\"groups\":[\"test-group1\",\"test-group2\",\"test-group3\"],\"isCa\":false,\"issuer\":\"1234567890abcedf1234567890abcedf\",\"name\":\"testing\",\"networks\":[\"10.1.1.1/24\",\"10.1.1.2/16\"],\"notAfter\":\"0000-11-30T02:00:00Z\",\"notBefore\":\"0000-11-30T01:00:00Z\",\"unsafeNetworks\":[\"9.1.1.2/24\",\"9.1.1.3/16\"]},\"fingerprint\":\"152d9a7400c1e001cb76cffd035215ebb351f69eeb797f7f847dd086e15e56dd\",\"publicKey\":\"3132333435363738393061626365646631323334353637383930616263656466\",\"signature\":\"31323334353637383930616263656466313233343536373839306162636564663132333435363738393061626365646631323334353637383930616263656466\",\"version\":2}", @@ -132,85 +132,85 @@ func TestCertificateV2_MarshalJSON(t *testing.T) { func TestCertificateV2_VerifyPrivateKey(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(Version2, Curve_CURVE25519, time.Time{}, time.Time{}, nil, nil, nil) err := ca.VerifyPrivateKey(Curve_CURVE25519, caKey) - assert.NoError(t, err) + require.NoError(t, err) err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey[:16]) - assert.ErrorIs(t, err, ErrInvalidPrivateKey) + require.ErrorIs(t, err, ErrInvalidPrivateKey) _, caKey2, err := ed25519.GenerateKey(rand.Reader) require.NoError(t, err) err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey2) - assert.ErrorIs(t, err, ErrPublicPrivateKeyMismatch) + require.ErrorIs(t, err, ErrPublicPrivateKeyMismatch) c, _, priv, _ := NewTestCert(Version2, Curve_CURVE25519, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv) - assert.NoError(t, err) + require.NoError(t, err) assert.Empty(t, b) assert.Equal(t, Curve_CURVE25519, curve) err = c.VerifyPrivateKey(Curve_CURVE25519, rawPriv) - assert.NoError(t, err) + require.NoError(t, err) _, priv2 := X25519Keypair() err = c.VerifyPrivateKey(Curve_P256, priv2) - assert.ErrorIs(t, err, ErrPublicPrivateCurveMismatch) + require.ErrorIs(t, err, ErrPublicPrivateCurveMismatch) err = c.VerifyPrivateKey(Curve_CURVE25519, priv2) - assert.ErrorIs(t, err, ErrPublicPrivateKeyMismatch) + require.ErrorIs(t, err, ErrPublicPrivateKeyMismatch) err = c.VerifyPrivateKey(Curve_CURVE25519, priv2[:16]) - assert.ErrorIs(t, err, ErrInvalidPrivateKey) + require.ErrorIs(t, err, ErrInvalidPrivateKey) ac, ok := c.(*certificateV2) require.True(t, ok) ac.curve = Curve(99) err = c.VerifyPrivateKey(Curve(99), priv2) - assert.EqualError(t, err, "invalid curve: 99") + require.EqualError(t, err, "invalid curve: 99") ca2, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) err = ca.VerifyPrivateKey(Curve_CURVE25519, caKey) - assert.NoError(t, err) + require.NoError(t, err) err = ca2.VerifyPrivateKey(Curve_P256, caKey2[:16]) - assert.ErrorIs(t, err, ErrInvalidPrivateKey) + require.ErrorIs(t, err, ErrInvalidPrivateKey) c, _, priv, _ = NewTestCert(Version2, Curve_P256, ca2, caKey2, "test", time.Time{}, time.Time{}, nil, nil, nil) rawPriv, b, curve, err = UnmarshalPrivateKeyFromPEM(priv) err = c.VerifyPrivateKey(Curve_P256, priv[:16]) - assert.ErrorIs(t, err, ErrInvalidPrivateKey) + require.ErrorIs(t, err, ErrInvalidPrivateKey) err = c.VerifyPrivateKey(Curve_P256, priv) - assert.ErrorIs(t, err, ErrInvalidPrivateKey) + require.ErrorIs(t, err, ErrInvalidPrivateKey) aCa, ok := ca2.(*certificateV2) require.True(t, ok) aCa.curve = Curve(99) err = aCa.VerifyPrivateKey(Curve(99), priv2) - assert.EqualError(t, err, "invalid curve: 99") + require.EqualError(t, err, "invalid curve: 99") } func TestCertificateV2_VerifyPrivateKeyP256(t *testing.T) { ca, _, caKey, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) err := ca.VerifyPrivateKey(Curve_P256, caKey) - assert.NoError(t, err) + require.NoError(t, err) _, _, caKey2, _ := NewTestCaCert(Version2, Curve_P256, time.Time{}, time.Time{}, nil, nil, nil) - assert.NoError(t, err) + require.NoError(t, err) err = ca.VerifyPrivateKey(Curve_P256, caKey2) - assert.Error(t, err) + require.Error(t, err) c, _, priv, _ := NewTestCert(Version2, Curve_P256, ca, caKey, "test", time.Time{}, time.Time{}, nil, nil, nil) rawPriv, b, curve, err := UnmarshalPrivateKeyFromPEM(priv) - assert.NoError(t, err) + require.NoError(t, err) assert.Empty(t, b) assert.Equal(t, Curve_P256, curve) err = c.VerifyPrivateKey(Curve_P256, rawPriv) - assert.NoError(t, err) + require.NoError(t, err) _, priv2 := P256Keypair() err = c.VerifyPrivateKey(Curve_P256, priv2) - assert.Error(t, err) + require.Error(t, err) } func TestCertificateV2_Copy(t *testing.T) { @@ -223,7 +223,7 @@ func TestCertificateV2_Copy(t *testing.T) { func TestUnmarshalCertificateV2(t *testing.T) { data := []byte("\x98\x00\x00") _, err := unmarshalCertificateV2(data, nil, Curve_CURVE25519) - assert.EqualError(t, err, "bad wire format") + require.EqualError(t, err, "bad wire format") } func TestCertificateV2_marshalForSigningStability(t *testing.T) { diff --git a/cert/crypto_test.go b/cert/crypto_test.go index c43eed7..ee671c0 100644 --- a/cert/crypto_test.go +++ b/cert/crypto_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "golang.org/x/crypto/argon2" ) @@ -61,33 +62,33 @@ qrlJ69wer3ZUHFXA // Success test case curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, keyBundle) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, Curve_CURVE25519, curve) assert.Len(t, k, 64) assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) // Fail due to short key curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest) - assert.EqualError(t, err, "key was not 64 bytes, is invalid ed25519 private key") + require.EqualError(t, err, "key was not 64 bytes, is invalid ed25519 private key") assert.Nil(t, k) assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) // Fail due to invalid banner curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest) - assert.EqualError(t, err, "bytes did not contain a proper nebula encrypted Ed25519/ECDSA private key banner") + require.EqualError(t, err, "bytes did not contain a proper nebula encrypted Ed25519/ECDSA private key banner") assert.Nil(t, k) assert.Equal(t, rest, invalidPem) // Fail due to ivalid PEM format, because // it's missing the requisite pre-encapsulation boundary. curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey(passphrase, rest) - assert.EqualError(t, err, "input did not contain a valid PEM encoded block") + require.EqualError(t, err, "input did not contain a valid PEM encoded block") assert.Nil(t, k) assert.Equal(t, rest, invalidPem) // Fail due to invalid passphrase curve, k, rest, err = DecryptAndUnmarshalSigningPrivateKey([]byte("invalid passphrase"), privKey) - assert.EqualError(t, err, "invalid passphrase or corrupt private key") + require.EqualError(t, err, "invalid passphrase or corrupt private key") assert.Nil(t, k) assert.Equal(t, []byte{}, rest) } @@ -99,14 +100,14 @@ func TestEncryptAndMarshalSigningPrivateKey(t *testing.T) { bytes := []byte("AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA") kdfParams := NewArgon2Parameters(64*1024, 4, 3) key, err := EncryptAndMarshalSigningPrivateKey(Curve_CURVE25519, bytes, passphrase, kdfParams) - assert.NoError(t, err) + require.NoError(t, err) // Verify the "key" can be decrypted successfully curve, k, rest, err := DecryptAndUnmarshalSigningPrivateKey(passphrase, key) assert.Len(t, k, 64) assert.Equal(t, Curve_CURVE25519, curve) assert.Equal(t, []byte{}, rest) - assert.NoError(t, err) + require.NoError(t, err) // EncryptAndMarshalEd25519PrivateKey does not create any errors itself } diff --git a/cert/pem_test.go b/cert/pem_test.go index 9ad8a69..6e49249 100644 --- a/cert/pem_test.go +++ b/cert/pem_test.go @@ -4,6 +4,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestUnmarshalCertificateFromPEM(t *testing.T) { @@ -35,20 +36,20 @@ bzBEr00kERQxxTzTsH8cpYEgRoipvmExvg8WP8NdAJEYJosB cert, rest, err := UnmarshalCertificateFromPEM(certBundle) assert.NotNil(t, cert) assert.Equal(t, rest, append(badBanner, invalidPem...)) - assert.NoError(t, err) + require.NoError(t, err) // Fail due to invalid banner. cert, rest, err = UnmarshalCertificateFromPEM(rest) assert.Nil(t, cert) assert.Equal(t, rest, invalidPem) - assert.EqualError(t, err, "bytes did not contain a proper certificate banner") + require.EqualError(t, err, "bytes did not contain a proper certificate banner") // Fail due to ivalid PEM format, because // it's missing the requisite pre-encapsulation boundary. cert, rest, err = UnmarshalCertificateFromPEM(rest) assert.Nil(t, cert) assert.Equal(t, rest, invalidPem) - assert.EqualError(t, err, "input did not contain a valid PEM encoded block") + require.EqualError(t, err, "input did not contain a valid PEM encoded block") } func TestUnmarshalSigningPrivateKeyFromPEM(t *testing.T) { @@ -84,33 +85,33 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA assert.Len(t, k, 64) assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem)) assert.Equal(t, Curve_CURVE25519, curve) - assert.NoError(t, err) + require.NoError(t, err) // Success test case k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest) assert.Len(t, k, 32) assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) assert.Equal(t, Curve_P256, curve) - assert.NoError(t, err) + require.NoError(t, err) // Fail due to short key k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest) assert.Nil(t, k) assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) - assert.EqualError(t, err, "key was not 64 bytes, is invalid Ed25519 private key") + require.EqualError(t, err, "key was not 64 bytes, is invalid Ed25519 private key") // Fail due to invalid banner k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest) assert.Nil(t, k) assert.Equal(t, rest, invalidPem) - assert.EqualError(t, err, "bytes did not contain a proper Ed25519/ECDSA private key banner") + require.EqualError(t, err, "bytes did not contain a proper Ed25519/ECDSA private key banner") // Fail due to ivalid PEM format, because // it's missing the requisite pre-encapsulation boundary. k, rest, curve, err = UnmarshalSigningPrivateKeyFromPEM(rest) assert.Nil(t, k) assert.Equal(t, rest, invalidPem) - assert.EqualError(t, err, "input did not contain a valid PEM encoded block") + require.EqualError(t, err, "input did not contain a valid PEM encoded block") } func TestUnmarshalPrivateKeyFromPEM(t *testing.T) { @@ -146,33 +147,33 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= assert.Len(t, k, 32) assert.Equal(t, rest, appendByteSlices(privP256Key, shortKey, invalidBanner, invalidPem)) assert.Equal(t, Curve_CURVE25519, curve) - assert.NoError(t, err) + require.NoError(t, err) // Success test case k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest) assert.Len(t, k, 32) assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) assert.Equal(t, Curve_P256, curve) - assert.NoError(t, err) + require.NoError(t, err) // Fail due to short key k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest) assert.Nil(t, k) assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) - assert.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 private key") + require.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 private key") // Fail due to invalid banner k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest) assert.Nil(t, k) assert.Equal(t, rest, invalidPem) - assert.EqualError(t, err, "bytes did not contain a proper private key banner") + require.EqualError(t, err, "bytes did not contain a proper private key banner") // Fail due to ivalid PEM format, because // it's missing the requisite pre-encapsulation boundary. k, rest, curve, err = UnmarshalPrivateKeyFromPEM(rest) assert.Nil(t, k) assert.Equal(t, rest, invalidPem) - assert.EqualError(t, err, "input did not contain a valid PEM encoded block") + require.EqualError(t, err, "input did not contain a valid PEM encoded block") } func TestUnmarshalPublicKeyFromPEM(t *testing.T) { @@ -202,7 +203,7 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle) assert.Len(t, k, 32) assert.Equal(t, Curve_CURVE25519, curve) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) // Fail due to short key @@ -210,13 +211,13 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= assert.Nil(t, k) assert.Equal(t, Curve_CURVE25519, curve) assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) - assert.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key") + require.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key") // Fail due to invalid banner k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) assert.Nil(t, k) assert.Equal(t, Curve_CURVE25519, curve) - assert.EqualError(t, err, "bytes did not contain a proper public key banner") + require.EqualError(t, err, "bytes did not contain a proper public key banner") assert.Equal(t, rest, invalidPem) // Fail due to ivalid PEM format, because @@ -225,7 +226,7 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= assert.Nil(t, k) assert.Equal(t, Curve_CURVE25519, curve) assert.Equal(t, rest, invalidPem) - assert.EqualError(t, err, "input did not contain a valid PEM encoded block") + require.EqualError(t, err, "input did not contain a valid PEM encoded block") } func TestUnmarshalX25519PublicKey(t *testing.T) { @@ -260,14 +261,14 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= // Success test case k, rest, curve, err := UnmarshalPublicKeyFromPEM(keyBundle) assert.Len(t, k, 32) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, rest, appendByteSlices(pubP256Key, shortKey, invalidBanner, invalidPem)) assert.Equal(t, Curve_CURVE25519, curve) // Success test case k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) assert.Len(t, k, 65) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, rest, appendByteSlices(shortKey, invalidBanner, invalidPem)) assert.Equal(t, Curve_P256, curve) @@ -275,12 +276,12 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) assert.Nil(t, k) assert.Equal(t, rest, appendByteSlices(invalidBanner, invalidPem)) - assert.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key") + require.EqualError(t, err, "key was not 32 bytes, is invalid CURVE25519 public key") // Fail due to invalid banner k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) assert.Nil(t, k) - assert.EqualError(t, err, "bytes did not contain a proper public key banner") + require.EqualError(t, err, "bytes did not contain a proper public key banner") assert.Equal(t, rest, invalidPem) // Fail due to ivalid PEM format, because @@ -288,5 +289,5 @@ AAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAAA= k, rest, curve, err = UnmarshalPublicKeyFromPEM(rest) assert.Nil(t, k) assert.Equal(t, rest, invalidPem) - assert.EqualError(t, err, "input did not contain a valid PEM encoded block") + require.EqualError(t, err, "input did not contain a valid PEM encoded block") } diff --git a/cert/sign_test.go b/cert/sign_test.go index 30d8480..e6f43cd 100644 --- a/cert/sign_test.go +++ b/cert/sign_test.go @@ -10,6 +10,7 @@ import ( "time" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestCertificateV1_Sign(t *testing.T) { @@ -37,14 +38,14 @@ func TestCertificateV1_Sign(t *testing.T) { pub, priv, err := ed25519.GenerateKey(rand.Reader) c, err := tbs.Sign(&certificateV1{details: detailsV1{notBefore: before, notAfter: after}}, Curve_CURVE25519, priv) - assert.NoError(t, err) + require.NoError(t, err) assert.NotNil(t, c) assert.True(t, c.CheckSignature(pub)) b, err := c.Marshal() - assert.NoError(t, err) + require.NoError(t, err) uc, err := unmarshalCertificateV1(b, nil) - assert.NoError(t, err) + require.NoError(t, err) assert.NotNil(t, uc) } @@ -73,18 +74,18 @@ func TestCertificateV1_SignP256(t *testing.T) { } priv, err := ecdsa.GenerateKey(elliptic.P256(), rand.Reader) - assert.NoError(t, err) + require.NoError(t, err) pub := elliptic.Marshal(elliptic.P256(), priv.PublicKey.X, priv.PublicKey.Y) rawPriv := priv.D.FillBytes(make([]byte, 32)) c, err := tbs.Sign(&certificateV1{details: detailsV1{notBefore: before, notAfter: after}}, Curve_P256, rawPriv) - assert.NoError(t, err) + require.NoError(t, err) assert.NotNil(t, c) assert.True(t, c.CheckSignature(pub)) b, err := c.Marshal() - assert.NoError(t, err) + require.NoError(t, err) uc, err := unmarshalCertificateV1(b, nil) - assert.NoError(t, err) + require.NoError(t, err) assert.NotNil(t, uc) } diff --git a/cmd/nebula-cert/ca_test.go b/cmd/nebula-cert/ca_test.go index 71b69be..189fc02 100644 --- a/cmd/nebula-cert/ca_test.go +++ b/cmd/nebula-cert/ca_test.go @@ -14,6 +14,7 @@ import ( "github.com/slackhq/nebula/cert" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func Test_caSummary(t *testing.T) { @@ -106,34 +107,34 @@ func Test_ca(t *testing.T) { ob.Reset() eb.Reset() args := []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey"} - assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) + require.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) // create temp key file keyF, err := os.CreateTemp("", "test.key") - assert.NoError(t, err) - assert.NoError(t, os.Remove(keyF.Name())) + require.NoError(t, err) + require.NoError(t, os.Remove(keyF.Name())) // failed cert write ob.Reset() eb.Reset() args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name()} - assert.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError) + require.EqualError(t, ca(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) // create temp cert file crtF, err := os.CreateTemp("", "test.crt") - assert.NoError(t, err) - assert.NoError(t, os.Remove(crtF.Name())) - assert.NoError(t, os.Remove(keyF.Name())) + require.NoError(t, err) + require.NoError(t, os.Remove(crtF.Name())) + require.NoError(t, os.Remove(keyF.Name())) // test proper cert with removed empty groups and subnets ob.Reset() eb.Reset() args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} - assert.NoError(t, ca(args, ob, eb, nopw)) + require.NoError(t, ca(args, ob, eb, nopw)) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) @@ -142,13 +143,13 @@ func Test_ca(t *testing.T) { lKey, b, c, err := cert.UnmarshalSigningPrivateKeyFromPEM(rb) assert.Equal(t, cert.Curve_CURVE25519, c) assert.Empty(t, b) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, lKey, 64) rb, _ = os.ReadFile(crtF.Name()) lCrt, b, err := cert.UnmarshalCertificateFromPEM(rb) assert.Empty(t, b) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "test", lCrt.Name()) assert.Empty(t, lCrt.Networks()) @@ -166,7 +167,7 @@ func Test_ca(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} - assert.NoError(t, ca(args, ob, eb, testpw)) + require.NoError(t, ca(args, ob, eb, testpw)) assert.Equal(t, pwPromptOb, ob.String()) assert.Equal(t, "", eb.String()) @@ -174,7 +175,7 @@ func Test_ca(t *testing.T) { rb, _ = os.ReadFile(keyF.Name()) k, _ := pem.Decode(rb) ned, err := cert.UnmarshalNebulaEncryptedData(k.Bytes) - assert.NoError(t, err) + require.NoError(t, err) // we won't know salt in advance, so just check start of string assert.Equal(t, uint32(2*1024*1024), ned.EncryptionMetadata.Argon2Parameters.Memory) assert.Equal(t, uint8(4), ned.EncryptionMetadata.Argon2Parameters.Parallelism) @@ -184,7 +185,7 @@ func Test_ca(t *testing.T) { var curve cert.Curve curve, lKey, b, err = cert.DecryptAndUnmarshalSigningPrivateKey(passphrase, rb) assert.Equal(t, cert.Curve_CURVE25519, curve) - assert.NoError(t, err) + require.NoError(t, err) assert.Empty(t, b) assert.Len(t, lKey, 64) @@ -194,7 +195,7 @@ func Test_ca(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} - assert.Error(t, ca(args, ob, eb, errpw)) + require.Error(t, ca(args, ob, eb, errpw)) assert.Equal(t, pwPromptOb, ob.String()) assert.Equal(t, "", eb.String()) @@ -204,7 +205,7 @@ func Test_ca(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-version", "1", "-encrypt", "-name", "test", "-duration", "100m", "-groups", "1,2,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} - assert.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext") + require.EqualError(t, ca(args, ob, eb, nopw), "no passphrase specified, remove -encrypt flag to write out-key in plaintext") assert.Equal(t, strings.Repeat(pwPromptOb, 5), ob.String()) // prompts 5 times before giving up assert.Equal(t, "", eb.String()) @@ -214,13 +215,13 @@ func Test_ca(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} - assert.NoError(t, ca(args, ob, eb, nopw)) + require.NoError(t, ca(args, ob, eb, nopw)) // test that we won't overwrite existing certificate file ob.Reset() eb.Reset() args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} - assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA key: "+keyF.Name()) + require.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA key: "+keyF.Name()) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) @@ -229,7 +230,7 @@ func Test_ca(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-version", "1", "-name", "test", "-duration", "100m", "-groups", "1,, 2 , ,,,3,4,5", "-out-crt", crtF.Name(), "-out-key", keyF.Name()} - assert.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA cert: "+crtF.Name()) + require.EqualError(t, ca(args, ob, eb, nopw), "refusing to overwrite existing CA cert: "+crtF.Name()) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) os.Remove(keyF.Name()) diff --git a/cmd/nebula-cert/keygen_test.go b/cmd/nebula-cert/keygen_test.go index 3427254..7eed5d2 100644 --- a/cmd/nebula-cert/keygen_test.go +++ b/cmd/nebula-cert/keygen_test.go @@ -7,6 +7,7 @@ import ( "github.com/slackhq/nebula/cert" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func Test_keygenSummary(t *testing.T) { @@ -47,33 +48,33 @@ func Test_keygen(t *testing.T) { ob.Reset() eb.Reset() args := []string{"-out-pub", "/do/not/write/pleasepub", "-out-key", "/do/not/write/pleasekey"} - assert.EqualError(t, keygen(args, ob, eb), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) + require.EqualError(t, keygen(args, ob, eb), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) // create temp key file keyF, err := os.CreateTemp("", "test.key") - assert.NoError(t, err) + require.NoError(t, err) defer os.Remove(keyF.Name()) // failed pub write ob.Reset() eb.Reset() args = []string{"-out-pub", "/do/not/write/pleasepub", "-out-key", keyF.Name()} - assert.EqualError(t, keygen(args, ob, eb), "error while writing out-pub: open /do/not/write/pleasepub: "+NoSuchDirError) + require.EqualError(t, keygen(args, ob, eb), "error while writing out-pub: open /do/not/write/pleasepub: "+NoSuchDirError) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) // create temp pub file pubF, err := os.CreateTemp("", "test.pub") - assert.NoError(t, err) + require.NoError(t, err) defer os.Remove(pubF.Name()) // test proper keygen ob.Reset() eb.Reset() args = []string{"-out-pub", pubF.Name(), "-out-key", keyF.Name()} - assert.NoError(t, keygen(args, ob, eb)) + require.NoError(t, keygen(args, ob, eb)) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) @@ -82,13 +83,13 @@ func Test_keygen(t *testing.T) { lKey, b, curve, err := cert.UnmarshalPrivateKeyFromPEM(rb) assert.Equal(t, cert.Curve_CURVE25519, curve) assert.Empty(t, b) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, lKey, 32) rb, _ = os.ReadFile(pubF.Name()) lPub, b, curve, err := cert.UnmarshalPublicKeyFromPEM(rb) assert.Equal(t, cert.Curve_CURVE25519, curve) assert.Empty(t, b) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, lPub, 32) } diff --git a/cmd/nebula-cert/main_test.go b/cmd/nebula-cert/main_test.go index f332895..2e92e7e 100644 --- a/cmd/nebula-cert/main_test.go +++ b/cmd/nebula-cert/main_test.go @@ -9,6 +9,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func Test_help(t *testing.T) { @@ -79,7 +80,7 @@ func assertHelpError(t *testing.T, err error, msg string) { t.Fatal(fmt.Sprintf("err was not a helpError: %q, expected %q", err, msg)) } - assert.EqualError(t, err, msg) + require.EqualError(t, err, msg) } func optionalPkcs11String(msg string) string { diff --git a/cmd/nebula-cert/print_test.go b/cmd/nebula-cert/print_test.go index 77e98e6..061e472 100644 --- a/cmd/nebula-cert/print_test.go +++ b/cmd/nebula-cert/print_test.go @@ -12,6 +12,7 @@ import ( "github.com/slackhq/nebula/cert" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func Test_printSummary(t *testing.T) { @@ -52,20 +53,20 @@ func Test_printCert(t *testing.T) { err = printCert([]string{"-path", "does_not_exist"}, ob, eb) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) - assert.EqualError(t, err, "unable to read cert; open does_not_exist: "+NoSuchFileError) + require.EqualError(t, err, "unable to read cert; open does_not_exist: "+NoSuchFileError) // invalid cert at path ob.Reset() eb.Reset() tf, err := os.CreateTemp("", "print-cert") - assert.NoError(t, err) + require.NoError(t, err) defer os.Remove(tf.Name()) tf.WriteString("-----BEGIN NOPE-----") err = printCert([]string{"-path", tf.Name()}, ob, eb) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) - assert.EqualError(t, err, "error while unmarshaling cert: input did not contain a valid PEM encoded block") + require.EqualError(t, err, "error while unmarshaling cert: input did not contain a valid PEM encoded block") // test multiple certs ob.Reset() @@ -84,7 +85,7 @@ func Test_printCert(t *testing.T) { fp, _ := c.Fingerprint() pk := hex.EncodeToString(c.PublicKey()) sig := hex.EncodeToString(c.Signature()) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal( t, //"NebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\nNebulaCertificate {\n\tDetails {\n\t\tName: test\n\t\tIps: []\n\t\tSubnets: []\n\t\tGroups: [\n\t\t\t\"hi\"\n\t\t]\n\t\tNot before: 0001-01-01 00:00:00 +0000 UTC\n\t\tNot After: 0001-01-01 00:00:00 +0000 UTC\n\t\tIs CA: false\n\t\tIssuer: "+c.Issuer()+"\n\t\tPublic key: "+pk+"\n\t\tCurve: CURVE25519\n\t}\n\tFingerprint: "+fp+"\n\tSignature: "+sig+"\n}\n", @@ -169,7 +170,7 @@ func Test_printCert(t *testing.T) { fp, _ = c.Fingerprint() pk = hex.EncodeToString(c.PublicKey()) sig = hex.EncodeToString(c.Signature()) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal( t, `[{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1},{"details":{"curve":"CURVE25519","groups":["hi"],"isCa":false,"issuer":"`+c.Issuer()+`","name":"test","networks":["10.0.0.123/8"],"notAfter":"0001-01-01T00:00:00Z","notBefore":"0001-01-01T00:00:00Z","publicKey":"`+pk+`","unsafeNetworks":[]},"fingerprint":"`+fp+`","signature":"`+sig+`","version":1}] diff --git a/cmd/nebula-cert/sign_test.go b/cmd/nebula-cert/sign_test.go index 4b242a4..b2bba76 100644 --- a/cmd/nebula-cert/sign_test.go +++ b/cmd/nebula-cert/sign_test.go @@ -13,6 +13,7 @@ import ( "github.com/slackhq/nebula/cert" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "golang.org/x/crypto/ed25519" ) @@ -103,17 +104,17 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() args := []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", "./nope", "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} - assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-key: open ./nope: "+NoSuchFileError) + require.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-key: open ./nope: "+NoSuchFileError) // failed to unmarshal key ob.Reset() eb.Reset() caKeyF, err := os.CreateTemp("", "sign-cert.key") - assert.NoError(t, err) + require.NoError(t, err) defer os.Remove(caKeyF.Name()) args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} - assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-key: input did not contain a valid PEM encoded block") + require.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-key: input did not contain a valid PEM encoded block") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -125,7 +126,7 @@ func Test_signCert(t *testing.T) { // failed to read cert args = []string{"-version", "1", "-ca-crt", "./nope", "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} - assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-crt: open ./nope: "+NoSuchFileError) + require.EqualError(t, signCert(args, ob, eb, nopw), "error while reading ca-crt: open ./nope: "+NoSuchFileError) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -133,11 +134,11 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() caCrtF, err := os.CreateTemp("", "sign-cert.crt") - assert.NoError(t, err) + require.NoError(t, err) defer os.Remove(caCrtF.Name()) args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m"} - assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-crt: input did not contain a valid PEM encoded block") + require.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing ca-crt: input did not contain a valid PEM encoded block") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -148,7 +149,7 @@ func Test_signCert(t *testing.T) { // failed to read pub args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", "./nope", "-duration", "100m"} - assert.EqualError(t, signCert(args, ob, eb, nopw), "error while reading in-pub: open ./nope: "+NoSuchFileError) + require.EqualError(t, signCert(args, ob, eb, nopw), "error while reading in-pub: open ./nope: "+NoSuchFileError) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -156,11 +157,11 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() inPubF, err := os.CreateTemp("", "in.pub") - assert.NoError(t, err) + require.NoError(t, err) defer os.Remove(inPubF.Name()) args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-in-pub", inPubF.Name(), "-duration", "100m"} - assert.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing in-pub: input did not contain a valid PEM encoded block") + require.EqualError(t, signCert(args, ob, eb, nopw), "error while parsing in-pub: input did not contain a valid PEM encoded block") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -210,14 +211,14 @@ func Test_signCert(t *testing.T) { // mismatched ca key _, caPriv2, _ := ed25519.GenerateKey(rand.Reader) caKeyF2, err := os.CreateTemp("", "sign-cert-2.key") - assert.NoError(t, err) + require.NoError(t, err) defer os.Remove(caKeyF2.Name()) caKeyF2.Write(cert.MarshalSigningPrivateKeyToPEM(cert.Curve_CURVE25519, caPriv2)) ob.Reset() eb.Reset() args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF2.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "nope", "-out-key", "nope", "-duration", "100m", "-subnets", "a"} - assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to sign, root certificate does not match private key") + require.EqualError(t, signCert(args, ob, eb, nopw), "refusing to sign, root certificate does not match private key") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -225,34 +226,34 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", "/do/not/write/pleasekey", "-duration", "100m", "-subnets", "10.1.1.1/32"} - assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) + require.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-key: open /do/not/write/pleasekey: "+NoSuchDirError) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) // create temp key file keyF, err := os.CreateTemp("", "test.key") - assert.NoError(t, err) + require.NoError(t, err) os.Remove(keyF.Name()) // failed cert write ob.Reset() eb.Reset() args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", "/do/not/write/pleasecrt", "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32"} - assert.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError) + require.EqualError(t, signCert(args, ob, eb, nopw), "error while writing out-crt: open /do/not/write/pleasecrt: "+NoSuchDirError) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) os.Remove(keyF.Name()) // create temp cert file crtF, err := os.CreateTemp("", "test.crt") - assert.NoError(t, err) + require.NoError(t, err) os.Remove(crtF.Name()) // test proper cert with removed empty groups and subnets ob.Reset() eb.Reset() args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.NoError(t, signCert(args, ob, eb, nopw)) + require.NoError(t, signCert(args, ob, eb, nopw)) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -261,13 +262,13 @@ func Test_signCert(t *testing.T) { lKey, b, curve, err := cert.UnmarshalPrivateKeyFromPEM(rb) assert.Equal(t, cert.Curve_CURVE25519, curve) assert.Empty(t, b) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, lKey, 32) rb, _ = os.ReadFile(crtF.Name()) lCrt, b, err := cert.UnmarshalCertificateFromPEM(rb) assert.Empty(t, b) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "test", lCrt.Name()) assert.Equal(t, "1.1.1.1/24", lCrt.Networks()[0].String()) @@ -295,7 +296,7 @@ func Test_signCert(t *testing.T) { ob.Reset() eb.Reset() args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-in-pub", inPubF.Name(), "-duration", "100m", "-groups", "1"} - assert.NoError(t, signCert(args, ob, eb, nopw)) + require.NoError(t, signCert(args, ob, eb, nopw)) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -303,7 +304,7 @@ func Test_signCert(t *testing.T) { rb, _ = os.ReadFile(crtF.Name()) lCrt, b, err = cert.UnmarshalCertificateFromPEM(rb) assert.Empty(t, b) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, lCrt.PublicKey(), inPub) // test refuse to sign cert with duration beyond root @@ -312,7 +313,7 @@ func Test_signCert(t *testing.T) { os.Remove(keyF.Name()) os.Remove(crtF.Name()) args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "1000m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.EqualError(t, signCert(args, ob, eb, nopw), "error while signing: certificate expires after signing certificate") + require.EqualError(t, signCert(args, ob, eb, nopw), "error while signing: certificate expires after signing certificate") assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -320,14 +321,14 @@ func Test_signCert(t *testing.T) { os.Remove(keyF.Name()) os.Remove(crtF.Name()) args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.NoError(t, signCert(args, ob, eb, nopw)) + require.NoError(t, signCert(args, ob, eb, nopw)) // test that we won't overwrite existing key file os.Remove(crtF.Name()) ob.Reset() eb.Reset() args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing key: "+keyF.Name()) + require.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing key: "+keyF.Name()) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -335,14 +336,14 @@ func Test_signCert(t *testing.T) { os.Remove(keyF.Name()) os.Remove(crtF.Name()) args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.NoError(t, signCert(args, ob, eb, nopw)) + require.NoError(t, signCert(args, ob, eb, nopw)) // test that we won't overwrite existing certificate file os.Remove(keyF.Name()) ob.Reset() eb.Reset() args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing cert: "+crtF.Name()) + require.EqualError(t, signCert(args, ob, eb, nopw), "refusing to overwrite existing cert: "+crtF.Name()) assert.Empty(t, ob.String()) assert.Empty(t, eb.String()) @@ -355,11 +356,11 @@ func Test_signCert(t *testing.T) { eb.Reset() caKeyF, err = os.CreateTemp("", "sign-cert.key") - assert.NoError(t, err) + require.NoError(t, err) defer os.Remove(caKeyF.Name()) caCrtF, err = os.CreateTemp("", "sign-cert.crt") - assert.NoError(t, err) + require.NoError(t, err) defer os.Remove(caCrtF.Name()) // generate the encrypted key @@ -374,7 +375,7 @@ func Test_signCert(t *testing.T) { // test with the proper password args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.NoError(t, signCert(args, ob, eb, testpw)) + require.NoError(t, signCert(args, ob, eb, testpw)) assert.Equal(t, "Enter passphrase: ", ob.String()) assert.Empty(t, eb.String()) @@ -384,7 +385,7 @@ func Test_signCert(t *testing.T) { testpw.password = []byte("invalid password") args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.Error(t, signCert(args, ob, eb, testpw)) + require.Error(t, signCert(args, ob, eb, testpw)) assert.Equal(t, "Enter passphrase: ", ob.String()) assert.Empty(t, eb.String()) @@ -393,7 +394,7 @@ func Test_signCert(t *testing.T) { eb.Reset() args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.Error(t, signCert(args, ob, eb, nopw)) + require.Error(t, signCert(args, ob, eb, nopw)) // normally the user hitting enter on the prompt would add newlines between these assert.Equal(t, "Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: Enter passphrase: ", ob.String()) assert.Empty(t, eb.String()) @@ -403,7 +404,7 @@ func Test_signCert(t *testing.T) { eb.Reset() args = []string{"-version", "1", "-ca-crt", caCrtF.Name(), "-ca-key", caKeyF.Name(), "-name", "test", "-ip", "1.1.1.1/24", "-out-crt", crtF.Name(), "-out-key", keyF.Name(), "-duration", "100m", "-subnets", "10.1.1.1/32, , 10.2.2.2/32 , , ,, 10.5.5.5/32", "-groups", "1,, 2 , ,,,3,4,5"} - assert.Error(t, signCert(args, ob, eb, errpw)) + require.Error(t, signCert(args, ob, eb, errpw)) assert.Equal(t, "Enter passphrase: ", ob.String()) assert.Empty(t, eb.String()) } diff --git a/cmd/nebula-cert/verify_test.go b/cmd/nebula-cert/verify_test.go index c2a9f55..acc9cca 100644 --- a/cmd/nebula-cert/verify_test.go +++ b/cmd/nebula-cert/verify_test.go @@ -9,6 +9,7 @@ import ( "github.com/slackhq/nebula/cert" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "golang.org/x/crypto/ed25519" ) @@ -50,20 +51,20 @@ func Test_verify(t *testing.T) { err := verify([]string{"-ca", "does_not_exist", "-crt", "does_not_exist"}, ob, eb) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) - assert.EqualError(t, err, "error while reading ca: open does_not_exist: "+NoSuchFileError) + require.EqualError(t, err, "error while reading ca: open does_not_exist: "+NoSuchFileError) // invalid ca at path ob.Reset() eb.Reset() caFile, err := os.CreateTemp("", "verify-ca") - assert.NoError(t, err) + require.NoError(t, err) defer os.Remove(caFile.Name()) caFile.WriteString("-----BEGIN NOPE-----") err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) - assert.EqualError(t, err, "error while adding ca cert to pool: input did not contain a valid PEM encoded block") + require.EqualError(t, err, "error while adding ca cert to pool: input did not contain a valid PEM encoded block") // make a ca for later caPub, caPriv, _ := ed25519.GenerateKey(rand.Reader) @@ -77,20 +78,20 @@ func Test_verify(t *testing.T) { err = verify([]string{"-ca", caFile.Name(), "-crt", "does_not_exist"}, ob, eb) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) - assert.EqualError(t, err, "unable to read crt: open does_not_exist: "+NoSuchFileError) + require.EqualError(t, err, "unable to read crt: open does_not_exist: "+NoSuchFileError) // invalid crt at path ob.Reset() eb.Reset() certFile, err := os.CreateTemp("", "verify-cert") - assert.NoError(t, err) + require.NoError(t, err) defer os.Remove(certFile.Name()) certFile.WriteString("-----BEGIN NOPE-----") err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) - assert.EqualError(t, err, "error while parsing crt: input did not contain a valid PEM encoded block") + require.EqualError(t, err, "error while parsing crt: input did not contain a valid PEM encoded block") // unverifiable cert at path crt, _ := NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil) @@ -107,7 +108,7 @@ func Test_verify(t *testing.T) { err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) - assert.ErrorIs(t, err, cert.ErrSignatureMismatch) + require.ErrorIs(t, err, cert.ErrSignatureMismatch) // verified cert at path crt, _ = NewTestCert(ca, caPriv, "test-cert", time.Now().Add(time.Hour*-1), time.Now().Add(time.Hour), nil, nil, nil) @@ -119,5 +120,5 @@ func Test_verify(t *testing.T) { err = verify([]string{"-ca", caFile.Name(), "-crt", certFile.Name()}, ob, eb) assert.Equal(t, "", ob.String()) assert.Equal(t, "", eb.String()) - assert.NoError(t, err) + require.NoError(t, err) } diff --git a/config/config_test.go b/config/config_test.go index 39301f9..468c642 100644 --- a/config/config_test.go +++ b/config/config_test.go @@ -19,18 +19,18 @@ func TestConfig_Load(t *testing.T) { // invalid yaml c := NewC(l) os.WriteFile(filepath.Join(dir, "01.yaml"), []byte(" invalid yaml"), 0644) - assert.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n line 1: cannot unmarshal !!str `invalid...` into map[interface {}]interface {}") + require.EqualError(t, c.Load(dir), "yaml: unmarshal errors:\n line 1: cannot unmarshal !!str `invalid...` into map[interface {}]interface {}") // simple multi config merge c = NewC(l) os.RemoveAll(dir) os.Mkdir(dir, 0755) - assert.NoError(t, err) + require.NoError(t, err) os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644) os.WriteFile(filepath.Join(dir, "02.yml"), []byte("outer:\n inner: override\nnew: hi"), 0644) - assert.NoError(t, c.Load(dir)) + require.NoError(t, c.Load(dir)) expected := map[interface{}]interface{}{ "outer": map[interface{}]interface{}{ "inner": "override", @@ -117,11 +117,11 @@ func TestConfig_ReloadConfig(t *testing.T) { l := test.NewLogger() done := make(chan bool, 1) dir, err := os.MkdirTemp("", "config-test") - assert.NoError(t, err) + require.NoError(t, err) os.WriteFile(filepath.Join(dir, "01.yaml"), []byte("outer:\n inner: hi"), 0644) c := NewC(l) - assert.NoError(t, c.Load(dir)) + require.NoError(t, c.Load(dir)) assert.False(t, c.HasChanged("outer.inner")) assert.False(t, c.HasChanged("outer")) diff --git a/connection_manager_test.go b/connection_manager_test.go index 8e2ef15..2c9baa1 100644 --- a/connection_manager_test.go +++ b/connection_manager_test.go @@ -14,6 +14,7 @@ import ( "github.com/slackhq/nebula/test" "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func newTestLighthouse() *LightHouse { @@ -223,9 +224,9 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { } caCert, err := tbs.Sign(nil, cert.Curve_CURVE25519, privCA) - assert.NoError(t, err) + require.NoError(t, err) ncp := cert.NewCAPool() - assert.NoError(t, ncp.AddCA(caCert)) + require.NoError(t, ncp.AddCA(caCert)) pubCrt, _, _ := ed25519.GenerateKey(rand.Reader) tbs = &cert.TBSCertificate{ @@ -237,7 +238,7 @@ func Test_NewConnectionManagerTest_DisconnectInvalid(t *testing.T) { PublicKey: pubCrt, } peerCert, err := tbs.Sign(caCert, cert.Curve_CURVE25519, privCA) - assert.NoError(t, err) + require.NoError(t, err) cachedPeerCert, err := ncp.VerifyCertificate(now.Add(time.Second), peerCert) diff --git a/e2e/handshakes_test.go b/e2e/handshakes_test.go index 2e7e6e4..06f2a21 100644 --- a/e2e/handshakes_test.go +++ b/e2e/handshakes_test.go @@ -19,6 +19,7 @@ import ( "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/udp" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "gopkg.in/yaml.v2" ) @@ -771,7 +772,7 @@ func TestRehandshakingRelays(t *testing.T) { "key": string(myNextPrivKey), } rc, err := yaml.Marshal(relayConfig.Settings) - assert.NoError(t, err) + require.NoError(t, err) relayConfig.ReloadConfigString(string(rc)) for { @@ -875,7 +876,7 @@ func TestRehandshakingRelaysPrimary(t *testing.T) { "key": string(myNextPrivKey), } rc, err := yaml.Marshal(relayConfig.Settings) - assert.NoError(t, err) + require.NoError(t, err) relayConfig.ReloadConfigString(string(rc)) for { @@ -970,7 +971,7 @@ func TestRehandshaking(t *testing.T) { "key": string(myNextPrivKey), } rc, err := yaml.Marshal(myConfig.Settings) - assert.NoError(t, err) + require.NoError(t, err) myConfig.ReloadConfigString(string(rc)) for { @@ -987,9 +988,9 @@ func TestRehandshaking(t *testing.T) { r.Log("Got the new cert") // Flip their firewall to only allowing the new group to catch the tunnels reverting incorrectly rc, err = yaml.Marshal(theirConfig.Settings) - assert.NoError(t, err) + require.NoError(t, err) var theirNewConfig m - assert.NoError(t, yaml.Unmarshal(rc, &theirNewConfig)) + require.NoError(t, yaml.Unmarshal(rc, &theirNewConfig)) theirFirewall := theirNewConfig["firewall"].(map[interface{}]interface{}) theirFirewall["inbound"] = []m{{ "proto": "any", @@ -997,7 +998,7 @@ func TestRehandshaking(t *testing.T) { "group": "new group", }} rc, err = yaml.Marshal(theirNewConfig) - assert.NoError(t, err) + require.NoError(t, err) theirConfig.ReloadConfigString(string(rc)) r.Log("Spin until there is only 1 tunnel") @@ -1067,7 +1068,7 @@ func TestRehandshakingLoser(t *testing.T) { "key": string(theirNextPrivKey), } rc, err := yaml.Marshal(theirConfig.Settings) - assert.NoError(t, err) + require.NoError(t, err) theirConfig.ReloadConfigString(string(rc)) for { @@ -1083,9 +1084,9 @@ func TestRehandshakingLoser(t *testing.T) { // Flip my firewall to only allowing the new group to catch the tunnels reverting incorrectly rc, err = yaml.Marshal(myConfig.Settings) - assert.NoError(t, err) + require.NoError(t, err) var myNewConfig m - assert.NoError(t, yaml.Unmarshal(rc, &myNewConfig)) + require.NoError(t, yaml.Unmarshal(rc, &myNewConfig)) theirFirewall := myNewConfig["firewall"].(map[interface{}]interface{}) theirFirewall["inbound"] = []m{{ "proto": "any", @@ -1093,7 +1094,7 @@ func TestRehandshakingLoser(t *testing.T) { "group": "their new group", }} rc, err = yaml.Marshal(myNewConfig) - assert.NoError(t, err) + require.NoError(t, err) myConfig.ReloadConfigString(string(rc)) r.Log("Spin until there is only 1 tunnel") diff --git a/firewall_test.go b/firewall_test.go index 92914af..8c2eeb0 100644 --- a/firewall_test.go +++ b/firewall_test.go @@ -66,61 +66,61 @@ func TestFirewall_AddRule(t *testing.T) { assert.NotNil(t, fw.OutRules) ti, err := netip.ParsePrefix("1.2.3.4/32") - assert.NoError(t, err) + require.NoError(t, err) - assert.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoTCP, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "")) // An empty rule is any assert.True(t, fw.InRules.TCP[1].Any.Any.Any) assert.Empty(t, fw.InRules.TCP[1].Any.Groups) assert.Empty(t, fw.InRules.TCP[1].Any.Hosts) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) assert.Nil(t, fw.InRules.UDP[1].Any.Any) assert.Contains(t, fw.InRules.UDP[1].Any.Groups[0].Groups, "g1") assert.Empty(t, fw.InRules.UDP[1].Any.Hosts) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", netip.Prefix{}, netip.Prefix{}, "", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoICMP, 1, 1, []string{}, "h1", netip.Prefix{}, netip.Prefix{}, "", "")) assert.Nil(t, fw.InRules.ICMP[1].Any.Any) assert.Empty(t, fw.InRules.ICMP[1].Any.Groups) assert.Contains(t, fw.InRules.ICMP[1].Any.Hosts, "h1") fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, netip.Prefix{}, "", "")) + require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", ti, netip.Prefix{}, "", "")) assert.Nil(t, fw.OutRules.AnyProto[1].Any.Any) _, ok := fw.OutRules.AnyProto[1].Any.CIDR.Get(ti) assert.True(t, ok) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", "")) + require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, ti, "", "")) assert.NotNil(t, fw.OutRules.AnyProto[1].Any.Any) _, ok = fw.OutRules.AnyProto[1].Any.Any.LocalCIDR.Get(ti) assert.True(t, ok) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "ca-name", "")) assert.Contains(t, fw.InRules.UDP[1].CANames, "ca-name") fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "ca-sha")) + require.NoError(t, fw.AddRule(true, firewall.ProtoUDP, 1, 1, []string{"g1"}, "", netip.Prefix{}, netip.Prefix{}, "", "ca-sha")) assert.Contains(t, fw.InRules.UDP[1].CAShas, "ca-sha") fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", netip.Prefix{}, netip.Prefix{}, "", "")) + require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "any", netip.Prefix{}, netip.Prefix{}, "", "")) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) anyIp, err := netip.ParsePrefix("0.0.0.0/0") - assert.NoError(t, err) + require.NoError(t, err) - assert.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", "")) + require.NoError(t, fw.AddRule(false, firewall.ProtoAny, 0, 0, []string{}, "", anyIp, netip.Prefix{}, "", "")) assert.True(t, fw.OutRules.AnyProto[0].Any.Any.Any) // Test error conditions fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c) - assert.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "")) - assert.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "")) + require.Error(t, fw.AddRule(true, math.MaxUint8, 0, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "")) + require.Error(t, fw.AddRule(true, firewall.ProtoAny, 10, 0, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "")) } func TestFirewall_Drop(t *testing.T) { @@ -155,16 +155,16 @@ func TestFirewall_Drop(t *testing.T) { h.buildNetworks(c.networks, c.unsafeNetworks) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) cp := cert.NewCAPool() // Drop outbound assert.Equal(t, ErrNoMatchingRule, fw.Drop(p, false, &h, cp, nil)) // Allow inbound resetConntrack(fw) - assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p, true, &h, cp, nil)) // Allow outbound because conntrack - assert.NoError(t, fw.Drop(p, false, &h, cp, nil)) + require.NoError(t, fw.Drop(p, false, &h, cp, nil)) // test remote mismatch oldRemote := p.RemoteAddr @@ -174,29 +174,29 @@ func TestFirewall_Drop(t *testing.T) { // ensure signer doesn't get in the way of group checks fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum")) - assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad")) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) // test caSha doesn't drop on match fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad")) - assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum")) - assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum-bad")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-shasum")) + require.NoError(t, fw.Drop(p, true, &h, cp, nil)) // ensure ca name doesn't get in the way of group checks cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", "")) - assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", "")) assert.Equal(t, fw.Drop(p, true, &h, cp, nil), ErrNoMatchingRule) // test caName doesn't drop on match cp.CAs["signer-shasum"] = &cert.CachedCertificate{Certificate: &dummyCert{name: "ca-good"}} fw = NewFirewall(l, time.Second, time.Minute, time.Hour, &c) - assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", "")) - assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", "")) - assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"nope"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good-bad", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group"}, "", netip.Prefix{}, netip.Prefix{}, "ca-good", "")) + require.NoError(t, fw.Drop(p, true, &h, cp, nil)) } func BenchmarkFirewallTable_match(b *testing.B) { @@ -350,14 +350,14 @@ func TestFirewall_Drop2(t *testing.T) { h1.buildNetworks(c1.Certificate.Networks(), c1.Certificate.UnsafeNetworks()) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) - assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"default-group", "test-group"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) cp := cert.NewCAPool() // h1/c1 lacks the proper groups - assert.ErrorIs(t, fw.Drop(p, true, &h1, cp, nil), ErrNoMatchingRule) + require.ErrorIs(t, fw.Drop(p, true, &h1, cp, nil), ErrNoMatchingRule) // c has the proper groups resetConntrack(fw) - assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p, true, &h, cp, nil)) } func TestFirewall_Drop3(t *testing.T) { @@ -428,23 +428,23 @@ func TestFirewall_Drop3(t *testing.T) { h3.buildNetworks(c3.Certificate.Networks(), c3.Certificate.UnsafeNetworks()) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) - assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", "")) - assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-sha")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "host1", netip.Prefix{}, netip.Prefix{}, "", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.Prefix{}, netip.Prefix{}, "", "signer-sha")) cp := cert.NewCAPool() // c1 should pass because host match - assert.NoError(t, fw.Drop(p, true, &h1, cp, nil)) + require.NoError(t, fw.Drop(p, true, &h1, cp, nil)) // c2 should pass because ca sha match resetConntrack(fw) - assert.NoError(t, fw.Drop(p, true, &h2, cp, nil)) + require.NoError(t, fw.Drop(p, true, &h2, cp, nil)) // c3 should fail because no match resetConntrack(fw) assert.Equal(t, fw.Drop(p, true, &h3, cp, nil), ErrNoMatchingRule) // Test a remote address match fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) - assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("1.2.3.4/24"), netip.Prefix{}, "", "")) - assert.NoError(t, fw.Drop(p, true, &h1, cp, nil)) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 1, 1, []string{}, "", netip.MustParsePrefix("1.2.3.4/24"), netip.Prefix{}, "", "")) + require.NoError(t, fw.Drop(p, true, &h1, cp, nil)) } func TestFirewall_DropConntrackReload(t *testing.T) { @@ -480,29 +480,29 @@ func TestFirewall_DropConntrackReload(t *testing.T) { h.buildNetworks(c.Certificate.Networks(), c.Certificate.UnsafeNetworks()) fw := NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) - assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 0, 0, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) cp := cert.NewCAPool() // Drop outbound assert.Equal(t, fw.Drop(p, false, &h, cp, nil), ErrNoMatchingRule) // Allow inbound resetConntrack(fw) - assert.NoError(t, fw.Drop(p, true, &h, cp, nil)) + require.NoError(t, fw.Drop(p, true, &h, cp, nil)) // Allow outbound because conntrack - assert.NoError(t, fw.Drop(p, false, &h, cp, nil)) + require.NoError(t, fw.Drop(p, false, &h, cp, nil)) oldFw := fw fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) - assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 10, 10, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) fw.Conntrack = oldFw.Conntrack fw.rulesVersion = oldFw.rulesVersion + 1 // Allow outbound because conntrack and new rules allow port 10 - assert.NoError(t, fw.Drop(p, false, &h, cp, nil)) + require.NoError(t, fw.Drop(p, false, &h, cp, nil)) oldFw = fw fw = NewFirewall(l, time.Second, time.Minute, time.Hour, c.Certificate) - assert.NoError(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) + require.NoError(t, fw.AddRule(true, firewall.ProtoAny, 11, 11, []string{"any"}, "", netip.Prefix{}, netip.Prefix{}, "", "")) fw.Conntrack = oldFw.Conntrack fw.rulesVersion = oldFw.rulesVersion + 1 @@ -585,42 +585,42 @@ func BenchmarkLookup(b *testing.B) { func Test_parsePort(t *testing.T) { _, _, err := parsePort("") - assert.EqualError(t, err, "was not a number; ``") + require.EqualError(t, err, "was not a number; ``") _, _, err = parsePort(" ") - assert.EqualError(t, err, "was not a number; ` `") + require.EqualError(t, err, "was not a number; ` `") _, _, err = parsePort("-") - assert.EqualError(t, err, "appears to be a range but could not be parsed; `-`") + require.EqualError(t, err, "appears to be a range but could not be parsed; `-`") _, _, err = parsePort(" - ") - assert.EqualError(t, err, "appears to be a range but could not be parsed; ` - `") + require.EqualError(t, err, "appears to be a range but could not be parsed; ` - `") _, _, err = parsePort("a-b") - assert.EqualError(t, err, "beginning range was not a number; `a`") + require.EqualError(t, err, "beginning range was not a number; `a`") _, _, err = parsePort("1-b") - assert.EqualError(t, err, "ending range was not a number; `b`") + require.EqualError(t, err, "ending range was not a number; `b`") s, e, err := parsePort(" 1 - 2 ") assert.Equal(t, int32(1), s) assert.Equal(t, int32(2), e) - assert.NoError(t, err) + require.NoError(t, err) s, e, err = parsePort("0-1") assert.Equal(t, int32(0), s) assert.Equal(t, int32(0), e) - assert.NoError(t, err) + require.NoError(t, err) s, e, err = parsePort("9919") assert.Equal(t, int32(9919), s) assert.Equal(t, int32(9919), e) - assert.NoError(t, err) + require.NoError(t, err) s, e, err = parsePort("any") assert.Equal(t, int32(0), s) assert.Equal(t, int32(0), e) - assert.NoError(t, err) + require.NoError(t, err) } func TestNewFirewallFromConfig(t *testing.T) { @@ -633,53 +633,53 @@ func TestNewFirewallFromConfig(t *testing.T) { conf := config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": "asdf"} _, err = NewFirewallFromConfig(l, cs, conf) - assert.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules") + require.EqualError(t, err, "firewall.outbound failed to parse, should be an array of rules") // Test both port and code conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "code": "2"}}} _, err = NewFirewallFromConfig(l, cs, conf) - assert.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided") + require.EqualError(t, err, "firewall.outbound rule #0; only one of port or code should be provided") // Test missing host, group, cidr, ca_name and ca_sha conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{}}} _, err = NewFirewallFromConfig(l, cs, conf) - assert.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided") + require.EqualError(t, err, "firewall.outbound rule #0; at least one of host, group, cidr, local_cidr, ca_name, or ca_sha must be provided") // Test code/port error conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "a", "host": "testh"}}} _, err = NewFirewallFromConfig(l, cs, conf) - assert.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`") + require.EqualError(t, err, "firewall.outbound rule #0; code was not a number; `a`") conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "a", "host": "testh"}}} _, err = NewFirewallFromConfig(l, cs, conf) - assert.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`") + require.EqualError(t, err, "firewall.outbound rule #0; port was not a number; `a`") // Test proto error conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "host": "testh"}}} _, err = NewFirewallFromConfig(l, cs, conf) - assert.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``") + require.EqualError(t, err, "firewall.outbound rule #0; proto was not understood; ``") // Test cidr parse error conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "cidr": "testh", "proto": "any"}}} _, err = NewFirewallFromConfig(l, cs, conf) - assert.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") + require.EqualError(t, err, "firewall.outbound rule #0; cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") // Test local_cidr parse error conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"code": "1", "local_cidr": "testh", "proto": "any"}}} _, err = NewFirewallFromConfig(l, cs, conf) - assert.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") + require.EqualError(t, err, "firewall.outbound rule #0; local_cidr did not parse; netip.ParsePrefix(\"testh\"): no '/'") // Test both group and groups conf = config.NewC(l) conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a", "groups": []string{"b", "c"}}}} _, err = NewFirewallFromConfig(l, cs, conf) - assert.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided") + require.EqualError(t, err, "firewall.inbound rule #0; only one of group or groups should be defined, both provided") } func TestAddFirewallRulesFromConfig(t *testing.T) { @@ -688,28 +688,28 @@ func TestAddFirewallRulesFromConfig(t *testing.T) { conf := config.NewC(l) mf := &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "tcp", "host": "a"}}} - assert.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoTCP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test adding udp rule conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "udp", "host": "a"}}} - assert.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoUDP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test adding icmp rule conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"outbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "icmp", "host": "a"}}} - assert.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, false, conf, mf)) assert.Equal(t, addRuleCall{incoming: false, proto: firewall.ProtoICMP, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test adding any rule conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}} - assert.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, host: "a", ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test adding rule with cidr @@ -717,49 +717,49 @@ func TestAddFirewallRulesFromConfig(t *testing.T) { conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "cidr": cidr.String()}}} - assert.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: cidr, localIp: netip.Prefix{}}, mf.lastCall) // Test adding rule with local_cidr conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "local_cidr": cidr.String()}}} - assert.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: cidr}, mf.lastCall) // Test adding rule with ca_sha conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_sha": "12312313123"}}} - assert.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caSha: "12312313123"}, mf.lastCall) // Test adding rule with ca_name conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "ca_name": "root01"}}} - assert.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: nil, ip: netip.Prefix{}, localIp: netip.Prefix{}, caName: "root01"}, mf.lastCall) // Test single group conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "group": "a"}}} - assert.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test single groups conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": "a"}}} - assert.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test multiple AND groups conf = config.NewC(l) mf = &mockFirewall{} conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "groups": []string{"a", "b"}}}} - assert.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) + require.NoError(t, AddFirewallRulesFromConfig(l, true, conf, mf)) assert.Equal(t, addRuleCall{incoming: true, proto: firewall.ProtoAny, startPort: 1, endPort: 1, groups: []string{"a", "b"}, ip: netip.Prefix{}, localIp: netip.Prefix{}}, mf.lastCall) // Test Add error @@ -767,7 +767,7 @@ func TestAddFirewallRulesFromConfig(t *testing.T) { mf = &mockFirewall{} mf.nextCallReturn = errors.New("test error") conf.Settings["firewall"] = map[interface{}]interface{}{"inbound": []interface{}{map[interface{}]interface{}{"port": "1", "proto": "any", "host": "a"}}} - assert.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; `test error`") + require.EqualError(t, AddFirewallRulesFromConfig(l, true, conf, mf), "firewall.inbound rule #0; `test error`") } func TestFirewall_convertRule(t *testing.T) { @@ -782,7 +782,7 @@ func TestFirewall_convertRule(t *testing.T) { r, err := convertRule(l, c, "test", 1) assert.Contains(t, ob.String(), "test rule #1; group was an array with a single value, converting to simple value") - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "group1", r.Group) // Ensure group array of > 1 is errord @@ -793,7 +793,7 @@ func TestFirewall_convertRule(t *testing.T) { r, err = convertRule(l, c, "test", 1) assert.Equal(t, "", ob.String()) - assert.Error(t, err, "group should contain a single value, an array with more than one entry was provided") + require.Error(t, err, "group should contain a single value, an array with more than one entry was provided") // Make sure a well formed group is alright ob.Reset() @@ -802,7 +802,7 @@ func TestFirewall_convertRule(t *testing.T) { } r, err = convertRule(l, c, "test", 1) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, "group1", r.Group) } diff --git a/header/header_test.go b/header/header_test.go index 1836a75..a7e5374 100644 --- a/header/header_test.go +++ b/header/header_test.go @@ -5,6 +5,7 @@ import ( "testing" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) type headerTest struct { @@ -111,7 +112,7 @@ func TestHeader_String(t *testing.T) { func TestHeader_MarshalJSON(t *testing.T) { b, err := (&H{100, Test, TestRequest, 99, 98, 97}).MarshalJSON() - assert.NoError(t, err) + require.NoError(t, err) assert.Equal( t, "{\"messageCounter\":97,\"remoteIndex\":98,\"reserved\":99,\"subType\":\"testRequest\",\"type\":\"test\",\"version\":100}", diff --git a/lighthouse_test.go b/lighthouse_test.go index 9e9ad53..3b1295a 100644 --- a/lighthouse_test.go +++ b/lighthouse_test.go @@ -13,6 +13,7 @@ import ( "github.com/slackhq/nebula/header" "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "gopkg.in/yaml.v2" ) @@ -21,7 +22,7 @@ func TestOldIPv4Only(t *testing.T) { b := []byte{8, 129, 130, 132, 80, 16, 10} var m V4AddrPort err := m.Unmarshal(b) - assert.NoError(t, err) + require.NoError(t, err) ip := netip.MustParseAddr("10.1.1.1") bp := ip.As4() assert.Equal(t, binary.BigEndian.Uint32(bp[:]), m.GetAddr()) @@ -42,14 +43,14 @@ func Test_lhStaticMapping(t *testing.T) { c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1}} c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}} _, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) - assert.NoError(t, err) + require.NoError(t, err) lh2 := "10.128.0.3" c = config.NewC(l) c.Settings["lighthouse"] = map[interface{}]interface{}{"hosts": []interface{}{lh1, lh2}} c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"100.1.1.1:4242"}} _, err = NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) - assert.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry") + require.EqualError(t, err, "lighthouse 10.128.0.3 does not have a static_host_map entry") } func TestReloadLighthouseInterval(t *testing.T) { @@ -71,19 +72,19 @@ func TestReloadLighthouseInterval(t *testing.T) { c.Settings["static_host_map"] = map[interface{}]interface{}{lh1: []interface{}{"1.1.1.1:4242"}} lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) - assert.NoError(t, err) + require.NoError(t, err) lh.ifce = &mockEncWriter{} // The first one routine is kicked off by main.go currently, lets make sure that one dies - assert.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 5")) + require.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 5")) assert.Equal(t, int64(5), lh.interval.Load()) // Subsequent calls are killed off by the LightHouse.Reload function - assert.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 10")) + require.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 10")) assert.Equal(t, int64(10), lh.interval.Load()) // If this completes then nothing is stealing our reload routine - assert.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 11")) + require.NoError(t, c.ReloadConfigString("lighthouse:\n interval: 11")) assert.Equal(t, int64(11), lh.interval.Load()) } @@ -99,9 +100,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { c := config.NewC(l) lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) - if !assert.NoError(b, err) { - b.Fatal() - } + require.NoError(b, err) hAddr := netip.MustParseAddrPort("4.5.6.7:12345") hAddr2 := netip.MustParseAddrPort("4.5.6.7:12346") @@ -145,7 +144,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { }, } p, err := req.Marshal() - assert.NoError(b, err) + require.NoError(b, err) for n := 0; n < b.N; n++ { lhh.HandleRequest(rAddr, hi, p, mw) } @@ -160,7 +159,7 @@ func BenchmarkLighthouseHandleRequest(b *testing.B) { }, } p, err := req.Marshal() - assert.NoError(b, err) + require.NoError(b, err) for n := 0; n < b.N; n++ { lhh.HandleRequest(rAddr, hi, p, mw) @@ -205,7 +204,7 @@ func TestLighthouse_Memory(t *testing.T) { } lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) lh.ifce = &mockEncWriter{} - assert.NoError(t, err) + require.NoError(t, err) lhh := lh.NewRequestHandler() // Test that my first update responds with just that @@ -290,7 +289,7 @@ func TestLighthouse_reload(t *testing.T) { } lh, err := NewLightHouseFromConfig(context.Background(), l, c, cs, nil, nil) - assert.NoError(t, err) + require.NoError(t, err) nc := map[interface{}]interface{}{ "static_host_map": map[interface{}]interface{}{ @@ -298,11 +297,11 @@ func TestLighthouse_reload(t *testing.T) { }, } rc, err := yaml.Marshal(nc) - assert.NoError(t, err) + require.NoError(t, err) c.ReloadConfigString(string(rc)) err = lh.reload(c, false) - assert.NoError(t, err) + require.NoError(t, err) } func newLHHostRequest(fromAddr netip.AddrPort, myVpnIp, queryVpnIp netip.Addr, lhh *LightHouseHandler) testLhReply { diff --git a/outside_test.go b/outside_test.go index 944bf16..c63e57d 100644 --- a/outside_test.go +++ b/outside_test.go @@ -12,6 +12,7 @@ import ( "github.com/slackhq/nebula/firewall" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" "golang.org/x/net/ipv4" ) @@ -20,13 +21,13 @@ func Test_newPacket(t *testing.T) { // length fails err := newPacket([]byte{}, true, p) - assert.ErrorIs(t, err, ErrPacketTooShort) + require.ErrorIs(t, err, ErrPacketTooShort) err = newPacket([]byte{0x40}, true, p) - assert.ErrorIs(t, err, ErrIPv4PacketTooShort) + require.ErrorIs(t, err, ErrIPv4PacketTooShort) err = newPacket([]byte{0x60}, true, p) - assert.ErrorIs(t, err, ErrIPv6PacketTooShort) + require.ErrorIs(t, err, ErrIPv6PacketTooShort) // length fail with ip options h := ipv4.Header{ @@ -39,15 +40,15 @@ func Test_newPacket(t *testing.T) { b, _ := h.Marshal() err = newPacket(b, true, p) - assert.ErrorIs(t, err, ErrIPv4InvalidHeaderLength) + require.ErrorIs(t, err, ErrIPv4InvalidHeaderLength) // not an ipv4 packet err = newPacket([]byte{0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p) - assert.ErrorIs(t, err, ErrUnknownIPVersion) + require.ErrorIs(t, err, ErrUnknownIPVersion) // invalid ihl err = newPacket([]byte{4<<4 | (8 >> 2 & 0x0f), 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0}, true, p) - assert.ErrorIs(t, err, ErrIPv4InvalidHeaderLength) + require.ErrorIs(t, err, ErrIPv4InvalidHeaderLength) // account for variable ip header length - incoming h = ipv4.Header{ @@ -63,7 +64,7 @@ func Test_newPacket(t *testing.T) { b = append(b, []byte{0, 3, 0, 4}...) err = newPacket(b, true, p) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol) assert.Equal(t, netip.MustParseAddr("10.0.0.2"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("10.0.0.1"), p.RemoteAddr) @@ -85,7 +86,7 @@ func Test_newPacket(t *testing.T) { b = append(b, []byte{0, 5, 0, 6}...) err = newPacket(b, false, p) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, uint8(2), p.Protocol) assert.Equal(t, netip.MustParseAddr("10.0.0.1"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("10.0.0.2"), p.RemoteAddr) @@ -111,10 +112,10 @@ func Test_newPacket_v6(t *testing.T) { FixLengths: false, } err := gopacket.SerializeLayers(buffer, opt, &ip) - assert.NoError(t, err) + require.NoError(t, err) err = newPacket(buffer.Bytes(), true, p) - assert.ErrorIs(t, err, ErrIPv6CouldNotFindPayload) + require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload) // A good ICMP packet ip = layers.IPv6{ @@ -134,7 +135,7 @@ func Test_newPacket_v6(t *testing.T) { } err = newPacket(buffer.Bytes(), true, p) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, uint8(layers.IPProtocolICMPv6), p.Protocol) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) @@ -146,7 +147,7 @@ func Test_newPacket_v6(t *testing.T) { b := buffer.Bytes() b[6] = byte(layers.IPProtocolESP) err = newPacket(b, true, p) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, uint8(layers.IPProtocolESP), p.Protocol) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) @@ -158,7 +159,7 @@ func Test_newPacket_v6(t *testing.T) { b = buffer.Bytes() b[6] = byte(layers.IPProtocolNoNextHeader) err = newPacket(b, true, p) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, uint8(layers.IPProtocolNoNextHeader), p.Protocol) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) @@ -170,7 +171,7 @@ func Test_newPacket_v6(t *testing.T) { b = buffer.Bytes() b[6] = 255 // 255 is a reserved protocol number err = newPacket(b, true, p) - assert.ErrorIs(t, err, ErrIPv6CouldNotFindPayload) + require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload) // A good UDP packet ip = layers.IPv6{ @@ -186,7 +187,7 @@ func Test_newPacket_v6(t *testing.T) { DstPort: layers.UDPPort(22), } err = udp.SetNetworkLayerForChecksum(&ip) - assert.NoError(t, err) + require.NoError(t, err) buffer.Clear() err = gopacket.SerializeLayers(buffer, opt, &ip, &udp, gopacket.Payload([]byte{0xde, 0xad, 0xbe, 0xef})) @@ -197,7 +198,7 @@ func Test_newPacket_v6(t *testing.T) { // incoming err = newPacket(b, true, p) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) @@ -207,7 +208,7 @@ func Test_newPacket_v6(t *testing.T) { // outgoing err = newPacket(b, false, p) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr) @@ -217,14 +218,14 @@ func Test_newPacket_v6(t *testing.T) { // Too short UDP packet err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes - assert.ErrorIs(t, err, ErrIPv6PacketTooShort) + require.ErrorIs(t, err, ErrIPv6PacketTooShort) // A good TCP packet b[6] = byte(layers.IPProtocolTCP) // incoming err = newPacket(b, true, p) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) @@ -234,7 +235,7 @@ func Test_newPacket_v6(t *testing.T) { // outgoing err = newPacket(b, false, p) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, uint8(firewall.ProtoTCP), p.Protocol) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr) @@ -244,7 +245,7 @@ func Test_newPacket_v6(t *testing.T) { // Too short TCP packet err = newPacket(b[:len(b)-10], false, p) // pull off the last 10 bytes - assert.ErrorIs(t, err, ErrIPv6PacketTooShort) + require.ErrorIs(t, err, ErrIPv6PacketTooShort) // A good UDP packet with an AH header ip = layers.IPv6{ @@ -279,7 +280,7 @@ func Test_newPacket_v6(t *testing.T) { b = append(b, udpHeader...) err = newPacket(b, true, p) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, uint8(firewall.ProtoUDP), p.Protocol) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) @@ -290,7 +291,7 @@ func Test_newPacket_v6(t *testing.T) { // Invalid AH header b = buffer.Bytes() err = newPacket(b, true, p) - assert.ErrorIs(t, err, ErrIPv6CouldNotFindPayload) + require.ErrorIs(t, err, ErrIPv6CouldNotFindPayload) } func Test_newPacket_ipv6Fragment(t *testing.T) { @@ -338,7 +339,7 @@ func Test_newPacket_ipv6Fragment(t *testing.T) { // Test first fragment incoming err = newPacket(firstFrag, true, p) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol) @@ -348,7 +349,7 @@ func Test_newPacket_ipv6Fragment(t *testing.T) { // Test first fragment outgoing err = newPacket(firstFrag, false, p) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr) assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol) @@ -377,7 +378,7 @@ func Test_newPacket_ipv6Fragment(t *testing.T) { // Test second fragment incoming err = newPacket(secondFrag, true, p) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.RemoteAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.LocalAddr) assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol) @@ -387,7 +388,7 @@ func Test_newPacket_ipv6Fragment(t *testing.T) { // Test second fragment outgoing err = newPacket(secondFrag, false, p) - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, netip.MustParseAddr("ff02::2"), p.LocalAddr) assert.Equal(t, netip.MustParseAddr("ff02::1"), p.RemoteAddr) assert.Equal(t, uint8(layers.IPProtocolUDP), p.Protocol) @@ -397,7 +398,7 @@ func Test_newPacket_ipv6Fragment(t *testing.T) { // Too short of a fragment packet err = newPacket(secondFrag[:len(secondFrag)-10], false, p) - assert.ErrorIs(t, err, ErrIPv6PacketTooShort) + require.ErrorIs(t, err, ErrIPv6PacketTooShort) } func BenchmarkParseV6(b *testing.B) { diff --git a/overlay/route_test.go b/overlay/route_test.go index 4fa30af..8f2c094 100644 --- a/overlay/route_test.go +++ b/overlay/route_test.go @@ -8,84 +8,85 @@ import ( "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func Test_parseRoutes(t *testing.T) { l := test.NewLogger() c := config.NewC(l) n, err := netip.ParsePrefix("10.0.0.0/24") - assert.NoError(t, err) + require.NoError(t, err) // test no routes config routes, err := parseRoutes(c, []netip.Prefix{n}) - assert.NoError(t, err) + require.NoError(t, err) assert.Empty(t, routes) // not an array c.Settings["tun"] = map[interface{}]interface{}{"routes": "hi"} routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "tun.routes is not an array") + require.EqualError(t, err, "tun.routes is not an array") // no routes c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{}} routes, err = parseRoutes(c, []netip.Prefix{n}) - assert.NoError(t, err) + require.NoError(t, err) assert.Empty(t, routes) // weird route c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{"asdf"}} routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1 in tun.routes is invalid") + require.EqualError(t, err, "entry 1 in tun.routes is invalid") // no mtu c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{}}} routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.mtu in tun.routes is not present") + require.EqualError(t, err, "entry 1.mtu in tun.routes is not present") // bad mtu c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "nope"}}} routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.mtu in tun.routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax") + require.EqualError(t, err, "entry 1.mtu in tun.routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax") // low mtu c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "499"}}} routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.mtu in tun.routes is below 500: 499") + require.EqualError(t, err, "entry 1.mtu in tun.routes is below 500: 499") // missing route c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500"}}} routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.route in tun.routes is not present") + require.EqualError(t, err, "entry 1.route in tun.routes is not present") // unparsable route c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "nope"}}} routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.route in tun.routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'") + require.EqualError(t, err, "entry 1.route in tun.routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'") // below network range c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "1.0.0.0/8"}}} routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 1.0.0.0/8, networks: [10.0.0.0/24]") + require.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 1.0.0.0/8, networks: [10.0.0.0/24]") // above network range c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "10.0.1.0/24"}}} routes, err = parseRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 10.0.1.0/24, networks: [10.0.0.0/24]") + require.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 10.0.1.0/24, networks: [10.0.0.0/24]") // Not in multiple ranges c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "route": "192.0.0.0/24"}}} routes, err = parseRoutes(c, []netip.Prefix{n, netip.MustParsePrefix("192.1.0.0/24")}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 192.0.0.0/24, networks: [10.0.0.0/24 192.1.0.0/24]") + require.EqualError(t, err, "entry 1.route in tun.routes is not contained within the configured vpn networks; route: 192.0.0.0/24, networks: [10.0.0.0/24 192.1.0.0/24]") // happy case c.Settings["tun"] = map[interface{}]interface{}{"routes": []interface{}{ @@ -93,7 +94,7 @@ func Test_parseRoutes(t *testing.T) { map[interface{}]interface{}{"mtu": "8000", "route": "10.0.0.1/32"}, }} routes, err = parseRoutes(c, []netip.Prefix{n}) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, routes, 2) tested := 0 @@ -119,36 +120,36 @@ func Test_parseUnsafeRoutes(t *testing.T) { l := test.NewLogger() c := config.NewC(l) n, err := netip.ParsePrefix("10.0.0.0/24") - assert.NoError(t, err) + require.NoError(t, err) // test no routes config routes, err := parseUnsafeRoutes(c, []netip.Prefix{n}) - assert.NoError(t, err) + require.NoError(t, err) assert.Empty(t, routes) // not an array c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": "hi"} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "tun.unsafe_routes is not an array") + require.EqualError(t, err, "tun.unsafe_routes is not an array") // no routes c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) - assert.NoError(t, err) + require.NoError(t, err) assert.Empty(t, routes) // weird route c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{"asdf"}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1 in tun.unsafe_routes is invalid") + require.EqualError(t, err, "entry 1 in tun.unsafe_routes is invalid") // no via c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not present") + require.EqualError(t, err, "entry 1.via in tun.unsafe_routes is not present") // invalid via for _, invalidValue := range []interface{}{ @@ -157,44 +158,44 @@ func Test_parseUnsafeRoutes(t *testing.T) { c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": invalidValue}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string: found %T", invalidValue)) + require.EqualError(t, err, fmt.Sprintf("entry 1.via in tun.unsafe_routes is not a string: found %T", invalidValue)) } // unparsable via c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"mtu": "500", "via": "nope"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP") + require.EqualError(t, err, "entry 1.via in tun.unsafe_routes failed to parse address: ParseAddr(\"nope\"): unable to parse IP") // missing route c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is not present") + require.EqualError(t, err, "entry 1.route in tun.unsafe_routes is not present") // unparsable route c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "500", "route": "nope"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'") + require.EqualError(t, err, "entry 1.route in tun.unsafe_routes failed to parse: netip.ParsePrefix(\"nope\"): no '/'") // within network range c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.0.0/24"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.route in tun.unsafe_routes is contained within the configured vpn networks; route: 10.0.0.0/24, network: 10.0.0.0/24") + require.EqualError(t, err, "entry 1.route in tun.unsafe_routes is contained within the configured vpn networks; route: 10.0.0.0/24, network: 10.0.0.0/24") // below network range c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Len(t, routes, 1) - assert.NoError(t, err) + require.NoError(t, err) // above network range c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "10.0.1.0/24"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Len(t, routes, 1) - assert.NoError(t, err) + require.NoError(t, err) // no mtu c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "route": "1.0.0.0/8"}}} @@ -206,19 +207,19 @@ func Test_parseUnsafeRoutes(t *testing.T) { c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "nope"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax") + require.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is not an integer: strconv.Atoi: parsing \"nope\": invalid syntax") // low mtu c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "499"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is below 500: 499") + require.EqualError(t, err, "entry 1.mtu in tun.unsafe_routes is below 500: 499") // bad install c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "9000", "route": "1.0.0.0/29", "install": "nope"}}} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) assert.Nil(t, routes) - assert.EqualError(t, err, "entry 1.install in tun.unsafe_routes is not a boolean: strconv.ParseBool: parsing \"nope\": invalid syntax") + require.EqualError(t, err, "entry 1.install in tun.unsafe_routes is not a boolean: strconv.ParseBool: parsing \"nope\": invalid syntax") // happy case c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{ @@ -228,7 +229,7 @@ func Test_parseUnsafeRoutes(t *testing.T) { map[interface{}]interface{}{"via": "127.0.0.1", "mtu": "1500", "metric": 1234, "route": "1.0.0.2/32"}, }} routes, err = parseUnsafeRoutes(c, []netip.Prefix{n}) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, routes, 4) tested := 0 @@ -260,38 +261,38 @@ func Test_makeRouteTree(t *testing.T) { l := test.NewLogger() c := config.NewC(l) n, err := netip.ParsePrefix("10.0.0.0/24") - assert.NoError(t, err) + require.NoError(t, err) c.Settings["tun"] = map[interface{}]interface{}{"unsafe_routes": []interface{}{ map[interface{}]interface{}{"via": "192.168.0.1", "route": "1.0.0.0/28"}, map[interface{}]interface{}{"via": "192.168.0.2", "route": "1.0.0.1/32"}, }} routes, err := parseUnsafeRoutes(c, []netip.Prefix{n}) - assert.NoError(t, err) + require.NoError(t, err) assert.Len(t, routes, 2) routeTree, err := makeRouteTree(l, routes, true) - assert.NoError(t, err) + require.NoError(t, err) ip, err := netip.ParseAddr("1.0.0.2") - assert.NoError(t, err) + require.NoError(t, err) r, ok := routeTree.Lookup(ip) assert.True(t, ok) nip, err := netip.ParseAddr("192.168.0.1") - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, nip, r) ip, err = netip.ParseAddr("1.0.0.1") - assert.NoError(t, err) + require.NoError(t, err) r, ok = routeTree.Lookup(ip) assert.True(t, ok) nip, err = netip.ParseAddr("192.168.0.2") - assert.NoError(t, err) + require.NoError(t, err) assert.Equal(t, nip, r) ip, err = netip.ParseAddr("1.1.0.1") - assert.NoError(t, err) + require.NoError(t, err) r, ok = routeTree.Lookup(ip) assert.False(t, ok) } diff --git a/punchy_test.go b/punchy_test.go index 7918449..99d703d 100644 --- a/punchy_test.go +++ b/punchy_test.go @@ -7,6 +7,7 @@ import ( "github.com/slackhq/nebula/config" "github.com/slackhq/nebula/test" "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" ) func TestNewPunchyFromConfig(t *testing.T) { @@ -56,7 +57,7 @@ func TestPunchy_reload(t *testing.T) { l := test.NewLogger() c := config.NewC(l) delay, _ := time.ParseDuration("1m") - assert.NoError(t, c.LoadString(` + require.NoError(t, c.LoadString(` punchy: delay: 1m respond: false @@ -66,7 +67,7 @@ punchy: assert.False(t, p.GetRespond()) newDelay, _ := time.ParseDuration("10m") - assert.NoError(t, c.ReloadConfigString(` + require.NoError(t, c.ReloadConfigString(` punchy: delay: 10m respond: true