diff --git a/CHANGELOG.md b/CHANGELOG.md index ce3e10e772..a57e2ea159 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -173,6 +173,7 @@ This will also affect the way you - Add `dns.extra_records_path` configuration option [#2262](https://github.com/juanfont/headscale/issues/2262) - Support client verify for DERP [#2046](https://github.com/juanfont/headscale/pull/2046) - Add PKCE Verifier for OIDC [#2314](https://github.com/juanfont/headscale/pull/2314) +- Add support for (nextDNS) node attributes [#2329](https://github.com/juanfont/headscale/pull/2329) ## 0.23.0 (2024-09-18) diff --git a/hscontrol/mapper/mapper.go b/hscontrol/mapper/mapper.go index e18276ad6f..b608a28e81 100644 --- a/hscontrol/mapper/mapper.go +++ b/hscontrol/mapper/mapper.go @@ -115,6 +115,7 @@ func generateUserProfiles( func generateDNSConfig( cfg *types.Config, node *types.Node, + nodeAttrs []string, ) *tailcfg.DNSConfig { if cfg.TailcfgDNSConfig == nil { return nil @@ -122,7 +123,7 @@ func generateDNSConfig( dnsConfig := cfg.TailcfgDNSConfig.Clone() - addNextDNSMetadata(dnsConfig.Resolvers, node) + addNextDNSMetadata(dnsConfig.Resolvers, node, nodeAttrs) return dnsConfig } @@ -134,12 +135,27 @@ func generateDNSConfig( // // This will produce a resolver like: // `https://dns.nextdns.io/?device_name=node-name&device_model=linux&device_ip=100.64.0.1` -func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) { +func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node, nodeAttrs []string) { for _, resolver := range resolvers { if strings.HasPrefix(resolver.Addr, nextDNSDoHPrefix) { + + idx := slices.IndexFunc(nodeAttrs, func(item string) bool { return strings.HasPrefix(item, "nextdns:") && item != "nextdns:no-device-info" }) + + if idx != -1 { + nextDNSProfile := strings.Split(nodeAttrs[idx], ":")[1] + resolver.Addr = fmt.Sprintf("%s/%s", nextDNSDoHPrefix, nextDNSProfile) + } + + if slices.Contains(nodeAttrs, "nextdns:no-device-info") { + continue + } + attrs := url.Values{ - "device_name": []string{node.Hostname}, - "device_model": []string{node.Hostinfo.OS}, + "device_name": []string{node.Hostname}, + } + + if node.Hostinfo != nil { + attrs.Add("device_model", node.Hostinfo.OS) } if len(node.IPs()) > 0 { @@ -158,7 +174,13 @@ func (m *Mapper) fullMapResponse( peers types.Nodes, capVer tailcfg.CapabilityVersion, ) (*tailcfg.MapResponse, error) { - resp, err := m.baseWithConfigMapResponse(node, capVer) + + nodeAttrs, err := m.polMan.NodeAttributes(node) + if err != nil { + return nil, err + } + + resp, err := m.baseWithConfigMapResponse(node, capVer, nodeAttrs) if err != nil { return nil, err } @@ -171,6 +193,7 @@ func (m *Mapper) fullMapResponse( capVer, peers, m.cfg, + nodeAttrs, ) if err != nil { return nil, err @@ -206,7 +229,13 @@ func (m *Mapper) ReadOnlyMapResponse( node *types.Node, messages ...string, ) ([]byte, error) { - resp, err := m.baseWithConfigMapResponse(node, mapRequest.Version) + + nodeAttrs, err := m.polMan.NodeAttributes(node) + if err != nil { + return nil, err + } + + resp, err := m.baseWithConfigMapResponse(node, mapRequest.Version, nodeAttrs) if err != nil { return nil, err } @@ -268,6 +297,11 @@ func (m *Mapper) PeerChangedResponse( } } + nodeAttrs, err := m.polMan.NodeAttributes(node) + if err != nil { + return nil, err + } + err = appendPeerChanges( &resp, false, // partial change @@ -276,6 +310,7 @@ func (m *Mapper) PeerChangedResponse( mapRequest.Version, changedNodes, m.cfg, + nodeAttrs, ) if err != nil { return nil, err @@ -300,7 +335,7 @@ func (m *Mapper) PeerChangedResponse( // Add the node itself, it might have changed, and particularly // if there are no patches or changes, this is a self update. - tailnode, err := tailNode(node, mapRequest.Version, m.polMan, m.cfg) + tailnode, err := tailNode(node, mapRequest.Version, m.polMan, m.cfg, nodeAttrs) if err != nil { return nil, err } @@ -444,10 +479,11 @@ func (m *Mapper) baseMapResponse() tailcfg.MapResponse { func (m *Mapper) baseWithConfigMapResponse( node *types.Node, capVer tailcfg.CapabilityVersion, + nodeAttrs []string, ) (*tailcfg.MapResponse, error) { resp := m.baseMapResponse() - tailnode, err := tailNode(node, capVer, m.polMan, m.cfg) + tailnode, err := tailNode(node, capVer, m.polMan, m.cfg, nodeAttrs) if err != nil { return nil, err } @@ -505,6 +541,7 @@ func appendPeerChanges( capVer tailcfg.CapabilityVersion, changed types.Nodes, cfg *types.Config, + attrs []string, ) error { filter := polMan.Filter() @@ -521,7 +558,7 @@ func appendPeerChanges( profiles := generateUserProfiles(node, changed) - dnsConfig := generateDNSConfig(cfg, node) + dnsConfig := generateDNSConfig(cfg, node, attrs) tailPeers, err := tailNodes(changed, capVer, polMan, cfg) if err != nil { diff --git a/hscontrol/mapper/mapper_test.go b/hscontrol/mapper/mapper_test.go index 55ab2ccbf7..4e502b3ac2 100644 --- a/hscontrol/mapper/mapper_test.go +++ b/hscontrol/mapper/mapper_test.go @@ -120,6 +120,7 @@ func TestDNSConfigMapResponse(t *testing.T) { TailcfgDNSConfig: &dnsConfigOrig, }, nodeInShared1, + []string{}, ) if diff := cmp.Diff(tt.want, got, cmpopts.EquateEmpty()); diff != "" { @@ -129,6 +130,114 @@ func TestDNSConfigMapResponse(t *testing.T) { } } +func TestAddNextDNSMetadata(t *testing.T) { + tests := []struct { + name string + attrs []string + in []*dnstype.Resolver + want []*dnstype.Resolver + }{ + { + name: "With NextDNS resolver, without nodeattrs", + attrs: []string{}, + in: []*dnstype.Resolver{ + { + Addr: "https://dns.nextdns.io/abcdef", + }, + }, + want: []*dnstype.Resolver{ + { + Addr: "https://dns.nextdns.io/abcdef?device_ip=100.64.0.1&device_model=linux&device_name=testnode", + }, + }, + }, + { + name: "With NextDNS resolver, with nodeattrs [nextdns:fedcba]", + attrs: []string{ + "nextdns:fedcba", + }, + in: []*dnstype.Resolver{ + { + Addr: "https://dns.nextdns.io/abcdef", + }, + }, + want: []*dnstype.Resolver{ + { + Addr: "https://dns.nextdns.io/fedcba?device_ip=100.64.0.1&device_model=linux&device_name=testnode", + }, + }, + }, + { + name: "With NextDNS resolver, with nodeattrs [nextdns:no-device-info]", + attrs: []string{ + "nextdns:no-device-info", + }, + in: []*dnstype.Resolver{ + { + Addr: "https://dns.nextdns.io/abcdef", + }, + }, + want: []*dnstype.Resolver{ + { + Addr: "https://dns.nextdns.io/abcdef", + }, + }, + }, + { + name: "With NextDNS resolver, with nodeattrs: [nextdns:fedcba, nextdns:no-device-info]", + attrs: []string{ + "nextdns:fedcba", + "nextdns:no-device-info", + }, + in: []*dnstype.Resolver{ + { + Addr: "https://dns.nextdns.io/abcdef", + }, + }, + want: []*dnstype.Resolver{ + { + Addr: "https://dns.nextdns.io/fedcba", + }, + }, + }, + { + name: "No NextDNS resolver, with nodeattrs: [nextdns:fedcba, nextdns:no-device-info]", + attrs: []string{ + "nextdns:fedcba", + "nextdns:no-device-info", + }, + in: []*dnstype.Resolver{ + { + Addr: "1.1.1.1", + }, + }, + want: []*dnstype.Resolver{ + { + Addr: "1.1.1.1", + }, + }, + }, + } + + node := &types.Node{ + Hostname: "testnode", + Hostinfo: &tailcfg.Hostinfo{ + OS: "linux", + }, + IPv4: iap("100.64.0.1"), + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + addNextDNSMetadata(tt.in, node, tt.attrs) + + if diff := cmp.Diff(tt.want, tt.in, cmpopts.EquateEmpty()); diff != "" { + t.Errorf("addNextDNSMetadata() unexpected result (-want +got):\n%s", diff) + } + }) + } +} + func Test_fullMapResponse(t *testing.T) { mustNK := func(str string) key.NodePublic { var k key.NodePublic diff --git a/hscontrol/mapper/tail.go b/hscontrol/mapper/tail.go index 4082df2b45..da48a439bd 100644 --- a/hscontrol/mapper/tail.go +++ b/hscontrol/mapper/tail.go @@ -20,11 +20,18 @@ func tailNodes( tNodes := make([]*tailcfg.Node, len(nodes)) for index, node := range nodes { + + nodeAttrs, err := polMan.NodeAttributes(node) + if err != nil { + return nil, err + } + node, err := tailNode( node, capVer, polMan, cfg, + nodeAttrs, ) if err != nil { return nil, err @@ -42,6 +49,7 @@ func tailNode( capVer tailcfg.CapabilityVersion, polMan policy.PolicyManager, cfg *types.Config, + nodeAttrs []string, ) (*tailcfg.Node, error) { addrs := node.Prefixes() @@ -124,6 +132,10 @@ func tailNode( tNode.CapMap[tailcfg.NodeAttrRandomizeClientPort] = []tailcfg.RawMessage{} } + for _, nodeAttr := range nodeAttrs { + tNode.CapMap[tailcfg.NodeCapability(nodeAttr)] = []tailcfg.RawMessage{} + } + if node.IsOnline == nil || !*node.IsOnline { // LastSeen is only set when node is // not connected to the control server. diff --git a/hscontrol/mapper/tail_test.go b/hscontrol/mapper/tail_test.go index 96c008ab12..633bce26ff 100644 --- a/hscontrol/mapper/tail_test.go +++ b/hscontrol/mapper/tail_test.go @@ -195,6 +195,7 @@ func TestTailNode(t *testing.T) { 0, polMan, cfg, + []string{}, ) if (err != nil) != tt.wantErr { @@ -248,6 +249,7 @@ func TestNodeExpiry(t *testing.T) { 0, &policy.PolicyManagerV1{}, &types.Config{}, + []string{}, ) if err != nil { t.Fatalf("nodeExpiry() error = %v", err) diff --git a/hscontrol/policy/acls.go b/hscontrol/policy/acls.go index 3d7a6f4a0a..c7adff3ef3 100644 --- a/hscontrol/policy/acls.go +++ b/hscontrol/policy/acls.go @@ -442,6 +442,42 @@ func (pol *ACLPolicy) CompileSSHPolicy( }, nil } +func (pol *ACLPolicy) GetAttributesForNode( + node *types.Node, + users []types.User, + peers types.Nodes, +) ([]string, error) { + if pol == nil { + return nil, nil + } + + var attributes []string + + for _, nodeAttr := range pol.NodeAttributes { + var dest netipx.IPSetBuilder + for _, target := range nodeAttr.Targets { + expanded, err := pol.ExpandAlias(append(peers, node), users, target) + if err != nil { + return nil, err + } + dest.AddSet(expanded) + } + + destSet, err := dest.IPSet() + if err != nil { + return nil, err + } + + if !node.InIPSet(destSet) { + continue + } + + attributes = append(attributes, nodeAttr.Attributes...) + } + + return attributes, nil +} + // ipSetAll returns a function that iterates over all the IPs in the IPSet. func ipSetAll(ipSet *netipx.IPSet) iter.Seq[netip.Addr] { return func(yield func(netip.Addr) bool) { diff --git a/hscontrol/policy/acls_test.go b/hscontrol/policy/acls_test.go index ae8898bfd3..737d6fdf7e 100644 --- a/hscontrol/policy/acls_test.go +++ b/hscontrol/policy/acls_test.go @@ -411,6 +411,195 @@ func TestParsing(t *testing.T) { } } +func TestGetAttributesForNode(t *testing.T) { + tests := []struct { + name string + format string + useTag bool + acl string + want []string + wantErr bool + }{ + { + name: "invalid-hujson", + format: "hujson", + useTag: false, + acl: ` +{ + `, + want: []string{}, + wantErr: true, + }, + { + name: "attributes-for-node-using-user", + format: "hujson", + useTag: false, + acl: ` +{ + "hosts": { + "host-1": "100.100.100.100", + "subnet-1": "100.100.101.100/24", + }, + + "nodeAttrs": [ + { + "target": ["testuser"], + "attr": [ "test" ] + }, + { + "target": ["testuser1"], + "attr": [ "test1" ] + } + ], +} + `, + want: []string{"test"}, + wantErr: false, + }, + { + name: "attributes-for-node-using-wildcard", + format: "hujson", + useTag: false, + acl: ` +{ + "hosts": { + "host-1": "100.100.100.100", + "subnet-1": "100.100.101.100/24", + }, + + "nodeAttrs": [ + { + "target": ["*"], + "attr": [ "test" ] + } + ], +} + `, + want: []string{"test"}, + wantErr: false, + }, + { + name: "attributes-for-node-using-group", + format: "hujson", + useTag: false, + acl: ` +{ + "groups": { + "group:example": [ + "testuser", + ], + "group:example1": [ + ], + }, + + "hosts": { + "host-1": "100.100.100.100", + "subnet-1": "100.100.101.100/24", + }, + + "nodeAttrs": [ + { + "target": ["group:example"], + "attr": [ "test" ] + }, + { + "target": ["group:example1"], + "attr": [ "test" ] + } + ], +} + `, + want: []string{"test"}, + wantErr: false, + }, + { + name: "attributes-for-node-using-tag", + format: "hujson", + useTag: true, + acl: ` +{ + "tagOwners": { + "tag:example": ["testuser"], + }, + + "hosts": { + "host-1": "100.100.100.100", + "subnet-1": "100.100.101.100/24", + }, + + "nodeAttrs": [ + { + "target": ["tag:example"], + "attr": [ "test" ] + } + ], +} + `, + want: []string{"test"}, + wantErr: false, + }, + } + + for _, tt := range tests { + t.Run(tt.name, func(t *testing.T) { + pol, err := LoadACLPolicyFromBytes([]byte(tt.acl)) + + if tt.wantErr && err == nil { + t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr) + + return + } else if !tt.wantErr && err != nil { + t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr) + + return + } + + if err != nil { + return + } + + user := types.User{ + Model: gorm.Model{ID: 1}, + Name: "testuser", + } + + node := types.Node{ + IPv4: iap("100.100.100.100"), + User: user, + Hostinfo: &tailcfg.Hostinfo{}, + } + + if tt.useTag { + node.Hostinfo.RequestTags = []string{"tag:example"} + } + + rules, err := pol.GetAttributesForNode( + &node, + []types.User{ + user, + }, + types.Nodes{ + &node, + &types.Node{ + IPv4: iap("200.200.200.200"), + User: user, + Hostinfo: &tailcfg.Hostinfo{}, + }, + }) + + if (err != nil) != tt.wantErr { + t.Errorf("parsing() error = %v, wantErr %v", err, tt.wantErr) + + return + } + + if diff := cmp.Diff(tt.want, rules); diff != "" { + t.Errorf("parsing() unexpected result (-want +got):\n%s", diff) + } + }) + } +} + func (s *Suite) TestRuleInvalidGeneration(c *check.C) { acl := []byte(` { diff --git a/hscontrol/policy/acls_types.go b/hscontrol/policy/acls_types.go index 5b5d183829..075d285989 100644 --- a/hscontrol/policy/acls_types.go +++ b/hscontrol/policy/acls_types.go @@ -10,13 +10,14 @@ import ( // ACLPolicy represents a Tailscale ACL Policy. type ACLPolicy struct { - Groups Groups `json:"groups"` - Hosts Hosts `json:"hosts"` - TagOwners TagOwners `json:"tagOwners"` - ACLs []ACL `json:"acls"` - Tests []ACLTest `json:"tests"` - AutoApprovers AutoApprovers `json:"autoApprovers"` - SSHs []SSH `json:"ssh"` + Groups Groups `json:"groups"` + Hosts Hosts `json:"hosts"` + TagOwners TagOwners `json:"tagOwners"` + ACLs []ACL `json:"acls"` + Tests []ACLTest `json:"tests"` + AutoApprovers AutoApprovers `json:"autoApprovers"` + SSHs []SSH `json:"ssh"` + NodeAttributes []NodeAttributes `json:"nodeAttrs"` } // ACL is a basic rule for the ACL Policy. @@ -50,6 +51,12 @@ type AutoApprovers struct { ExitNode []string `json:"exitNode"` } +// NodeAttributes is for applying additional attributes to specific devices +type NodeAttributes struct { + Targets []string `json:"target"` + Attributes []string `json:"attr"` +} + // SSH controls who can ssh into which machines. type SSH struct { Action string `json:"action"` diff --git a/hscontrol/policy/pm.go b/hscontrol/policy/pm.go index 4e10003ed9..b62163ea20 100644 --- a/hscontrol/policy/pm.go +++ b/hscontrol/policy/pm.go @@ -17,6 +17,7 @@ import ( type PolicyManager interface { Filter() []tailcfg.FilterRule SSHPolicy(*types.Node) (*tailcfg.SSHPolicy, error) + NodeAttributes(node *types.Node) ([]string, error) Tags(*types.Node) []string ApproversForRoute(netip.Prefix) []string ExpandAlias(string) (*netipx.IPSet, error) @@ -140,6 +141,13 @@ func (pm *PolicyManagerV1) SetPolicy(polB []byte) (bool, error) { return pm.updateLocked() } +func (pm *PolicyManagerV1) NodeAttributes(node *types.Node) ([]string, error) { + pm.mu.Lock() + defer pm.mu.Unlock() + + return pm.pol.GetAttributesForNode(node, pm.users, pm.nodes) +} + // SetUsers updates the users in the policy manager and updates the filter rules. func (pm *PolicyManagerV1) SetUsers(users []types.User) (bool, error) { pm.mu.Lock() diff --git a/integration/nodeAttrs_test.go b/integration/nodeAttrs_test.go new file mode 100644 index 0000000000..e1b45807bf --- /dev/null +++ b/integration/nodeAttrs_test.go @@ -0,0 +1,148 @@ +package integration + +import ( + "regexp" + "testing" + + "github.com/juanfont/headscale/hscontrol/policy" + "github.com/juanfont/headscale/integration/hsic" + "github.com/juanfont/headscale/integration/tsic" + "github.com/stretchr/testify/require" +) + +func TestNodeAttrsNextDNS(t *testing.T) { + IntegrationSkip(t) + + tests := []struct { + name string + policy policy.ACLPolicy + wantedResolverRegex map[string]string + }{ + { + name: "NextDNS attribute for all", + policy: policy.ACLPolicy{ + ACLs: []policy.ACL{ + { + Action: "accept", + Sources: []string{"*"}, + Destinations: []string{"*:*"}, + }, + }, + NodeAttributes: []policy.NodeAttributes{ + { + Targets: []string{"*"}, + Attributes: []string{"nextdns:fedcba"}, + }, + }, + }, + wantedResolverRegex: map[string]string{ + "user1": "https://dns\\.nextdns\\.io/fedcba\\?device_ip=.*?\\&device_model=.*?&device_name=.*", + "user2": "https://dns\\.nextdns\\.io/fedcba\\?device_ip=.*?\\&device_model=.*?&device_name=.*", + }, + }, + { + name: "NextDNS attribute for user 1", + policy: policy.ACLPolicy{ + ACLs: []policy.ACL{ + { + Action: "accept", + Sources: []string{"*"}, + Destinations: []string{"*:*"}, + }, + }, + NodeAttributes: []policy.NodeAttributes{ + { + Targets: []string{"user1"}, + Attributes: []string{"nextdns:fedcba"}, + }, + }, + }, + wantedResolverRegex: map[string]string{ + "user1": "https://dns\\.nextdns\\.io/fedcba\\?device_ip=.*?\\&device_model=.*?&device_name=.*", + "user2": "https://dns\\.nextdns\\.io/abcdef\\?device_ip=.*?\\&device_model=.*?&device_name=.*", + }, + }, + { + name: "NextDNS attribute for no deviceInfo", + policy: policy.ACLPolicy{ + ACLs: []policy.ACL{ + { + Action: "accept", + Sources: []string{"*"}, + Destinations: []string{"*:*"}, + }, + }, + NodeAttributes: []policy.NodeAttributes{ + { + Targets: []string{"*"}, + Attributes: []string{"nextdns:no-device-info"}, + }, + }, + }, + wantedResolverRegex: map[string]string{ + "user1": "https://dns\\.nextdns\\.io/abcdef", + "user2": "https://dns\\.nextdns\\.io/abcdef", + }, + }, + } + + spec := map[string]int{ + "user1": 2, + "user2": 2, + } + + for _, testcase := range tests { + t.Run(testcase.name, func(t *testing.T) { + scenario, err := NewScenario(dockertestMaxWait()) + require.NoError(t, err) + + scenario.CreateHeadscaleEnv(spec, + []tsic.Option{ + tsic.WithSSH(), + + // Alpine containers dont have ip6tables set up, which causes + // tailscaled to stop configuring the wgengine, causing it + // to not configure DNS. + tsic.WithNetfilter("off"), + tsic.WithDockerEntrypoint([]string{ + "/bin/sh", + "-c", + "/bin/sleep 3 ; apk add openssh ; adduser ssh-it-user ; update-ca-certificates ; tailscaled --tun=tsdev", + }), + tsic.WithDockerWorkdir("/"), + }, + hsic.WithACLPolicy(&testcase.policy), + hsic.WithConfigEnv(map[string]string{ + "HEADSCALE_DNS_NAMESERVERS_GLOBAL": "https://dns.nextdns.io/abcdef", + }), + ) + + require.NoError(t, err) + defer scenario.ShutdownAssertNoPanics(t) + + for user, expectedResolver := range testcase.wantedResolverRegex { + + expr, err := regexp.Compile(expectedResolver) + require.NoError(t, err) + + clients, err := scenario.ListTailscaleClients(user) + require.NoError(t, err) + + for _, client := range clients { + + output, _, err := client.Execute([]string{ + "tailscale", + "dns", + "status", + }) + + require.NoError(t, err) + + if !expr.MatchString(output) { + t.Logf("unexpected resolver expected: '%s', actual: '%s'", expectedResolver, output) + } + } + } + }) + } +}