Skip to content

Commit 714fc0d

Browse files
committed
refactor(firewall): extract shared code from enable/disable commands
Extract common filter flag definitions, matching logic, and rule modification code into rule_modify_shared.go to eliminate duplication between enable and disable commands. Benefits: - Reduces code from ~200 lines each to ~45 lines per command - Single source of truth for filter logic and validation - Easier to maintain and extend with new features - All tests pass, no behavior changes The shared functions: - addRuleFilterFlags: Defines all filter flags - configureRuleFilterFlagsPostAdd: Sets up mutual exclusivity and completion - findMatchingRules: Filters rules based on criteria - modifyFirewallRules: Main modification logic with confirmation
1 parent 2058d84 commit 714fc0d

3 files changed

Lines changed: 224 additions & 330 deletions

File tree

Lines changed: 11 additions & 165 deletions
Original file line numberDiff line numberDiff line change
@@ -1,28 +1,17 @@
11
package serverfirewall
22

33
import (
4-
"fmt"
5-
"strings"
6-
74
"github.com/UpCloudLtd/upcloud-cli/v3/internal/commands"
85
"github.com/UpCloudLtd/upcloud-cli/v3/internal/completion"
96
"github.com/UpCloudLtd/upcloud-cli/v3/internal/output"
107
"github.com/UpCloudLtd/upcloud-cli/v3/internal/resolver"
118
"github.com/UpCloudLtd/upcloud-go-api/v8/upcloud"
12-
"github.com/UpCloudLtd/upcloud-go-api/v8/upcloud/request"
13-
"github.com/spf13/cobra"
149
"github.com/spf13/pflag"
1510
)
1611

1712
type ruleDisableCommand struct {
1813
*commands.BaseCommand
19-
rulePosition int
20-
ruleComment string
21-
ruleDirection string
22-
ruleProtocol string
23-
ruleDestPort string
24-
ruleSrcAddress string
25-
skipConfirmation int
14+
params ruleModifyParams
2615
completion.Server
2716
resolver.CachingServer
2817
}
@@ -33,169 +22,26 @@ func RuleDisableCommand() commands.Command {
3322
BaseCommand: commands.New(
3423
"disable",
3524
"Disable firewall rules by changing their action to drop",
36-
"upctl server firewall rule disable 00038afc-d526-4148-af0e-d2f1eeaded9b --comment \"Dev ports\"",
37-
"upctl server firewall rule disable 00038afc-d526-4148-af0e-d2f1eeaded9b --direction in --protocol tcp --dest-port 8080",
38-
"upctl server firewall rule disable 00038afc-d526-4148-af0e-d2f1eeaded9b --comment \"Test\" --direction in --skip-confirmation 10",
39-
"upctl server firewall rule disable 00038afc-d526-4148-af0e-d2f1eeaded9b --position 5",
25+
"upctl server firewall rule disable myserver --dest-port 80",
26+
"upctl server firewall rule disable myserver --comment \"Dev ports\"",
27+
"upctl server firewall rule disable myserver --direction out --protocol udp --dest-port 53",
28+
"upctl server firewall rule disable myserver --position 5",
4029
),
41-
skipConfirmation: 1,
30+
params: ruleModifyParams{
31+
skipConfirmation: 1,
32+
},
4233
}
4334
}
4435

4536
// InitCommand implements Command.InitCommand
4637
func (s *ruleDisableCommand) InitCommand() {
47-
directions := []string{upcloud.FirewallRuleDirectionIn, upcloud.FirewallRuleDirectionOut}
48-
protocols := []string{upcloud.FirewallRuleProtocolTCP, upcloud.FirewallRuleProtocolUDP, upcloud.FirewallRuleProtocolICMP}
49-
5038
flagSet := &pflag.FlagSet{}
51-
flagSet.IntVar(&s.rulePosition, "position", 0, "Rule position. Available: 1-1000")
52-
flagSet.StringVar(&s.ruleComment, "comment", "", "Filter by comment (partial match, case-insensitive)")
53-
flagSet.StringVar(&s.ruleDirection, "direction", "", "Filter by direction. Available: "+strings.Join(directions, ", "))
54-
flagSet.StringVar(&s.ruleProtocol, "protocol", "", "Filter by protocol. Available: "+strings.Join(protocols, ", "))
55-
flagSet.StringVar(&s.ruleDestPort, "dest-port", "", "Filter by destination port (matches both start and end)")
56-
flagSet.StringVar(&s.ruleSrcAddress, "src-address", "", "Filter by source address (partial match)")
57-
flagSet.IntVar(&s.skipConfirmation, "skip-confirmation", 1, "Maximum rules to modify without confirmation. Use 0 to always require confirmation, even for a single rule.")
39+
addRuleFilterFlags(flagSet, &s.params, s.Cobra())
5840
s.AddFlags(flagSet)
59-
60-
s.Cobra().MarkFlagsMutuallyExclusive("position", "comment")
61-
s.Cobra().MarkFlagsMutuallyExclusive("position", "direction")
62-
s.Cobra().MarkFlagsMutuallyExclusive("position", "protocol")
63-
s.Cobra().MarkFlagsMutuallyExclusive("position", "dest-port")
64-
s.Cobra().MarkFlagsMutuallyExclusive("position", "src-address")
65-
commands.Must(s.Cobra().RegisterFlagCompletionFunc("position", cobra.NoFileCompletions))
66-
commands.Must(s.Cobra().RegisterFlagCompletionFunc("comment", cobra.NoFileCompletions))
67-
commands.Must(s.Cobra().RegisterFlagCompletionFunc("direction", cobra.FixedCompletions(directions, cobra.ShellCompDirectiveNoFileComp)))
68-
commands.Must(s.Cobra().RegisterFlagCompletionFunc("protocol", cobra.FixedCompletions(protocols, cobra.ShellCompDirectiveNoFileComp)))
69-
commands.Must(s.Cobra().RegisterFlagCompletionFunc("dest-port", cobra.NoFileCompletions))
70-
commands.Must(s.Cobra().RegisterFlagCompletionFunc("src-address", cobra.NoFileCompletions))
71-
commands.Must(s.Cobra().RegisterFlagCompletionFunc("skip-confirmation", cobra.NoFileCompletions))
41+
configureRuleFilterFlagsPostAdd(s.Cobra())
7242
}
7343

7444
// Execute implements commands.MultipleArgumentCommand
7545
func (s *ruleDisableCommand) Execute(exec commands.Executor, arg string) (output.Output, error) {
76-
// Validation
77-
hasFilters := s.rulePosition != 0 || s.ruleComment != "" || s.ruleDirection != "" ||
78-
s.ruleProtocol != "" || s.ruleDestPort != "" || s.ruleSrcAddress != ""
79-
if !hasFilters {
80-
return nil, fmt.Errorf("at least one filter must be specified (--comment, --direction, --protocol, --dest-port, --src-address, or --position)")
81-
}
82-
if s.rulePosition != 0 && (s.rulePosition < 1 || s.rulePosition > 1000) {
83-
return nil, fmt.Errorf("invalid position (1-1000 allowed)")
84-
}
85-
86-
msg := fmt.Sprintf("Disabling firewall rules on server %v", arg)
87-
exec.PushProgressStarted(msg)
88-
89-
// Fetch current firewall rules
90-
currentRules, err := exec.Firewall().GetFirewallRules(exec.Context(), &request.GetFirewallRulesRequest{
91-
ServerUUID: arg,
92-
})
93-
if err != nil {
94-
return commands.HandleError(exec, msg, err)
95-
}
96-
97-
// Find matching rules
98-
var matchedIndices []int
99-
for i := range currentRules.FirewallRules {
100-
rule := &currentRules.FirewallRules[i]
101-
102-
// Position-based filter (exact match, exclusive)
103-
if s.rulePosition != 0 {
104-
if rule.Position == s.rulePosition {
105-
matchedIndices = append(matchedIndices, i)
106-
}
107-
continue
108-
}
109-
110-
// Apply all specified filters (AND logic)
111-
match := true
112-
113-
if s.ruleComment != "" {
114-
if !strings.Contains(strings.ToLower(rule.Comment), strings.ToLower(s.ruleComment)) {
115-
match = false
116-
}
117-
}
118-
119-
if s.ruleDirection != "" {
120-
if !strings.EqualFold(rule.Direction, s.ruleDirection) {
121-
match = false
122-
}
123-
}
124-
125-
if s.ruleProtocol != "" {
126-
if !strings.EqualFold(rule.Protocol, s.ruleProtocol) {
127-
match = false
128-
}
129-
}
130-
131-
if s.ruleDestPort != "" {
132-
// Match if either start or end matches the specified port
133-
if rule.DestinationPortStart != s.ruleDestPort && rule.DestinationPortEnd != s.ruleDestPort {
134-
match = false
135-
}
136-
}
137-
138-
if s.ruleSrcAddress != "" {
139-
// Partial match on either start or end address
140-
addrLower := strings.ToLower(s.ruleSrcAddress)
141-
if !strings.Contains(strings.ToLower(rule.SourceAddressStart), addrLower) &&
142-
!strings.Contains(strings.ToLower(rule.SourceAddressEnd), addrLower) {
143-
match = false
144-
}
145-
}
146-
147-
if match {
148-
matchedIndices = append(matchedIndices, i)
149-
}
150-
}
151-
152-
if len(matchedIndices) == 0 {
153-
if s.rulePosition != 0 {
154-
return nil, fmt.Errorf("firewall rule at position %d not found on server %s", s.rulePosition, arg)
155-
}
156-
return nil, fmt.Errorf("no firewall rules matching the specified filters found on server %s", arg)
157-
}
158-
159-
// Confirmation check
160-
if len(matchedIndices) > s.skipConfirmation {
161-
var ruleDescriptions []string
162-
for _, idx := range matchedIndices {
163-
rule := currentRules.FirewallRules[idx]
164-
desc := fmt.Sprintf(" - Position %d: %s %s", rule.Position, rule.Direction, rule.Protocol)
165-
if rule.Comment != "" {
166-
desc += fmt.Sprintf(" (comment: %q)", rule.Comment)
167-
}
168-
ruleDescriptions = append(ruleDescriptions, desc)
169-
}
170-
171-
return nil, fmt.Errorf("would disable %d rules (exceeds skip-confirmation=%d). Matching rules:\n%s\n\nIncrease --skip-confirmation to proceed",
172-
len(matchedIndices), s.skipConfirmation, strings.Join(ruleDescriptions, "\n"))
173-
}
174-
175-
// Modify matched rules
176-
modifiedCount := 0
177-
for _, idx := range matchedIndices {
178-
if currentRules.FirewallRules[idx].Action != upcloud.FirewallRuleActionDrop {
179-
currentRules.FirewallRules[idx].Action = upcloud.FirewallRuleActionDrop
180-
modifiedCount++
181-
}
182-
}
183-
184-
if modifiedCount == 0 {
185-
return nil, fmt.Errorf("all %d matching rules already disabled", len(matchedIndices))
186-
}
187-
188-
// Replace entire ruleset atomically
189-
err = exec.Firewall().CreateFirewallRules(exec.Context(), &request.CreateFirewallRulesRequest{
190-
ServerUUID: arg,
191-
FirewallRules: currentRules.FirewallRules,
192-
})
193-
if err != nil {
194-
return commands.HandleError(exec, msg, err)
195-
}
196-
197-
msg = fmt.Sprintf("Disabled %d firewall rule(s) on server %v", modifiedCount, arg)
198-
exec.PushProgressSuccess(msg)
199-
200-
return output.None{}, nil
46+
return modifyFirewallRules(exec, arg, &s.params, upcloud.FirewallRuleActionDrop, "disable")
20147
}

0 commit comments

Comments
 (0)