Skip to content

Commit 765e641

Browse files
committed
fix: handle DROP TABLE CASCADE for dependent views
Signed-off-by: wucm667 <stevenwucongmin@gmail.com>
1 parent 977ac6d commit 765e641

5 files changed

Lines changed: 195 additions & 2 deletions

File tree

internal/engine/postgresql/catalog_test.go

Lines changed: 118 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -6,6 +6,8 @@ import (
66
"strings"
77
"testing"
88

9+
"github.com/sqlc-dev/sqlc/internal/sql/ast"
10+
"github.com/sqlc-dev/sqlc/internal/sql/catalog"
911
"github.com/sqlc-dev/sqlc/internal/sql/sqlerr"
1012

1113
"github.com/google/go-cmp/cmp"
@@ -168,3 +170,119 @@ func TestUpdateErrors(t *testing.T) {
168170
})
169171
}
170172
}
173+
174+
func TestDropTableCascadeViewRecreate(t *testing.T) {
175+
// Regression test for https://github.com/sqlc-dev/sqlc/issues/4416
176+
// DROP TABLE CASCADE should remove dependent views from the catalog,
177+
// allowing a subsequent CREATE VIEW with the same name to succeed.
178+
p := NewParser()
179+
180+
// First: create the table
181+
stmts1, err := p.Parse(strings.NewReader(`
182+
CREATE TABLE reference_rates (id BIGSERIAL PRIMARY KEY);
183+
`))
184+
if err != nil {
185+
t.Fatalf("parse error: %v", err)
186+
}
187+
188+
c := NewCatalog()
189+
if err := c.Build(stmts1); err != nil {
190+
t.Fatalf("create table error: %v", err)
191+
}
192+
193+
// Manually add a view that depends on reference_rates to the catalog
194+
var schema *catalog.Schema
195+
for _, s := range c.Schemas {
196+
if s.Name == "public" {
197+
schema = s
198+
}
199+
}
200+
schema.Tables = append(schema.Tables, &catalog.Table{
201+
Rel: &ast.TableName{Schema: "public", Name: "vw_reference_rates"},
202+
Columns: []*catalog.Column{{Name: "id"}},
203+
DependsOnTables: []*ast.TableName{
204+
{Schema: "public", Name: "reference_rates"},
205+
},
206+
})
207+
208+
// Verify the view exists
209+
if !viewExists(schema, "vw_reference_rates") {
210+
t.Fatal("view not found in catalog before drop")
211+
}
212+
213+
// Second: DROP TABLE CASCADE
214+
stmts2, err := p.Parse(strings.NewReader(`
215+
DROP TABLE reference_rates CASCADE;
216+
`))
217+
if err != nil {
218+
t.Fatalf("parse error: %v", err)
219+
}
220+
if err := c.Build(stmts2); err != nil {
221+
t.Fatalf("DROP TABLE CASCADE error: %v", err)
222+
}
223+
224+
// Verify the view was removed
225+
if viewExists(schema, "vw_reference_rates") {
226+
t.Fatal("expected view to be removed by CASCADE, but it still exists")
227+
}
228+
}
229+
230+
func TestDropTableCascadeWithoutCascadeFails(t *testing.T) {
231+
// Without CASCADE, dropping a table that has a dependent view leaves the view
232+
// in the catalog (matching current sqlc behavior, though real PostgreSQL would
233+
// reject DROP TABLE without CASCADE when views depend on it).
234+
p := NewParser()
235+
236+
// Create the table
237+
stmts1, err := p.Parse(strings.NewReader(`
238+
CREATE TABLE reference_rates (id BIGSERIAL PRIMARY KEY);
239+
`))
240+
if err != nil {
241+
t.Fatalf("parse error: %v", err)
242+
}
243+
244+
c := NewCatalog()
245+
if err := c.Build(stmts1); err != nil {
246+
t.Fatalf("create table error: %v", err)
247+
}
248+
249+
// Manually add a view that depends on reference_rates
250+
schema := c.Schemas[0]
251+
for _, s := range c.Schemas {
252+
if s.Name == "public" {
253+
schema = s
254+
}
255+
}
256+
schema.Tables = append(schema.Tables, &catalog.Table{
257+
Rel: &ast.TableName{Schema: "public", Name: "vw_reference_rates"},
258+
Columns: []*catalog.Column{{Name: "id"}},
259+
DependsOnTables: []*ast.TableName{
260+
{Schema: "public", Name: "reference_rates"},
261+
},
262+
})
263+
264+
// DROP TABLE without CASCADE
265+
stmts2, err := p.Parse(strings.NewReader(`
266+
DROP TABLE reference_rates;
267+
`))
268+
if err != nil {
269+
t.Fatalf("parse error: %v", err)
270+
}
271+
if err := c.Build(stmts2); err != nil {
272+
t.Fatalf("DROP TABLE error: %v", err)
273+
}
274+
275+
// Without CASCADE, the view should still exist in the catalog
276+
if !viewExists(schema, "vw_reference_rates") {
277+
t.Fatal("expected view to still exist without CASCADE, but it was removed")
278+
}
279+
}
280+
281+
func viewExists(schema *catalog.Schema, name string) bool {
282+
for _, tbl := range schema.Tables {
283+
if tbl.Rel.Name == name {
284+
return true
285+
}
286+
}
287+
return false
288+
}

internal/engine/postgresql/parse.go

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -578,6 +578,7 @@ func translate(node *nodes.Node) (ast.Node, error) {
578578

579579
case nodes.ObjectType_OBJECT_TABLE, nodes.ObjectType_OBJECT_VIEW, nodes.ObjectType_OBJECT_MATVIEW:
580580
drop := &ast.DropTableStmt{
581+
Behavior: ast.DropBehavior(n.Behavior),
581582
IfExists: n.MissingOk,
582583
}
583584
for _, obj := range n.Objects {

internal/sql/ast/drop_table_stmt.go

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -1,8 +1,9 @@
11
package ast
22

33
type DropTableStmt struct {
4-
IfExists bool
5-
Tables []*TableName
4+
Behavior DropBehavior
5+
IfExists bool
6+
Tables []*TableName
67
}
78

89
func (n *DropTableStmt) Pos() int {

internal/sql/catalog/table.go

Lines changed: 35 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -16,6 +16,10 @@ type Table struct {
1616
Rel *ast.TableName
1717
Columns []*Column
1818
Comment string
19+
20+
// If non-nil, this Table represents a view and depends on the listed tables.
21+
// Only set when the Table is created via CREATE VIEW.
22+
DependsOnTables []*ast.TableName
1923
}
2024

2125
func checkMissing(err error, missingOK bool) error {
@@ -373,6 +377,24 @@ func (c *Catalog) dropTable(stmt *ast.DropTableStmt) error {
373377
return err
374378
}
375379

380+
// When CASCADE, drop dependent views first
381+
// DROP_CASCADE = 2 in pg_query_go protobuf enum
382+
if stmt.Behavior == 2 {
383+
for i := len(schema.Tables) - 1; i >= 0; i-- {
384+
view := schema.Tables[i]
385+
if len(view.DependsOnTables) == 0 {
386+
continue
387+
}
388+
for _, dep := range view.DependsOnTables {
389+
if dep.Name == name.Name && tablesSameSchema(dep, name, c.DefaultSchema) {
390+
// This view depends on the table being dropped
391+
schema.Tables = append(schema.Tables[:i], schema.Tables[i+1:]...)
392+
break
393+
}
394+
}
395+
}
396+
}
397+
376398
drop := &ast.DropTypeStmt{}
377399
for _, col := range tbl.Columns {
378400
if !col.linkedType {
@@ -389,6 +411,19 @@ func (c *Catalog) dropTable(stmt *ast.DropTableStmt) error {
389411
return nil
390412
}
391413

414+
// tablesSameSchema checks if two table references point to the same schema.
415+
func tablesSameSchema(a, b *ast.TableName, defaultSchema string) bool {
416+
aSchema := a.Schema
417+
bSchema := b.Schema
418+
if aSchema == "" {
419+
aSchema = defaultSchema
420+
}
421+
if bSchema == "" {
422+
bSchema = defaultSchema
423+
}
424+
return aSchema == bSchema
425+
}
426+
392427
func (c *Catalog) renameColumn(stmt *ast.RenameColumnStmt) error {
393428
_, tbl, err := c.getTable(stmt.Table)
394429
if err != nil {

internal/sql/catalog/view.go

Lines changed: 38 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -2,6 +2,7 @@ package catalog
22

33
import (
44
"github.com/sqlc-dev/sqlc/internal/sql/ast"
5+
"github.com/sqlc-dev/sqlc/internal/sql/astutils"
56
"github.com/sqlc-dev/sqlc/internal/sql/sqlerr"
67
)
78

@@ -29,6 +30,9 @@ func (c *Catalog) createView(stmt *ast.ViewStmt, colGen columnGenerator) error {
2930
Columns: cols,
3031
}
3132

33+
// Extract table dependencies from the view's SELECT query
34+
tbl.DependsOnTables = extractTableDeps(stmt.Query)
35+
3236
ns := tbl.Rel.Schema
3337
if ns == "" {
3438
ns = c.DefaultSchema
@@ -50,3 +54,37 @@ func (c *Catalog) createView(stmt *ast.ViewStmt, colGen columnGenerator) error {
5054

5155
return nil
5256
}
57+
58+
// extractTableDeps walks the SELECT query AST and returns all table references (RangeVar nodes).
59+
func extractTableDeps(node ast.Node) []*ast.TableName {
60+
var deps []*ast.TableName
61+
seen := make(map[string]bool)
62+
63+
astutils.Walk(astutils.VisitorFunc(func(n ast.Node) {
64+
rv, ok := n.(*ast.RangeVar)
65+
if !ok || rv.Relname == nil {
66+
return
67+
}
68+
schema := ""
69+
if rv.Schemaname != nil {
70+
schema = *rv.Schemaname
71+
}
72+
key := schema + "." + *rv.Relname
73+
if seen[key] {
74+
return
75+
}
76+
seen[key] = true
77+
78+
// Skip system catalogs and information schema
79+
if schema == "pg_catalog" || schema == "information_schema" {
80+
return
81+
}
82+
83+
deps = append(deps, &ast.TableName{
84+
Schema: schema,
85+
Name: *rv.Relname,
86+
})
87+
}), node)
88+
89+
return deps
90+
}

0 commit comments

Comments
 (0)