Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat: Support for (nextDNS) node attributes (nodeattrs) #2329

Draft
wants to merge 5 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
1 change: 1 addition & 0 deletions CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,7 @@ This will also affect the way you
- Add `dns.extra_records_path` configuration option [#2262](https://github.com/juanfont/headscale/issues/2262)
- Support client verify for DERP [#2046](https://github.com/juanfont/headscale/pull/2046)
- Add PKCE Verifier for OIDC [#2314](https://github.com/juanfont/headscale/pull/2314)
- Add support for (nextDNS) node attributes [#2329](https://github.com/juanfont/headscale/pull/2329)

## 0.23.0 (2024-09-18)

Expand Down
55 changes: 46 additions & 9 deletions hscontrol/mapper/mapper.go
Original file line number Diff line number Diff line change
Expand Up @@ -115,14 +115,15 @@ func generateUserProfiles(
func generateDNSConfig(
cfg *types.Config,
node *types.Node,
nodeAttrs []string,
) *tailcfg.DNSConfig {
if cfg.TailcfgDNSConfig == nil {
return nil
}

dnsConfig := cfg.TailcfgDNSConfig.Clone()

addNextDNSMetadata(dnsConfig.Resolvers, node)
addNextDNSMetadata(dnsConfig.Resolvers, node, nodeAttrs)

return dnsConfig
}
Expand All @@ -134,12 +135,27 @@ func generateDNSConfig(
//
// This will produce a resolver like:
// `https://dns.nextdns.io/<nextdns-id>?device_name=node-name&device_model=linux&device_ip=100.64.0.1`
func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node) {
func addNextDNSMetadata(resolvers []*dnstype.Resolver, node *types.Node, nodeAttrs []string) {
for _, resolver := range resolvers {
if strings.HasPrefix(resolver.Addr, nextDNSDoHPrefix) {

idx := slices.IndexFunc(nodeAttrs, func(item string) bool { return strings.HasPrefix(item, "nextdns:") && item != "nextdns:no-device-info" })

if idx != -1 {
nextDNSProfile := strings.Split(nodeAttrs[idx], ":")[1]
resolver.Addr = fmt.Sprintf("%s/%s", nextDNSDoHPrefix, nextDNSProfile)
}

if slices.Contains(nodeAttrs, "nextdns:no-device-info") {
continue
}

attrs := url.Values{
"device_name": []string{node.Hostname},
"device_model": []string{node.Hostinfo.OS},
"device_name": []string{node.Hostname},
}

if node.Hostinfo != nil {
attrs.Add("device_model", node.Hostinfo.OS)
}

if len(node.IPs()) > 0 {
Expand All @@ -158,7 +174,13 @@ func (m *Mapper) fullMapResponse(
peers types.Nodes,
capVer tailcfg.CapabilityVersion,
) (*tailcfg.MapResponse, error) {
resp, err := m.baseWithConfigMapResponse(node, capVer)

nodeAttrs, err := m.polMan.NodeAttributes(node)
if err != nil {
return nil, err
}

resp, err := m.baseWithConfigMapResponse(node, capVer, nodeAttrs)
if err != nil {
return nil, err
}
Expand All @@ -171,6 +193,7 @@ func (m *Mapper) fullMapResponse(
capVer,
peers,
m.cfg,
nodeAttrs,
)
if err != nil {
return nil, err
Expand Down Expand Up @@ -206,7 +229,13 @@ func (m *Mapper) ReadOnlyMapResponse(
node *types.Node,
messages ...string,
) ([]byte, error) {
resp, err := m.baseWithConfigMapResponse(node, mapRequest.Version)

nodeAttrs, err := m.polMan.NodeAttributes(node)
if err != nil {
return nil, err
}

resp, err := m.baseWithConfigMapResponse(node, mapRequest.Version, nodeAttrs)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -268,6 +297,11 @@ func (m *Mapper) PeerChangedResponse(
}
}

nodeAttrs, err := m.polMan.NodeAttributes(node)
if err != nil {
return nil, err
}

err = appendPeerChanges(
&resp,
false, // partial change
Expand All @@ -276,6 +310,7 @@ func (m *Mapper) PeerChangedResponse(
mapRequest.Version,
changedNodes,
m.cfg,
nodeAttrs,
)
if err != nil {
return nil, err
Expand All @@ -300,7 +335,7 @@ func (m *Mapper) PeerChangedResponse(

// Add the node itself, it might have changed, and particularly
// if there are no patches or changes, this is a self update.
tailnode, err := tailNode(node, mapRequest.Version, m.polMan, m.cfg)
tailnode, err := tailNode(node, mapRequest.Version, m.polMan, m.cfg, nodeAttrs)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -444,10 +479,11 @@ func (m *Mapper) baseMapResponse() tailcfg.MapResponse {
func (m *Mapper) baseWithConfigMapResponse(
node *types.Node,
capVer tailcfg.CapabilityVersion,
nodeAttrs []string,
) (*tailcfg.MapResponse, error) {
resp := m.baseMapResponse()

tailnode, err := tailNode(node, capVer, m.polMan, m.cfg)
tailnode, err := tailNode(node, capVer, m.polMan, m.cfg, nodeAttrs)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -505,6 +541,7 @@ func appendPeerChanges(
capVer tailcfg.CapabilityVersion,
changed types.Nodes,
cfg *types.Config,
attrs []string,
) error {
filter := polMan.Filter()

Expand All @@ -521,7 +558,7 @@ func appendPeerChanges(

profiles := generateUserProfiles(node, changed)

dnsConfig := generateDNSConfig(cfg, node)
dnsConfig := generateDNSConfig(cfg, node, attrs)

tailPeers, err := tailNodes(changed, capVer, polMan, cfg)
if err != nil {
Expand Down
109 changes: 109 additions & 0 deletions hscontrol/mapper/mapper_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -120,6 +120,7 @@ func TestDNSConfigMapResponse(t *testing.T) {
TailcfgDNSConfig: &dnsConfigOrig,
},
nodeInShared1,
[]string{},
)

if diff := cmp.Diff(tt.want, got, cmpopts.EquateEmpty()); diff != "" {
Expand All @@ -129,6 +130,114 @@ func TestDNSConfigMapResponse(t *testing.T) {
}
}

func TestAddNextDNSMetadata(t *testing.T) {
tests := []struct {
name string
attrs []string
in []*dnstype.Resolver
want []*dnstype.Resolver
}{
{
name: "With NextDNS resolver, without nodeattrs",
attrs: []string{},
in: []*dnstype.Resolver{
{
Addr: "https://dns.nextdns.io/abcdef",
},
},
want: []*dnstype.Resolver{
{
Addr: "https://dns.nextdns.io/abcdef?device_ip=100.64.0.1&device_model=linux&device_name=testnode",
},
},
},
{
name: "With NextDNS resolver, with nodeattrs [nextdns:fedcba]",
attrs: []string{
"nextdns:fedcba",
},
in: []*dnstype.Resolver{
{
Addr: "https://dns.nextdns.io/abcdef",
},
},
want: []*dnstype.Resolver{
{
Addr: "https://dns.nextdns.io/fedcba?device_ip=100.64.0.1&device_model=linux&device_name=testnode",
},
},
},
{
name: "With NextDNS resolver, with nodeattrs [nextdns:no-device-info]",
attrs: []string{
"nextdns:no-device-info",
},
in: []*dnstype.Resolver{
{
Addr: "https://dns.nextdns.io/abcdef",
},
},
want: []*dnstype.Resolver{
{
Addr: "https://dns.nextdns.io/abcdef",
},
},
},
{
name: "With NextDNS resolver, with nodeattrs: [nextdns:fedcba, nextdns:no-device-info]",
attrs: []string{
"nextdns:fedcba",
"nextdns:no-device-info",
},
in: []*dnstype.Resolver{
{
Addr: "https://dns.nextdns.io/abcdef",
},
},
want: []*dnstype.Resolver{
{
Addr: "https://dns.nextdns.io/fedcba",
},
},
},
{
name: "No NextDNS resolver, with nodeattrs: [nextdns:fedcba, nextdns:no-device-info]",
attrs: []string{
"nextdns:fedcba",
"nextdns:no-device-info",
},
in: []*dnstype.Resolver{
{
Addr: "1.1.1.1",
},
},
want: []*dnstype.Resolver{
{
Addr: "1.1.1.1",
},
},
},
}

node := &types.Node{
Hostname: "testnode",
Hostinfo: &tailcfg.Hostinfo{
OS: "linux",
},
IPv4: iap("100.64.0.1"),
}

for _, tt := range tests {
t.Run(tt.name, func(t *testing.T) {
addNextDNSMetadata(tt.in, node, tt.attrs)

if diff := cmp.Diff(tt.want, tt.in, cmpopts.EquateEmpty()); diff != "" {
t.Errorf("addNextDNSMetadata() unexpected result (-want +got):\n%s", diff)
}
})
}
}

func Test_fullMapResponse(t *testing.T) {
mustNK := func(str string) key.NodePublic {
var k key.NodePublic
Expand Down
12 changes: 12 additions & 0 deletions hscontrol/mapper/tail.go
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,18 @@ func tailNodes(
tNodes := make([]*tailcfg.Node, len(nodes))

for index, node := range nodes {

nodeAttrs, err := polMan.NodeAttributes(node)
if err != nil {
return nil, err
}

node, err := tailNode(
node,
capVer,
polMan,
cfg,
nodeAttrs,
)
if err != nil {
return nil, err
Expand All @@ -42,6 +49,7 @@ func tailNode(
capVer tailcfg.CapabilityVersion,
polMan policy.PolicyManager,
cfg *types.Config,
nodeAttrs []string,
) (*tailcfg.Node, error) {
addrs := node.Prefixes()

Expand Down Expand Up @@ -124,6 +132,10 @@ func tailNode(
tNode.CapMap[tailcfg.NodeAttrRandomizeClientPort] = []tailcfg.RawMessage{}
}

for _, nodeAttr := range nodeAttrs {
tNode.CapMap[tailcfg.NodeCapability(nodeAttr)] = []tailcfg.RawMessage{}
}

if node.IsOnline == nil || !*node.IsOnline {
// LastSeen is only set when node is
// not connected to the control server.
Expand Down
2 changes: 2 additions & 0 deletions hscontrol/mapper/tail_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -195,6 +195,7 @@ func TestTailNode(t *testing.T) {
0,
polMan,
cfg,
[]string{},
)

if (err != nil) != tt.wantErr {
Expand Down Expand Up @@ -248,6 +249,7 @@ func TestNodeExpiry(t *testing.T) {
0,
&policy.PolicyManagerV1{},
&types.Config{},
[]string{},
)
if err != nil {
t.Fatalf("nodeExpiry() error = %v", err)
Expand Down
36 changes: 36 additions & 0 deletions hscontrol/policy/acls.go
Original file line number Diff line number Diff line change
Expand Up @@ -442,6 +442,42 @@ func (pol *ACLPolicy) CompileSSHPolicy(
}, nil
}

func (pol *ACLPolicy) GetAttributesForNode(
node *types.Node,
users []types.User,
peers types.Nodes,
) ([]string, error) {
if pol == nil {
return nil, nil
}

var attributes []string

for _, nodeAttr := range pol.NodeAttributes {
var dest netipx.IPSetBuilder
for _, target := range nodeAttr.Targets {
expanded, err := pol.ExpandAlias(append(peers, node), users, target)
if err != nil {
return nil, err
}
dest.AddSet(expanded)
}

destSet, err := dest.IPSet()
if err != nil {
return nil, err
}

if !node.InIPSet(destSet) {
continue
}

attributes = append(attributes, nodeAttr.Attributes...)
}

return attributes, nil
}

// ipSetAll returns a function that iterates over all the IPs in the IPSet.
func ipSetAll(ipSet *netipx.IPSet) iter.Seq[netip.Addr] {
return func(yield func(netip.Addr) bool) {
Expand Down
Loading