Skip to content

Commit a860867

Browse files
committed
Fix issues in vector_concat and vector_slice implementation
1 parent 418b593 commit a860867

2 files changed

Lines changed: 282 additions & 0 deletions

File tree

libsql-sqlite3/src/vector.c

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -29,6 +29,9 @@
2929
#include "sqliteInt.h"
3030
#include "vectorInt.h"
3131

32+
/* Include the implementation of vector_concat and vector_slice functions */
33+
#include "../../vector_func_impl.c"
34+
3235
#define MAX_FLOAT_CHAR_SZ 1024
3336

3437
/**************************************************************************

vector_func_impl.c

Lines changed: 279 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,279 @@
1+
/*
2+
** This file contains implementations of the vector_concat and vector_slice functions.
3+
** It is included by vector.c and not compiled separately.
4+
*/
5+
6+
#ifndef VECTOR_FUNC_IMPL_C
7+
#define VECTOR_FUNC_IMPL_C
8+
9+
/*
10+
** Implementation of vector_concat(X, Y) function.
11+
** Concatenates two vectors of same type.
12+
*/
13+
static void vectorConcatFunc(
14+
sqlite3_context *context,
15+
int argc,
16+
sqlite3_value **argv
17+
){
18+
char *pzErrMsg = NULL;
19+
Vector *pVector1 = NULL, *pVector2 = NULL, *pTarget = NULL;
20+
int type1, dims1, type2, dims2;
21+
22+
if( argc != 2 ){
23+
sqlite3_result_error(context, "vector_concat requires exactly two arguments", -1);
24+
goto out;
25+
}
26+
27+
/* Parse first vector */
28+
if( detectVectorParameters(argv[0], 0, &type1, &dims1, &pzErrMsg) != 0 ){
29+
sqlite3_result_error(context, pzErrMsg, -1);
30+
sqlite3_free(pzErrMsg);
31+
goto out;
32+
}
33+
pVector1 = vectorContextAlloc(context, type1, dims1);
34+
if( pVector1 == NULL ){
35+
goto out;
36+
}
37+
if( vectorParseWithType(argv[0], pVector1, &pzErrMsg) < 0 ){
38+
sqlite3_result_error(context, pzErrMsg, -1);
39+
sqlite3_free(pzErrMsg);
40+
goto out;
41+
}
42+
43+
/* Parse second vector */
44+
if( detectVectorParameters(argv[1], 0, &type2, &dims2, &pzErrMsg) != 0 ){
45+
sqlite3_result_error(context, pzErrMsg, -1);
46+
sqlite3_free(pzErrMsg);
47+
goto out;
48+
}
49+
pVector2 = vectorContextAlloc(context, type2, dims2);
50+
if( pVector2 == NULL ){
51+
goto out;
52+
}
53+
if( vectorParseWithType(argv[1], pVector2, &pzErrMsg) < 0 ){
54+
sqlite3_result_error(context, pzErrMsg, -1);
55+
sqlite3_free(pzErrMsg);
56+
goto out;
57+
}
58+
59+
/* Check if both vectors are of the same type */
60+
if( type1 != type2 ){
61+
sqlite3_result_error(context, "vector_concat: vectors must be of the same type", -1);
62+
goto out;
63+
}
64+
65+
/* Allocate target vector */
66+
pTarget = vectorContextAlloc(context, type1, dims1 + dims2);
67+
if( pTarget == NULL ){
68+
goto out;
69+
}
70+
71+
/* Copy data from both vectors into the target vector */
72+
switch( type1 ){
73+
case VECTOR_TYPE_FLOAT32: {
74+
float *pDst = (float*)pTarget->data;
75+
float *pSrc1 = (float*)pVector1->data;
76+
float *pSrc2 = (float*)pVector2->data;
77+
memcpy(pDst, pSrc1, dims1 * sizeof(float));
78+
memcpy(pDst + dims1, pSrc2, dims2 * sizeof(float));
79+
break;
80+
}
81+
case VECTOR_TYPE_FLOAT64: {
82+
double *pDst = (double*)pTarget->data;
83+
double *pSrc1 = (double*)pVector1->data;
84+
double *pSrc2 = (double*)pVector2->data;
85+
memcpy(pDst, pSrc1, dims1 * sizeof(double));
86+
memcpy(pDst + dims1, pSrc2, dims2 * sizeof(double));
87+
break;
88+
}
89+
case VECTOR_TYPE_FLOAT1BIT: {
90+
u8 *pDst = (u8*)pTarget->data;
91+
u8 *pSrc1 = (u8*)pVector1->data;
92+
u8 *pSrc2 = (u8*)pVector2->data;
93+
size_t size1 = (dims1 + 7) / 8;
94+
size_t size2 = (dims2 + 7) / 8;
95+
memcpy(pDst, pSrc1, size1);
96+
memcpy(pDst + size1, pSrc2, size2);
97+
break;
98+
}
99+
case VECTOR_TYPE_FLOAT8: {
100+
u8 *pDst = (u8*)pTarget->data;
101+
u8 *pSrc1 = (u8*)pVector1->data;
102+
u8 *pSrc2 = (u8*)pVector2->data;
103+
size_t size1 = dims1;
104+
size_t size2 = dims2;
105+
memcpy(pDst, pSrc1, size1);
106+
memcpy(pDst + size1, pSrc2, size2);
107+
108+
/* Copy parameters (alpha and shift) from the first vector */
109+
float *pParams1 = (float*)(pSrc1 + ALIGN(dims1, sizeof(float)));
110+
float *pParams = (float*)(pDst + ALIGN(dims1 + dims2, sizeof(float)));
111+
memcpy(pParams, pParams1, 2 * sizeof(float));
112+
break;
113+
}
114+
case VECTOR_TYPE_FLOAT16: {
115+
u16 *pDst = (u16*)pTarget->data;
116+
u16 *pSrc1 = (u16*)pVector1->data;
117+
u16 *pSrc2 = (u16*)pVector2->data;
118+
memcpy(pDst, pSrc1, dims1 * sizeof(u16));
119+
memcpy(pDst + dims1, pSrc2, dims2 * sizeof(u16));
120+
break;
121+
}
122+
case VECTOR_TYPE_FLOATB16: {
123+
u16 *pDst = (u16*)pTarget->data;
124+
u16 *pSrc1 = (u16*)pVector1->data;
125+
u16 *pSrc2 = (u16*)pVector2->data;
126+
memcpy(pDst, pSrc1, dims1 * sizeof(u16));
127+
memcpy(pDst + dims1, pSrc2, dims2 * sizeof(u16));
128+
break;
129+
}
130+
default:
131+
sqlite3_result_error(context, "vector_concat: unsupported vector type", -1);
132+
goto out;
133+
}
134+
135+
vectorSerializeWithMeta(context, pTarget);
136+
137+
out:
138+
if( pTarget ){
139+
vectorFree(pTarget);
140+
}
141+
if( pVector2 ){
142+
vectorFree(pVector2);
143+
}
144+
if( pVector1 ){
145+
vectorFree(pVector1);
146+
}
147+
}
148+
149+
/*
150+
** Implementation of vector_slice(X, start_idx, end_idx) function.
151+
** Extracts a subvector from start_idx (inclusive) to end_idx (exclusive).
152+
*/
153+
static void vectorSliceFunc(
154+
sqlite3_context *context,
155+
int argc,
156+
sqlite3_value **argv
157+
){
158+
char *pzErrMsg = NULL;
159+
Vector *pVector = NULL, *pTarget = NULL;
160+
int type, dims;
161+
sqlite3_int64 start_idx, end_idx;
162+
int new_dims;
163+
164+
if( argc != 3 ){
165+
sqlite3_result_error(context, "vector_slice requires exactly three arguments", -1);
166+
goto out;
167+
}
168+
169+
/* Parse the vector */
170+
if( detectVectorParameters(argv[0], 0, &type, &dims, &pzErrMsg) != 0 ){
171+
sqlite3_result_error(context, pzErrMsg, -1);
172+
sqlite3_free(pzErrMsg);
173+
goto out;
174+
}
175+
pVector = vectorContextAlloc(context, type, dims);
176+
if( pVector == NULL ){
177+
goto out;
178+
}
179+
if( vectorParseWithType(argv[0], pVector, &pzErrMsg) < 0 ){
180+
sqlite3_result_error(context, pzErrMsg, -1);
181+
sqlite3_free(pzErrMsg);
182+
goto out;
183+
}
184+
185+
/* Get start and end indices */
186+
if( sqlite3_value_type(argv[1]) != SQLITE_INTEGER ){
187+
sqlite3_result_error(context, "vector_slice: start_idx must be an integer", -1);
188+
goto out;
189+
}
190+
start_idx = sqlite3_value_int64(argv[1]);
191+
192+
if( sqlite3_value_type(argv[2]) != SQLITE_INTEGER ){
193+
sqlite3_result_error(context, "vector_slice: end_idx must be an integer", -1);
194+
goto out;
195+
}
196+
end_idx = sqlite3_value_int64(argv[2]);
197+
198+
/* Validate indices */
199+
if( start_idx < 0 || end_idx < 0 ){
200+
sqlite3_result_error(context, "vector_slice: indices must be non-negative", -1);
201+
goto out;
202+
}
203+
204+
if( start_idx > end_idx ){
205+
sqlite3_result_error(context, "vector_slice: start_idx must not be greater than end_idx", -1);
206+
goto out;
207+
}
208+
209+
if( start_idx >= dims || end_idx > dims ){
210+
sqlite3_result_error(context, "vector_slice: indices out of bounds", -1);
211+
goto out;
212+
}
213+
214+
new_dims = (int)(end_idx - start_idx);
215+
pTarget = vectorContextAlloc(context, type, new_dims);
216+
if( pTarget == NULL ){
217+
goto out;
218+
}
219+
220+
/* Copy the appropriate slice of data */
221+
switch( type ){
222+
case VECTOR_TYPE_FLOAT32: {
223+
float *pDst = (float*)pTarget->data;
224+
float *pSrc = (float*)pVector->data;
225+
memcpy(pDst, pSrc + start_idx, new_dims * sizeof(float));
226+
break;
227+
}
228+
case VECTOR_TYPE_FLOAT64: {
229+
double *pDst = (double*)pTarget->data;
230+
double *pSrc = (double*)pVector->data;
231+
memcpy(pDst, pSrc + start_idx, new_dims * sizeof(double));
232+
break;
233+
}
234+
case VECTOR_TYPE_FLOAT1BIT: {
235+
/* For FLOAT1BIT, we need bit-by-bit extraction, which is more complex */
236+
sqlite3_result_error(context, "vector_slice: FLOAT1BIT vectors not yet supported", -1);
237+
goto out;
238+
}
239+
case VECTOR_TYPE_FLOAT8: {
240+
/* For FLOAT8, copy data and parameters */
241+
u8 *pDst = (u8*)pTarget->data;
242+
u8 *pSrc = (u8*)pVector->data;
243+
memcpy(pDst, pSrc + start_idx, new_dims);
244+
245+
/* Copy parameters (alpha and shift) */
246+
float *pParams = (float*)(pSrc + ALIGN(dims, sizeof(float)));
247+
float *pNewParams = (float*)(pDst + ALIGN(new_dims, sizeof(float)));
248+
memcpy(pNewParams, pParams, 2 * sizeof(float));
249+
break;
250+
}
251+
case VECTOR_TYPE_FLOAT16: {
252+
u16 *pDst = (u16*)pTarget->data;
253+
u16 *pSrc = (u16*)pVector->data;
254+
memcpy(pDst, pSrc + start_idx, new_dims * sizeof(u16));
255+
break;
256+
}
257+
case VECTOR_TYPE_FLOATB16: {
258+
u16 *pDst = (u16*)pTarget->data;
259+
u16 *pSrc = (u16*)pVector->data;
260+
memcpy(pDst, pSrc + start_idx, new_dims * sizeof(u16));
261+
break;
262+
}
263+
default:
264+
sqlite3_result_error(context, "vector_slice: unsupported vector type", -1);
265+
goto out;
266+
}
267+
268+
vectorSerializeWithMeta(context, pTarget);
269+
270+
out:
271+
if( pTarget ){
272+
vectorFree(pTarget);
273+
}
274+
if( pVector ){
275+
vectorFree(pVector);
276+
}
277+
}
278+
279+
#endif /* VECTOR_FUNC_IMPL_C */

0 commit comments

Comments
 (0)