package policy

import (
	"testing"

	"github.com/google/go-cmp/cmp"
	"github.com/juanfont/headscale/hscontrol/types"
	"github.com/stretchr/testify/assert"
	"github.com/stretchr/testify/require"
	"gorm.io/gorm"
	"tailscale.com/tailcfg"
)

func TestPolicySetChange(t *testing.T) {
	users := []types.User{
		{
			Model: gorm.Model{ID: 1},
			Name:  "testuser",
		},
	}
	tests := []struct {
		name             string
		users            []types.User
		nodes            types.Nodes
		policy           []byte
		wantUsersChange  bool
		wantNodesChange  bool
		wantPolicyChange bool
		wantFilter       []tailcfg.FilterRule
	}{
		{
			name: "set-nodes",
			nodes: types.Nodes{
				{
					IPv4: iap("100.64.0.2"),
					User: users[0],
				},
			},
			wantNodesChange: false,
			wantFilter: []tailcfg.FilterRule{
				{
					DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}},
				},
			},
		},
		{
			name:            "set-users",
			users:           users,
			wantUsersChange: false,
			wantFilter: []tailcfg.FilterRule{
				{
					DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}},
				},
			},
		},
		{
			name:  "set-users-and-node",
			users: users,
			nodes: types.Nodes{
				{
					IPv4: iap("100.64.0.2"),
					User: users[0],
				},
			},
			wantUsersChange: false,
			wantNodesChange: true,
			wantFilter: []tailcfg.FilterRule{
				{
					SrcIPs:   []string{"100.64.0.2/32"},
					DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.1/32", Ports: tailcfg.PortRangeAny}},
				},
			},
		},
		{
			name: "set-policy",
			policy: []byte(`
{
"acls": [
		{
			"action": "accept",
			"src": [
				"100.64.0.61",
			],
			"dst": [
				"100.64.0.62:*",
			],
		},
		],
}
				`),
			wantPolicyChange: true,
			wantFilter: []tailcfg.FilterRule{
				{
					SrcIPs:   []string{"100.64.0.61/32"},
					DstPorts: []tailcfg.NetPortRange{{IP: "100.64.0.62/32", Ports: tailcfg.PortRangeAny}},
				},
			},
		},
	}

	for _, tt := range tests {
		t.Run(tt.name, func(t *testing.T) {
			pol := `
{
	"groups": {
		"group:example": [
			"testuser",
		],
	},

	"hosts": {
		"host-1": "100.64.0.1",
		"subnet-1": "100.100.101.100/24",
	},

	"acls": [
		{
			"action": "accept",
			"src": [
				"group:example",
			],
			"dst": [
				"host-1:*",
			],
		},
	],
}
`
			pm, err := NewPolicyManager([]byte(pol), []types.User{}, types.Nodes{})
			require.NoError(t, err)

			if tt.policy != nil {
				change, err := pm.SetPolicy(tt.policy)
				require.NoError(t, err)

				assert.Equal(t, tt.wantPolicyChange, change)
			}

			if tt.users != nil {
				change, err := pm.SetUsers(tt.users)
				require.NoError(t, err)

				assert.Equal(t, tt.wantUsersChange, change)
			}

			if tt.nodes != nil {
				change, err := pm.SetNodes(tt.nodes)
				require.NoError(t, err)

				assert.Equal(t, tt.wantNodesChange, change)
			}

			if diff := cmp.Diff(tt.wantFilter, pm.Filter()); diff != "" {
				t.Errorf("TestPolicySetChange() unexpected result (-want +got):\n%s", diff)
			}
		})
	}
}