package db

import (
	"net/netip"
	"testing"

	"github.com/google/go-cmp/cmp"
	"github.com/juanfont/headscale/hscontrol/policy"
	"github.com/juanfont/headscale/hscontrol/types"
	"github.com/juanfont/headscale/hscontrol/util"
	"github.com/stretchr/testify/assert"
	"gopkg.in/check.v1"
	"tailscale.com/envknob"
	"tailscale.com/tailcfg"
)

// TODO(kradalby):
// Convert these tests to being non-database dependent and table driven. They are
// very verbose, and dont really need the database.

func (s *Suite) TestSshRules(c *check.C) {
	envknob.Setenv("HEADSCALE_EXPERIMENTAL_FEATURE_SSH", "1")

	user, err := db.CreateUser("user1")
	c.Assert(err, check.IsNil)

	pak, err := db.CreatePreAuthKey(user.Name, false, false, nil, nil)
	c.Assert(err, check.IsNil)

	_, err = db.GetMachine("user1", "testmachine")
	c.Assert(err, check.NotNil)
	hostInfo := tailcfg.Hostinfo{
		OS:          "centos",
		Hostname:    "testmachine",
		RequestTags: []string{"tag:test"},
	}

	machine := types.Machine{
		ID:             0,
		MachineKey:     "foo",
		NodeKey:        "bar",
		DiscoKey:       "faa",
		Hostname:       "testmachine",
		IPAddresses:    types.MachineAddresses{netip.MustParseAddr("100.64.0.1")},
		UserID:         user.ID,
		RegisterMethod: util.RegisterMethodAuthKey,
		AuthKeyID:      uint(pak.ID),
		HostInfo:       types.HostInfo(hostInfo),
	}
	err = db.MachineSave(&machine)
	c.Assert(err, check.IsNil)

	aclPolicy := &policy.ACLPolicy{
		Groups: policy.Groups{
			"group:test": []string{"user1"},
		},
		Hosts: policy.Hosts{
			"client": netip.PrefixFrom(netip.MustParseAddr("100.64.99.42"), 32),
		},
		ACLs: []policy.ACL{
			{
				Action:       "accept",
				Sources:      []string{"*"},
				Destinations: []string{"*:*"},
			},
		},
		SSHs: []policy.SSH{
			{
				Action:       "accept",
				Sources:      []string{"group:test"},
				Destinations: []string{"client"},
				Users:        []string{"autogroup:nonroot"},
			},
			{
				Action:       "accept",
				Sources:      []string{"*"},
				Destinations: []string{"client"},
				Users:        []string{"autogroup:nonroot"},
			},
		},
	}

	_, sshPolicy, err := policy.GenerateFilterRules(aclPolicy, &machine, types.Machines{}, false)

	c.Assert(err, check.IsNil)
	c.Assert(sshPolicy, check.NotNil)
	c.Assert(sshPolicy.Rules, check.HasLen, 2)
	c.Assert(sshPolicy.Rules[0].SSHUsers, check.HasLen, 1)
	c.Assert(sshPolicy.Rules[0].Principals, check.HasLen, 1)
	c.Assert(sshPolicy.Rules[0].Principals[0].UserLogin, check.Matches, "user1")

	c.Assert(sshPolicy.Rules[1].SSHUsers, check.HasLen, 1)
	c.Assert(sshPolicy.Rules[1].Principals, check.HasLen, 1)
	c.Assert(sshPolicy.Rules[1].Principals[0].NodeIP, check.Matches, "*")
}

// this test should validate that we can expand a group in a TagOWner section and
// match properly the IP's of the related hosts. The owner is valid and the tag is also valid.
// the tag is matched in the Sources section.
func TestValidExpandTagOwnersInSources(t *testing.T) {
	hostInfo := tailcfg.Hostinfo{
		OS:          "centos",
		Hostname:    "testmachine",
		RequestTags: []string{"tag:test"},
	}

	machine := types.Machine{
		ID:          0,
		MachineKey:  "foo",
		NodeKey:     "bar",
		DiscoKey:    "faa",
		Hostname:    "testmachine",
		IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")},
		UserID:      0,
		User: types.User{
			Name: "user1",
		},
		RegisterMethod: util.RegisterMethodAuthKey,
		HostInfo:       types.HostInfo(hostInfo),
	}

	pol := &policy.ACLPolicy{
		Groups:    policy.Groups{"group:test": []string{"user1", "user2"}},
		TagOwners: policy.TagOwners{"tag:test": []string{"user3", "group:test"}},
		ACLs: []policy.ACL{
			{
				Action:       "accept",
				Sources:      []string{"tag:test"},
				Destinations: []string{"*:*"},
			},
		},
	}

	got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}, false)
	assert.NoError(t, err)

	want := []tailcfg.FilterRule{
		{
			SrcIPs: []string{"100.64.0.1/32"},
			DstPorts: []tailcfg.NetPortRange{
				{IP: "0.0.0.0/0", Ports: tailcfg.PortRange{Last: 65535}},
				{IP: "::/0", Ports: tailcfg.PortRange{Last: 65535}},
			},
		},
	}

	if diff := cmp.Diff(want, got); diff != "" {
		t.Errorf("TestValidExpandTagOwnersInSources() unexpected result (-want +got):\n%s", diff)
	}
}

// need a test with:
// tag on a host that isn't owned by a tag owners. So the user
// of the host should be valid.
func TestInvalidTagValidUser(t *testing.T) {
	hostInfo := tailcfg.Hostinfo{
		OS:          "centos",
		Hostname:    "testmachine",
		RequestTags: []string{"tag:foo"},
	}

	machine := types.Machine{
		ID:          1,
		MachineKey:  "12345",
		NodeKey:     "bar",
		DiscoKey:    "faa",
		Hostname:    "testmachine",
		IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")},
		UserID:      1,
		User: types.User{
			Name: "user1",
		},
		RegisterMethod: util.RegisterMethodAuthKey,
		HostInfo:       types.HostInfo(hostInfo),
	}

	pol := &policy.ACLPolicy{
		TagOwners: policy.TagOwners{"tag:test": []string{"user1"}},
		ACLs: []policy.ACL{
			{
				Action:       "accept",
				Sources:      []string{"user1"},
				Destinations: []string{"*:*"},
			},
		},
	}

	got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}, false)
	assert.NoError(t, err)

	want := []tailcfg.FilterRule{
		{
			SrcIPs: []string{"100.64.0.1/32"},
			DstPorts: []tailcfg.NetPortRange{
				{IP: "0.0.0.0/0", Ports: tailcfg.PortRange{Last: 65535}},
				{IP: "::/0", Ports: tailcfg.PortRange{Last: 65535}},
			},
		},
	}

	if diff := cmp.Diff(want, got); diff != "" {
		t.Errorf("TestInvalidTagValidUser() unexpected result (-want +got):\n%s", diff)
	}
}

func TestPortGroup(t *testing.T) {
	machine := types.Machine{
		ID:         0,
		MachineKey: "foo",
		NodeKey:    "bar",
		DiscoKey:   "faa",
		Hostname:   "testmachine",
		UserID:     0,
		User: types.User{
			Name: "testuser",
		},
		RegisterMethod: util.RegisterMethodAuthKey,
		IPAddresses:    types.MachineAddresses{netip.MustParseAddr("100.64.0.5")},
	}

	acl := []byte(`
{
	"groups": {
		"group:example": [
			"testuser",
		],
	},

	"hosts": {
		"host-1": "100.100.100.100",
		"subnet-1": "100.100.101.100/24",
	},

	"acls": [
		{
			"action": "accept",
			"src": [
				"group:example",
			],
			"dst": [
				"host-1:*",
			],
		},
	],
}
	`)
	pol, err := policy.LoadACLPolicyFromBytes(acl, "hujson")
	assert.NoError(t, err)

	got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}, false)
	assert.NoError(t, err)

	want := []tailcfg.FilterRule{
		{
			SrcIPs: []string{"100.64.0.5/32"},
			DstPorts: []tailcfg.NetPortRange{
				{IP: "100.100.100.100/32", Ports: tailcfg.PortRange{Last: 65535}},
			},
		},
	}

	if diff := cmp.Diff(want, got); diff != "" {
		t.Errorf("TestPortGroup() unexpected result (-want +got):\n%s", diff)
	}
}

func TestPortUser(t *testing.T) {
	machine := types.Machine{
		ID:         0,
		MachineKey: "12345",
		NodeKey:    "bar",
		DiscoKey:   "faa",
		Hostname:   "testmachine",
		UserID:     0,
		User: types.User{
			Name: "testuser",
		},
		RegisterMethod: util.RegisterMethodAuthKey,
		IPAddresses:    types.MachineAddresses{netip.MustParseAddr("100.64.0.9")},
	}

	acl := []byte(`
{
	"hosts": {
		"host-1": "100.100.100.100",
		"subnet-1": "100.100.101.100/24",
	},

	"acls": [
		{
			"action": "accept",
			"src": [
				"testuser",
			],
			"dst": [
				"host-1:*",
			],
		},
	],
}
	`)
	pol, err := policy.LoadACLPolicyFromBytes(acl, "hujson")
	assert.NoError(t, err)

	got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}, false)
	assert.NoError(t, err)

	want := []tailcfg.FilterRule{
		{
			SrcIPs: []string{"100.64.0.9/32"},
			DstPorts: []tailcfg.NetPortRange{
				{IP: "100.100.100.100/32", Ports: tailcfg.PortRange{Last: 65535}},
			},
		},
	}

	if diff := cmp.Diff(want, got); diff != "" {
		t.Errorf("TestPortUser() unexpected result (-want +got):\n%s", diff)
	}
}

// this test should validate that we can expand a group in a TagOWner section and
// match properly the IP's of the related hosts. The owner is valid and the tag is also valid.
// the tag is matched in the Destinations section.
func TestValidExpandTagOwnersInDestinations(t *testing.T) {
	hostInfo := tailcfg.Hostinfo{
		OS:          "centos",
		Hostname:    "testmachine",
		RequestTags: []string{"tag:test"},
	}

	machine := types.Machine{
		ID:          1,
		MachineKey:  "12345",
		NodeKey:     "bar",
		DiscoKey:    "faa",
		Hostname:    "testmachine",
		IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")},
		UserID:      1,
		User: types.User{
			Name: "user1",
		},
		RegisterMethod: util.RegisterMethodAuthKey,
		HostInfo:       types.HostInfo(hostInfo),
	}

	pol := &policy.ACLPolicy{
		Groups:    policy.Groups{"group:test": []string{"user1", "user2"}},
		TagOwners: policy.TagOwners{"tag:test": []string{"user3", "group:test"}},
		ACLs: []policy.ACL{
			{
				Action:       "accept",
				Sources:      []string{"*"},
				Destinations: []string{"tag:test:*"},
			},
		},
	}

	// rules, _, err := policy.GenerateFilterRules(pol, &machine, peers, false)
	// c.Assert(err, check.IsNil)
	//
	// c.Assert(rules, check.HasLen, 1)
	// c.Assert(rules[0].DstPorts, check.HasLen, 1)
	// c.Assert(rules[0].DstPorts[0].IP, check.Equals, "100.64.0.1/32")

	got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{}, false)
	assert.NoError(t, err)

	want := []tailcfg.FilterRule{
		{
			SrcIPs: []string{"0.0.0.0/0", "::/0"},
			DstPorts: []tailcfg.NetPortRange{
				{IP: "100.64.0.1/32", Ports: tailcfg.PortRange{Last: 65535}},
			},
		},
	}

	if diff := cmp.Diff(want, got); diff != "" {
		t.Errorf("TestValidExpandTagOwnersInDestinations() unexpected result (-want +got):\n%s", diff)
	}
}

// tag on a host is owned by a tag owner, the tag is valid.
// an ACL rule is matching the tag to a user. It should not be valid since the
// host should be tied to the tag now.
func TestValidTagInvalidUser(t *testing.T) {
	hostInfo := tailcfg.Hostinfo{
		OS:          "centos",
		Hostname:    "webserver",
		RequestTags: []string{"tag:webapp"},
	}

	machine := types.Machine{
		ID:          1,
		MachineKey:  "12345",
		NodeKey:     "bar",
		DiscoKey:    "faa",
		Hostname:    "webserver",
		IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.1")},
		UserID:      1,
		User: types.User{
			Name: "user1",
		},
		RegisterMethod: util.RegisterMethodAuthKey,
		HostInfo:       types.HostInfo(hostInfo),
	}

	hostInfo2 := tailcfg.Hostinfo{
		OS:       "debian",
		Hostname: "Hostname",
	}

	machine2 := types.Machine{
		ID:          2,
		MachineKey:  "56789",
		NodeKey:     "bar2",
		DiscoKey:    "faab",
		Hostname:    "user",
		IPAddresses: types.MachineAddresses{netip.MustParseAddr("100.64.0.2")},
		UserID:      1,
		User: types.User{
			Name: "user1",
		},
		RegisterMethod: util.RegisterMethodAuthKey,
		HostInfo:       types.HostInfo(hostInfo2),
	}

	pol := &policy.ACLPolicy{
		TagOwners: policy.TagOwners{"tag:webapp": []string{"user1"}},
		ACLs: []policy.ACL{
			{
				Action:       "accept",
				Sources:      []string{"user1"},
				Destinations: []string{"tag:webapp:80,443"},
			},
		},
	}

	got, _, err := policy.GenerateFilterRules(pol, &machine, types.Machines{machine2}, false)
	assert.NoError(t, err)

	want := []tailcfg.FilterRule{
		{
			SrcIPs: []string{"100.64.0.2/32"},
			DstPorts: []tailcfg.NetPortRange{
				{IP: "100.64.0.1/32", Ports: tailcfg.PortRange{First: 80, Last: 80}},
				{IP: "100.64.0.1/32", Ports: tailcfg.PortRange{First: 443, Last: 443}},
			},
		},
	}

	if diff := cmp.Diff(want, got); diff != "" {
		t.Errorf("TestValidTagInvalidUser() unexpected result (-want +got):\n%s", diff)
	}
}