Skip to content

Commit 09455a8

Browse files
add backfill using insert with select from
1 parent 5c5b87f commit 09455a8

1 file changed

Lines changed: 260 additions & 0 deletions

File tree

Lines changed: 260 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,260 @@
1+
import logging
2+
import time
3+
4+
from django.apps import apps
5+
from django.core.management.base import BaseCommand
6+
from django.db import connection
7+
8+
logger = logging.getLogger(__name__)
9+
10+
11+
class Command(BaseCommand):
12+
help = "Backfill pghistory events using direct SQL INSERT - much simpler and faster!"
13+
14+
def add_arguments(self, parser):
15+
parser.add_argument(
16+
"--batch-size",
17+
type=int,
18+
default=10000,
19+
help="Number of records to process in each batch",
20+
)
21+
parser.add_argument(
22+
"--dry-run",
23+
action="store_true",
24+
help="Show what would be processed without making changes",
25+
)
26+
parser.add_argument(
27+
"--models",
28+
nargs="+",
29+
help="Specific models to process (default: all configured models)",
30+
)
31+
32+
def handle(self, *args, **options):
33+
batch_size = options["batch_size"]
34+
dry_run = options["dry_run"]
35+
specific_models = options.get("models")
36+
37+
# Define the models to process
38+
models_to_process = [
39+
"Test",
40+
"Product",
41+
"Finding",
42+
"Endpoint",
43+
"Dojo_User",
44+
"Product_Type",
45+
"Finding_Group",
46+
"Risk_Acceptance",
47+
"Finding_Template",
48+
"Cred_User",
49+
"Notification_Webhooks",
50+
]
51+
52+
if specific_models:
53+
models_to_process = [m for m in models_to_process if m in specific_models]
54+
55+
self.stdout.write(
56+
self.style.SUCCESS(
57+
f"Starting backfill for {len(models_to_process)} model(s) using direct SQL INSERT...",
58+
),
59+
)
60+
61+
total_processed = 0
62+
total_start_time = time.time()
63+
64+
for model_name in models_to_process:
65+
self.stdout.write(f"\nProcessing {model_name}...")
66+
processed, _records_per_second = self.process_model_simple(
67+
model_name, batch_size, dry_run,
68+
)
69+
total_processed += processed
70+
71+
total_duration = time.time() - total_start_time
72+
total_records_per_second = total_processed / total_duration if total_duration > 0 else 0
73+
74+
self.stdout.write(
75+
self.style.SUCCESS(
76+
f"\n✓ Backfill completed: {total_processed:,} total records in {total_duration:.2f}s "
77+
f"({total_records_per_second:.1f} records/sec)",
78+
),
79+
)
80+
81+
def get_excluded_fields(self, model_name):
82+
"""Get the list of excluded fields for a specific model from pghistory configuration."""
83+
excluded_fields_map = {
84+
"Dojo_User": ["password"],
85+
"Product": ["updated"],
86+
"Cred_User": ["password"],
87+
"Notification_Webhooks": ["header_name", "header_value"],
88+
}
89+
return excluded_fields_map.get(model_name, [])
90+
91+
def process_model_simple(self, model_name, batch_size, dry_run):
92+
"""Process a single model using direct SQL INSERT - much simpler!"""
93+
try:
94+
# Get table names
95+
table_name, event_table_name = self.get_table_names(model_name)
96+
97+
if not table_name or not event_table_name:
98+
self.stdout.write(f" Skipping {model_name}: table not found")
99+
return 0, 0.0
100+
101+
# Check if event table exists
102+
with connection.cursor() as cursor:
103+
cursor.execute("""
104+
SELECT EXISTS (
105+
SELECT 1 FROM information_schema.tables
106+
WHERE table_name = %s
107+
)
108+
""", [event_table_name])
109+
if not cursor.fetchone()[0]:
110+
self.stdout.write(f" Skipping {model_name}: event table {event_table_name} not found")
111+
return 0, 0.0
112+
113+
# Get counts
114+
with connection.cursor() as cursor:
115+
cursor.execute(f"SELECT COUNT(*) FROM {table_name}")
116+
total_count = cursor.fetchone()[0]
117+
118+
cursor.execute(f"""
119+
SELECT COUNT(*) FROM {table_name} t
120+
WHERE NOT EXISTS (
121+
SELECT 1 FROM {event_table_name} e
122+
WHERE e.pgh_obj_id = t.id AND e.pgh_label = 'initial_import'
123+
)
124+
""")
125+
backfill_count = cursor.fetchone()[0]
126+
127+
if backfill_count == 0:
128+
self.stdout.write(f" No records need backfill for {model_name}")
129+
return 0, 0.0
130+
131+
self.stdout.write(f" {backfill_count:,} records need backfill out of {total_count:,} total")
132+
133+
if dry_run:
134+
self.stdout.write(f" [DRY RUN] Would process {backfill_count:,} records")
135+
return backfill_count, 0.0
136+
137+
# Get source columns (excluding pghistory-specific ones)
138+
excluded_fields = self.get_excluded_fields(model_name)
139+
with connection.cursor() as cursor:
140+
cursor.execute("""
141+
SELECT column_name
142+
FROM information_schema.columns
143+
WHERE table_name = %s
144+
ORDER BY ordinal_position
145+
""", [table_name])
146+
source_columns = [row[0] for row in cursor.fetchall()]
147+
148+
# Filter out excluded fields
149+
source_columns = [col for col in source_columns if col not in excluded_fields]
150+
151+
# Get event table columns (excluding pgh_id which is auto-generated)
152+
with connection.cursor() as cursor:
153+
cursor.execute("""
154+
SELECT column_name
155+
FROM information_schema.columns
156+
WHERE table_name = %s AND column_name != 'pgh_id'
157+
ORDER BY ordinal_position
158+
""", [event_table_name])
159+
event_columns = [row[0] for row in cursor.fetchall()]
160+
161+
# Build the INSERT query - this is the magic!
162+
# We use INSERT INTO ... SELECT to directly generate the event data
163+
select_columns = []
164+
for col in event_columns:
165+
if col == "pgh_created_at":
166+
select_columns.append("NOW() as pgh_created_at")
167+
elif col == "pgh_label":
168+
select_columns.append("'initial_import' as pgh_label")
169+
elif col == "pgh_obj_id":
170+
select_columns.append("t.id as pgh_obj_id")
171+
elif col == "pgh_context_id":
172+
select_columns.append("NULL as pgh_context_id")
173+
elif col in source_columns:
174+
select_columns.append(f"t.{col}")
175+
else:
176+
select_columns.append("NULL as " + col)
177+
178+
# Get all IDs that need backfill
179+
with connection.cursor() as cursor:
180+
cursor.execute(f"""
181+
SELECT t.id FROM {table_name} t
182+
WHERE NOT EXISTS (
183+
SELECT 1 FROM {event_table_name} e
184+
WHERE e.pgh_obj_id = t.id AND e.pgh_label = 'initial_import'
185+
)
186+
ORDER BY t.id
187+
""")
188+
ids_to_process = [row[0] for row in cursor.fetchall()]
189+
190+
if not ids_to_process:
191+
self.stdout.write(" No records need backfill")
192+
return 0, 0.0
193+
194+
# Process in batches using direct SQL
195+
processed = 0
196+
model_start_time = time.time()
197+
198+
for i in range(0, len(ids_to_process), batch_size):
199+
batch_ids = ids_to_process[i:i + batch_size]
200+
201+
# Log progress every 10 batches
202+
if i > 0 and i % (batch_size * 10) == 0:
203+
self.stdout.write(f" Processing batch starting at index {i:,}...")
204+
205+
# The magic happens here - direct SQL INSERT!
206+
insert_sql = f"""
207+
INSERT INTO {event_table_name} ({', '.join(event_columns)})
208+
SELECT {', '.join(select_columns)}
209+
FROM {table_name} t
210+
WHERE t.id = ANY(%s)
211+
ORDER BY t.id
212+
"""
213+
214+
with connection.cursor() as cursor:
215+
cursor.execute(insert_sql, [batch_ids])
216+
batch_processed = cursor.rowcount
217+
processed += batch_processed
218+
219+
# Log progress every 10 batches
220+
if i > 0 and i % (batch_size * 10) == 0:
221+
progress = (i + batch_size) / len(ids_to_process) * 100
222+
self.stdout.write(f" Processed {processed:,}/{backfill_count:,} records ({progress:.1f}%)")
223+
224+
# Calculate timing
225+
model_end_time = time.time()
226+
total_duration = model_end_time - model_start_time
227+
records_per_second = processed / total_duration if total_duration > 0 else 0
228+
229+
self.stdout.write(
230+
self.style.SUCCESS(
231+
f" ✓ Completed {model_name}: {processed:,} records in {total_duration:.2f}s "
232+
f"({records_per_second:.1f} records/sec)",
233+
),
234+
)
235+
236+
return processed, records_per_second # noqa: TRY300
237+
238+
except Exception as e:
239+
self.stdout.write(
240+
self.style.ERROR(f" ✗ Failed to process {model_name}: {e}"),
241+
)
242+
logger.exception(f"Error processing {model_name}")
243+
return 0, 0.0
244+
245+
def get_table_names(self, model_name):
246+
"""Get the actual table names for a model using Django's model metadata."""
247+
try:
248+
# Get the Django model
249+
Model = apps.get_model("dojo", model_name)
250+
table_name = Model._meta.db_table
251+
252+
# Get the corresponding Event model
253+
event_table_name = f"{model_name}Event"
254+
EventModel = apps.get_model("dojo", event_table_name)
255+
event_table_name = EventModel._meta.db_table
256+
257+
return table_name, event_table_name # noqa: TRY300
258+
except LookupError:
259+
# Model not found, return None
260+
return None, None

0 commit comments

Comments
 (0)