Skip to content

Commit dbd2dd6

Browse files
authored
refactor(misconf): pass options to Rego scanner as is (#7529)
Signed-off-by: nikpivkin <[email protected]>
1 parent aeb7039 commit dbd2dd6

33 files changed

+570
-1095
lines changed

pkg/iac/rego/load.go

+9-9
Original file line numberDiff line numberDiff line change
@@ -75,7 +75,7 @@ func (s *Scanner) loadEmbedded() error {
7575
return nil
7676
}
7777

78-
func (s *Scanner) LoadPolicies(enableEmbeddedLibraries, enableEmbeddedPolicies bool, srcFS fs.FS, paths []string, readers []io.Reader) error {
78+
func (s *Scanner) LoadPolicies(srcFS fs.FS) error {
7979

8080
if s.policies == nil {
8181
s.policies = make(map[string]*ast.Module)
@@ -90,28 +90,28 @@ func (s *Scanner) LoadPolicies(enableEmbeddedLibraries, enableEmbeddedPolicies b
9090
return err
9191
}
9292

93-
if enableEmbeddedPolicies {
93+
if s.includeEmbeddedPolicies {
9494
s.policies = lo.Assign(s.policies, s.embeddedChecks)
9595
}
9696

97-
if enableEmbeddedLibraries {
97+
if s.includeEmbeddedLibraries {
9898
s.policies = lo.Assign(s.policies, s.embeddedLibs)
9999
}
100100

101101
var err error
102-
if len(paths) > 0 {
103-
loaded, err := LoadPoliciesFromDirs(srcFS, paths...)
102+
if len(s.policyDirs) > 0 {
103+
loaded, err := LoadPoliciesFromDirs(srcFS, s.policyDirs...)
104104
if err != nil {
105-
return fmt.Errorf("failed to load rego checks from %s: %w", paths, err)
105+
return fmt.Errorf("failed to load rego checks from %s: %w", s.policyDirs, err)
106106
}
107107
for name, policy := range loaded {
108108
s.policies[name] = policy
109109
}
110110
s.logger.Debug("Checks from disk are loaded", log.Int("count", len(loaded)))
111111
}
112112

113-
if len(readers) > 0 {
114-
loaded, err := s.loadPoliciesFromReaders(readers)
113+
if len(s.policyReaders) > 0 {
114+
loaded, err := s.loadPoliciesFromReaders(s.policyReaders)
115115
if err != nil {
116116
return fmt.Errorf("failed to load rego checks from reader(s): %w", err)
117117
}
@@ -143,7 +143,7 @@ func (s *Scanner) LoadPolicies(enableEmbeddedLibraries, enableEmbeddedPolicies b
143143
}
144144
s.store = store
145145

146-
return s.compilePolicies(srcFS, paths)
146+
return s.compilePolicies(srcFS, s.policyDirs)
147147
}
148148

149149
func (s *Scanner) fallbackChecks(compiler *ast.Compiler) {

pkg/iac/rego/load_test.go

+31-18
Original file line numberDiff line numberDiff line change
@@ -4,7 +4,6 @@ import (
44
"bytes"
55
"embed"
66
"fmt"
7-
"io"
87
"log/slog"
98
"strings"
109
"testing"
@@ -16,7 +15,6 @@ import (
1615

1716
checks "github.com/aquasecurity/trivy-checks"
1817
"github.com/aquasecurity/trivy/pkg/iac/rego"
19-
"github.com/aquasecurity/trivy/pkg/iac/scanners/options"
2018
"github.com/aquasecurity/trivy/pkg/iac/types"
2119
"github.com/aquasecurity/trivy/pkg/log"
2220
)
@@ -33,10 +31,11 @@ func Test_RegoScanning_WithSomeInvalidPolicies(t *testing.T) {
3331
slog.SetDefault(log.New(log.NewHandler(&debugBuf, nil)))
3432
scanner := rego.NewScanner(
3533
types.SourceDockerfile,
36-
options.ScannerWithRegoErrorLimits(0),
34+
rego.WithRegoErrorLimits(0),
35+
rego.WithPolicyDirs("."),
3736
)
3837

39-
err := scanner.LoadPolicies(false, false, testEmbedFS, []string{"."}, nil)
38+
err := scanner.LoadPolicies(testEmbedFS)
4039
require.ErrorContains(t, err, `want (one of): ["Cmd" "EndLine" "Flags" "JSON" "Original" "Path" "Stage" "StartLine" "SubCmd" "Value"]`)
4140
assert.Contains(t, debugBuf.String(), "Error(s) occurred while loading checks")
4241
})
@@ -46,10 +45,11 @@ func Test_RegoScanning_WithSomeInvalidPolicies(t *testing.T) {
4645
slog.SetDefault(log.New(log.NewHandler(&debugBuf, nil)))
4746
scanner := rego.NewScanner(
4847
types.SourceDockerfile,
49-
options.ScannerWithRegoErrorLimits(1),
48+
rego.WithRegoErrorLimits(1),
49+
rego.WithPolicyDirs("."),
5050
)
5151

52-
err := scanner.LoadPolicies(false, false, testEmbedFS, []string{"."}, nil)
52+
err := scanner.LoadPolicies(testEmbedFS)
5353
require.NoError(t, err)
5454

5555
assert.Contains(t, debugBuf.String(), "Error occurred while parsing\tfile_path=\"testdata/policies/invalid.rego\" err=\"testdata/policies/invalid.rego:7")
@@ -64,9 +64,13 @@ package mypackage
6464
deny {
6565
input.evil == "foo bar"
6666
}`
67-
scanner := rego.NewScanner(types.SourceJSON)
67+
scanner := rego.NewScanner(
68+
types.SourceJSON,
69+
rego.WithPolicyDirs("."),
70+
rego.WithPolicyReader(strings.NewReader(check)),
71+
)
6872

69-
err := scanner.LoadPolicies(false, false, fstest.MapFS{}, []string{"."}, []io.Reader{strings.NewReader(check)})
73+
err := scanner.LoadPolicies(fstest.MapFS{})
7074
assert.ErrorContains(t, err, "could not find schema \"fooschema\"")
7175
})
7276

@@ -79,15 +83,19 @@ package mypackage
7983
deny {
8084
input.evil == "foo bar"
8185
}`
82-
scanner := rego.NewScanner(types.SourceJSON)
86+
scanner := rego.NewScanner(
87+
types.SourceJSON,
88+
rego.WithPolicyDirs("."),
89+
rego.WithPolicyReader(strings.NewReader(check)),
90+
)
8391

8492
fsys := fstest.MapFS{
8593
"schemas/fooschema.json": &fstest.MapFile{
8694
Data: []byte("bad json"),
8795
},
8896
}
8997

90-
err := scanner.LoadPolicies(false, false, fsys, []string{"."}, []io.Reader{strings.NewReader(check)})
98+
err := scanner.LoadPolicies(fsys)
9199
assert.ErrorContains(t, err, "could not parse schema \"fooschema\"")
92100
})
93101

@@ -97,8 +105,12 @@ deny {
97105
deny {
98106
input.evil == "foo bar"
99107
}`
100-
scanner := rego.NewScanner(types.SourceJSON)
101-
err := scanner.LoadPolicies(false, false, fstest.MapFS{}, []string{"."}, []io.Reader{strings.NewReader(check)})
108+
scanner := rego.NewScanner(
109+
types.SourceJSON,
110+
rego.WithPolicyDirs("."),
111+
rego.WithPolicyReader(strings.NewReader(check)),
112+
)
113+
err := scanner.LoadPolicies(fstest.MapFS{})
102114
require.NoError(t, err)
103115
})
104116

@@ -184,8 +196,9 @@ deny {
184196
t.Run(tt.name, func(t *testing.T) {
185197
scanner := rego.NewScanner(
186198
types.SourceDockerfile,
187-
options.ScannerWithRegoErrorLimits(0),
188-
options.ScannerWithEmbeddedPolicies(false),
199+
rego.WithRegoErrorLimits(0),
200+
rego.WithEmbeddedPolicies(false),
201+
rego.WithPolicyDirs("."),
189202
)
190203

191204
tt.files["schemas/fooschema.json"] = &fstest.MapFile{
@@ -200,9 +213,8 @@ deny {
200213
}`),
201214
}
202215

203-
fsys := fstest.MapFS(tt.files)
204216
checks.EmbeddedPolicyFileSystem = embeddedChecksFS
205-
err := scanner.LoadPolicies(false, false, fsys, []string{"."}, nil)
217+
err := scanner.LoadPolicies(fstest.MapFS(tt.files))
206218

207219
if tt.expectedErr != "" {
208220
assert.ErrorContains(t, err, tt.expectedErr)
@@ -244,8 +256,9 @@ deny {
244256

245257
scanner := rego.NewScanner(
246258
types.SourceDockerfile,
247-
options.ScannerWithEmbeddedPolicies(false),
259+
rego.WithEmbeddedPolicies(false),
260+
rego.WithPolicyDirs("."),
248261
)
249-
err := scanner.LoadPolicies(false, false, fsys, []string{"."}, nil)
262+
err := scanner.LoadPolicies(fsys)
250263
require.Error(t, err)
251264
}

pkg/iac/rego/options.go

+108
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,108 @@
1+
package rego
2+
3+
import (
4+
"io"
5+
"io/fs"
6+
7+
"github.com/aquasecurity/trivy/pkg/iac/scanners/options"
8+
)
9+
10+
func WithPolicyReader(readers ...io.Reader) options.ScannerOption {
11+
return func(s options.ConfigurableScanner) {
12+
if ss, ok := s.(*Scanner); ok {
13+
ss.policyReaders = readers
14+
}
15+
}
16+
}
17+
18+
func WithEmbeddedPolicies(include bool) options.ScannerOption {
19+
return func(s options.ConfigurableScanner) {
20+
if ss, ok := s.(*Scanner); ok {
21+
ss.includeEmbeddedPolicies = include
22+
}
23+
}
24+
}
25+
26+
func WithEmbeddedLibraries(include bool) options.ScannerOption {
27+
return func(s options.ConfigurableScanner) {
28+
if ss, ok := s.(*Scanner); ok {
29+
ss.includeEmbeddedLibraries = include
30+
}
31+
}
32+
}
33+
34+
// WithTrace specifies an io.Writer for trace logs (mainly rego tracing) - if not set, they are discarded
35+
func WithTrace(w io.Writer) options.ScannerOption {
36+
return func(s options.ConfigurableScanner) {
37+
if ss, ok := s.(*Scanner); ok {
38+
ss.traceWriter = w
39+
}
40+
}
41+
}
42+
43+
func WithPerResultTracing(enabled bool) options.ScannerOption {
44+
return func(s options.ConfigurableScanner) {
45+
if ss, ok := s.(*Scanner); ok {
46+
ss.tracePerResult = enabled
47+
}
48+
}
49+
}
50+
51+
func WithPolicyDirs(paths ...string) options.ScannerOption {
52+
return func(s options.ConfigurableScanner) {
53+
if ss, ok := s.(*Scanner); ok {
54+
ss.policyDirs = paths
55+
}
56+
}
57+
}
58+
59+
func WithDataDirs(paths ...string) options.ScannerOption {
60+
return func(s options.ConfigurableScanner) {
61+
if ss, ok := s.(*Scanner); ok {
62+
ss.dataDirs = paths
63+
}
64+
}
65+
}
66+
67+
// WithPolicyNamespaces - namespaces which indicate rego policies containing enforced rules
68+
func WithPolicyNamespaces(namespaces ...string) options.ScannerOption {
69+
return func(s options.ConfigurableScanner) {
70+
if ss, ok := s.(*Scanner); ok {
71+
for _, namespace := range namespaces {
72+
ss.ruleNamespaces[namespace] = struct{}{}
73+
}
74+
}
75+
}
76+
}
77+
78+
func WithPolicyFilesystem(fsys fs.FS) options.ScannerOption {
79+
return func(s options.ConfigurableScanner) {
80+
if ss, ok := s.(*Scanner); ok {
81+
ss.policyFS = fsys
82+
}
83+
}
84+
}
85+
86+
func WithDataFilesystem(fsys fs.FS) options.ScannerOption {
87+
return func(s options.ConfigurableScanner) {
88+
if ss, ok := s.(*Scanner); ok {
89+
ss.dataFS = fsys
90+
}
91+
}
92+
}
93+
94+
func WithRegoErrorLimits(limit int) options.ScannerOption {
95+
return func(s options.ConfigurableScanner) {
96+
if ss, ok := s.(*Scanner); ok {
97+
ss.regoErrorLimit = limit
98+
}
99+
}
100+
}
101+
102+
func WithCustomSchemas(schemas map[string][]byte) options.ScannerOption {
103+
return func(s options.ConfigurableScanner) {
104+
if ss, ok := s.(*Scanner); ok {
105+
ss.customSchemas = schemas
106+
}
107+
}
108+
}

0 commit comments

Comments
 (0)