Skip to content

[sql39]: Actions schemas, queries and SQL store impl #1072

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 4 commits into from
Jun 2, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
22 changes: 9 additions & 13 deletions config_dev.go
Original file line number Diff line number Diff line change
Expand Up @@ -151,23 +151,19 @@ func NewStores(cfg *Config, clock clock.Clock) (*stores, error) {

stores.sessions = sessionStore
stores.closeFns["bbolt-sessions"] = sessionStore.Close
}

firewallBoltDB, err := firewalldb.NewBoltDB(
networkDir, firewalldb.DBFilename, stores.sessions,
stores.accounts, clock,
)
if err != nil {
return stores, fmt.Errorf("error creating firewall BoltDB: %v",
err)
}
firewallBoltDB, err := firewalldb.NewBoltDB(
networkDir, firewalldb.DBFilename, stores.sessions,
stores.accounts, clock,
)
if err != nil {
return stores, fmt.Errorf("error creating firewall "+
"BoltDB: %v", err)
}

if stores.firewall == nil {
stores.firewall = firewalldb.NewDB(firewallBoltDB)
stores.closeFns["bbolt-firewalldb"] = firewallBoltDB.Close
}

stores.firewallBolt = firewallBoltDB
stores.closeFns["bbolt-firewalldb"] = firewallBoltDB.Close

return stores, nil
}
1 change: 0 additions & 1 deletion config_prod.go
Original file line number Diff line number Diff line change
Expand Up @@ -62,7 +62,6 @@ func NewStores(cfg *Config, clock clock.Clock) (*stores, error) {
if err != nil {
return stores, fmt.Errorf("error creating firewall DB: %v", err)
}
stores.firewallBolt = firewallDB
stores.firewall = firewalldb.NewDB(firewallDB)
stores.closeFns["firewall"] = firewallDB.Close

Expand Down
7 changes: 4 additions & 3 deletions db/interfaces.go
Original file line number Diff line number Diff line change
Expand Up @@ -87,12 +87,13 @@ type BatchedQuerier interface {
// create a batched version of the normal methods they need.
sqlc.Querier

// CustomQueries is the set of custom queries that we have manually
// defined in addition to the ones generated by sqlc.
sqlc.CustomQueries

// BeginTx creates a new database transaction given the set of
// transaction options.
BeginTx(ctx context.Context, options TxOptions) (*sql.Tx, error)

// Backend returns the type of the database backend used.
Backend() sqlc.BackendType
}

// txExecutorOptions is a struct that holds the options for the transaction
Expand Down
2 changes: 1 addition & 1 deletion db/migrations.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,7 +22,7 @@ const (
// daemon.
//
// NOTE: This MUST be updated when a new migration is added.
LatestMigrationVersion = 4
LatestMigrationVersion = 5
)

// MigrationTarget is a functional option that can be passed to applyMigrations
Expand Down
78 changes: 78 additions & 0 deletions db/sqlc/actions.sql.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

210 changes: 210 additions & 0 deletions db/sqlc/actions_custom.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,210 @@
package sqlc
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice approach with the addition of this file 🎉🔥!


import (
"context"
"database/sql"
"strconv"
"strings"
)

// ActionQueryParams defines the parameters for querying actions.
type ActionQueryParams struct {
SessionID sql.NullInt64
AccountID sql.NullInt64
FeatureName sql.NullString
ActorName sql.NullString
RpcMethod sql.NullString
State sql.NullInt16
EndTime sql.NullTime
StartTime sql.NullTime
GroupID sql.NullInt64
Comment on lines +12 to +20
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

something that could be cool would be to add a unit test that uses reflection and field tagging to check that field names and types are a subset of and consistent with models.Action, such that any changes in the schema could be detected. one could tag EndTime, StartTime, and GroupID as not being part of Action

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

existing unit tests would break no?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I'm not sure every field is covered in unit tests, sqlc does all the renaming automatically if a field's name changes, but we would have to manually change the representations in the query and would only fail once we add a unit test. Was just an idea, but not a blocker :)

}

// ListActionsParams defines the parameters for listing actions, including
// the ActionQueryParams for filtering and a Pagination struct for
// pagination. The Reversed field indicates whether the results should be
// returned in reverse order based on the created_at timestamp.
type ListActionsParams struct {
ActionQueryParams
Reversed bool
*Pagination
}

// Pagination defines the pagination parameters for listing actions.
type Pagination struct {
NumOffset int32
NumLimit int32
}

// ListActions retrieves a list of actions based on the provided
// ListActionsParams.
func (q *Queries) ListActions(ctx context.Context,
arg ListActionsParams) ([]Action, error) {

query, args := buildListActionsQuery(arg)
rows, err := q.db.QueryContext(ctx, fillPlaceHolders(query), args...)
if err != nil {
return nil, err
}
defer rows.Close()
var items []Action
for rows.Next() {
var i Action
if err := rows.Scan(
&i.ID,
&i.SessionID,
&i.AccountID,
&i.MacaroonIdentifier,
&i.ActorName,
&i.FeatureName,
&i.ActionTrigger,
&i.Intent,
&i.StructuredJsonData,
&i.RpcMethod,
&i.RpcParamsJson,
&i.CreatedAt,
&i.ActionState,
&i.ErrorReason,
); err != nil {
return nil, err
}
items = append(items, i)
}
if err := rows.Close(); err != nil {
return nil, err
}
if err := rows.Err(); err != nil {
return nil, err
}
return items, nil
}

// CountActions returns the number of actions that match the provided
// ActionQueryParams.
func (q *Queries) CountActions(ctx context.Context,
arg ActionQueryParams) (int64, error) {

query, args := buildActionsQuery(arg, true)
row := q.db.QueryRowContext(ctx, query, args...)

var count int64
err := row.Scan(&count)

return count, err
}

// buildActionsQuery constructs a SQL query to retrieve actions based on the
// provided parameters. We do this manually so that if, for example, we have
// a sessionID we are filtering by, then this appears in the query as:
// `WHERE a.session_id = ?` which will properly make use of the underlying
// index. If we were instead to use a single SQLC query, it would include many
// WHERE clauses like:
// "WHERE a.session_id = COALESCE(sqlc.narg('session_id'), a.session_id)".
// This would use the index if run against postres but not when run against
// sqlite.
//
// The 'count' param indicates whether the query should return a count of
// actions that match the criteria or the actions themselves.
func buildActionsQuery(params ActionQueryParams, count bool) (string, []any) {
var (
conditions []string
args []any
)

if params.SessionID.Valid {
conditions = append(conditions, "a.session_id = ?")
args = append(args, params.SessionID.Int64)
}
if params.AccountID.Valid {
conditions = append(conditions, "a.account_id = ?")
args = append(args, params.AccountID.Int64)
}
if params.FeatureName.Valid {
conditions = append(conditions, "a.feature_name = ?")
args = append(args, params.FeatureName.String)
}
if params.ActorName.Valid {
conditions = append(conditions, "a.actor_name = ?")
args = append(args, params.ActorName.String)
}
if params.RpcMethod.Valid {
conditions = append(conditions, "a.rpc_method = ?")
args = append(args, params.RpcMethod.String)
}
if params.State.Valid {
conditions = append(conditions, "a.action_state = ?")
args = append(args, params.State.Int16)
}
if params.EndTime.Valid {
conditions = append(conditions, "a.created_at <= ?")
args = append(args, params.EndTime.Time)
}
if params.StartTime.Valid {
conditions = append(conditions, "a.created_at >= ?")
args = append(args, params.StartTime.Time)
}
if params.GroupID.Valid {
conditions = append(conditions, `
EXISTS (
SELECT 1
FROM sessions s
WHERE s.id = a.session_id AND s.group_id = ?
)`)
args = append(args, params.GroupID.Int64)
}

query := "SELECT a.* FROM actions a"
if count {
query = "SELECT COUNT(*) FROM actions a"
}
if len(conditions) > 0 {
query += " WHERE " + strings.Join(conditions, " AND ")
}

return query, args
}

// buildListActionsQuery constructs a SQL query to retrieve a list of actions
// based on the provided parameters. It builds upon the `buildActionsQuery`
// function, adding pagination and ordering based on the reversed parameter.
func buildListActionsQuery(params ListActionsParams) (string, []interface{}) {
query, args := buildActionsQuery(params.ActionQueryParams, false)

// Determine order direction.
order := "ASC"
if params.Reversed {
order = "DESC"
}
query += " ORDER BY a.created_at " + order

// Maybe paginate.
if params.Pagination != nil {
query += " LIMIT ? OFFSET ?"
args = append(args, params.NumLimit, params.NumOffset)
}

return query, args
}

// fillPlaceHolders replaces all '?' placeholders in the SQL query with
// positional placeholders like $1, $2, etc. This is necessary for
// compatibility with Postgres.
func fillPlaceHolders(query string) string {
var (
sb strings.Builder
argNum = 1
)

for i := range len(query) {
if query[i] != '?' {
sb.WriteByte(query[i])
continue
}

sb.WriteString("$")
sb.WriteString(strconv.Itoa(argNum))
argNum++
}

return sb.String()
}
20 changes: 20 additions & 0 deletions db/sqlc/db_custom.go
Original file line number Diff line number Diff line change
@@ -1,5 +1,9 @@
package sqlc

import (
"context"
)

// BackendType is an enum that represents the type of database backend we're
// using.
type BackendType uint8
Expand Down Expand Up @@ -44,3 +48,19 @@ func NewSqlite(db DBTX) *Queries {
func NewPostgres(db DBTX) *Queries {
return &Queries{db: &wrappedTX{db, BackendTypePostgres}}
}

// CustomQueries defines a set of custom queries that we define in addition
// to the ones generated by sqlc.
type CustomQueries interface {
// CountActions returns the number of actions that match the provided
// ActionQueryParams.
CountActions(ctx context.Context, arg ActionQueryParams) (int64, error)

// ListActions retrieves a list of actions based on the provided
// ListActionsParams.
ListActions(ctx context.Context,
arg ListActionsParams) ([]Action, error)

// Backend returns the type of the database backend used.
Backend() BackendType
}
5 changes: 5 additions & 0 deletions db/sqlc/migrations/000005_actions.down.sql
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
DROP INDEX IF NOT EXISTS actions_state_idx;
DROP INDEX IF NOT EXISTS actions_session_id_idx;
DROP INDEX IF NOT EXISTS actions_feature_name_idx;
DROP INDEX IF NOT EXISTS actions_created_at_idx;
DROP TABLE IF EXISTS actions;
Loading
Loading