|
| 1 | +/* |
| 2 | +** This file contains implementations of the vector_concat and vector_slice functions. |
| 3 | +** It is included by vector.c and not compiled separately. |
| 4 | +*/ |
| 5 | + |
| 6 | +#ifndef VECTOR_FUNC_IMPL_C |
| 7 | +#define VECTOR_FUNC_IMPL_C |
| 8 | + |
| 9 | +/* |
| 10 | +** Implementation of vector_concat(X, Y) function. |
| 11 | +** Concatenates two vectors of same type. |
| 12 | +*/ |
| 13 | +static void vectorConcatFunc( |
| 14 | + sqlite3_context *context, |
| 15 | + int argc, |
| 16 | + sqlite3_value **argv |
| 17 | +){ |
| 18 | + char *pzErrMsg = NULL; |
| 19 | + Vector *pVector1 = NULL, *pVector2 = NULL, *pTarget = NULL; |
| 20 | + int type1, dims1, type2, dims2; |
| 21 | + |
| 22 | + if( argc != 2 ){ |
| 23 | + sqlite3_result_error(context, "vector_concat requires exactly two arguments", -1); |
| 24 | + goto out; |
| 25 | + } |
| 26 | + |
| 27 | + /* Parse first vector */ |
| 28 | + if( detectVectorParameters(argv[0], 0, &type1, &dims1, &pzErrMsg) != 0 ){ |
| 29 | + sqlite3_result_error(context, pzErrMsg, -1); |
| 30 | + sqlite3_free(pzErrMsg); |
| 31 | + goto out; |
| 32 | + } |
| 33 | + pVector1 = vectorContextAlloc(context, type1, dims1); |
| 34 | + if( pVector1 == NULL ){ |
| 35 | + goto out; |
| 36 | + } |
| 37 | + if( vectorParseWithType(argv[0], pVector1, &pzErrMsg) < 0 ){ |
| 38 | + sqlite3_result_error(context, pzErrMsg, -1); |
| 39 | + sqlite3_free(pzErrMsg); |
| 40 | + goto out; |
| 41 | + } |
| 42 | + |
| 43 | + /* Parse second vector */ |
| 44 | + if( detectVectorParameters(argv[1], 0, &type2, &dims2, &pzErrMsg) != 0 ){ |
| 45 | + sqlite3_result_error(context, pzErrMsg, -1); |
| 46 | + sqlite3_free(pzErrMsg); |
| 47 | + goto out; |
| 48 | + } |
| 49 | + pVector2 = vectorContextAlloc(context, type2, dims2); |
| 50 | + if( pVector2 == NULL ){ |
| 51 | + goto out; |
| 52 | + } |
| 53 | + if( vectorParseWithType(argv[1], pVector2, &pzErrMsg) < 0 ){ |
| 54 | + sqlite3_result_error(context, pzErrMsg, -1); |
| 55 | + sqlite3_free(pzErrMsg); |
| 56 | + goto out; |
| 57 | + } |
| 58 | + |
| 59 | + /* Check if both vectors are of the same type */ |
| 60 | + if( type1 != type2 ){ |
| 61 | + sqlite3_result_error(context, "vector_concat: vectors must be of the same type", -1); |
| 62 | + goto out; |
| 63 | + } |
| 64 | + |
| 65 | + /* Allocate target vector */ |
| 66 | + pTarget = vectorContextAlloc(context, type1, dims1 + dims2); |
| 67 | + if( pTarget == NULL ){ |
| 68 | + goto out; |
| 69 | + } |
| 70 | + |
| 71 | + /* Copy data from both vectors into the target vector */ |
| 72 | + switch( type1 ){ |
| 73 | + case VECTOR_TYPE_FLOAT32: { |
| 74 | + float *pDst = (float*)pTarget->data; |
| 75 | + float *pSrc1 = (float*)pVector1->data; |
| 76 | + float *pSrc2 = (float*)pVector2->data; |
| 77 | + memcpy(pDst, pSrc1, dims1 * sizeof(float)); |
| 78 | + memcpy(pDst + dims1, pSrc2, dims2 * sizeof(float)); |
| 79 | + break; |
| 80 | + } |
| 81 | + case VECTOR_TYPE_FLOAT64: { |
| 82 | + double *pDst = (double*)pTarget->data; |
| 83 | + double *pSrc1 = (double*)pVector1->data; |
| 84 | + double *pSrc2 = (double*)pVector2->data; |
| 85 | + memcpy(pDst, pSrc1, dims1 * sizeof(double)); |
| 86 | + memcpy(pDst + dims1, pSrc2, dims2 * sizeof(double)); |
| 87 | + break; |
| 88 | + } |
| 89 | + case VECTOR_TYPE_FLOAT1BIT: { |
| 90 | + u8 *pDst = (u8*)pTarget->data; |
| 91 | + u8 *pSrc1 = (u8*)pVector1->data; |
| 92 | + u8 *pSrc2 = (u8*)pVector2->data; |
| 93 | + size_t size1 = (dims1 + 7) / 8; |
| 94 | + size_t size2 = (dims2 + 7) / 8; |
| 95 | + memcpy(pDst, pSrc1, size1); |
| 96 | + memcpy(pDst + size1, pSrc2, size2); |
| 97 | + break; |
| 98 | + } |
| 99 | + case VECTOR_TYPE_FLOAT8: { |
| 100 | + u8 *pDst = (u8*)pTarget->data; |
| 101 | + u8 *pSrc1 = (u8*)pVector1->data; |
| 102 | + u8 *pSrc2 = (u8*)pVector2->data; |
| 103 | + size_t size1 = dims1; |
| 104 | + size_t size2 = dims2; |
| 105 | + memcpy(pDst, pSrc1, size1); |
| 106 | + memcpy(pDst + size1, pSrc2, size2); |
| 107 | + |
| 108 | + /* Copy parameters (alpha and shift) from the first vector */ |
| 109 | + float *pParams1 = (float*)(pSrc1 + ALIGN(dims1, sizeof(float))); |
| 110 | + float *pParams = (float*)(pDst + ALIGN(dims1 + dims2, sizeof(float))); |
| 111 | + memcpy(pParams, pParams1, 2 * sizeof(float)); |
| 112 | + break; |
| 113 | + } |
| 114 | + case VECTOR_TYPE_FLOAT16: { |
| 115 | + u16 *pDst = (u16*)pTarget->data; |
| 116 | + u16 *pSrc1 = (u16*)pVector1->data; |
| 117 | + u16 *pSrc2 = (u16*)pVector2->data; |
| 118 | + memcpy(pDst, pSrc1, dims1 * sizeof(u16)); |
| 119 | + memcpy(pDst + dims1, pSrc2, dims2 * sizeof(u16)); |
| 120 | + break; |
| 121 | + } |
| 122 | + case VECTOR_TYPE_FLOATB16: { |
| 123 | + u16 *pDst = (u16*)pTarget->data; |
| 124 | + u16 *pSrc1 = (u16*)pVector1->data; |
| 125 | + u16 *pSrc2 = (u16*)pVector2->data; |
| 126 | + memcpy(pDst, pSrc1, dims1 * sizeof(u16)); |
| 127 | + memcpy(pDst + dims1, pSrc2, dims2 * sizeof(u16)); |
| 128 | + break; |
| 129 | + } |
| 130 | + default: |
| 131 | + sqlite3_result_error(context, "vector_concat: unsupported vector type", -1); |
| 132 | + goto out; |
| 133 | + } |
| 134 | + |
| 135 | + vectorSerializeWithMeta(context, pTarget); |
| 136 | + |
| 137 | +out: |
| 138 | + if( pTarget ){ |
| 139 | + vectorFree(pTarget); |
| 140 | + } |
| 141 | + if( pVector2 ){ |
| 142 | + vectorFree(pVector2); |
| 143 | + } |
| 144 | + if( pVector1 ){ |
| 145 | + vectorFree(pVector1); |
| 146 | + } |
| 147 | +} |
| 148 | + |
| 149 | +/* |
| 150 | +** Implementation of vector_slice(X, start_idx, end_idx) function. |
| 151 | +** Extracts a subvector from start_idx (inclusive) to end_idx (exclusive). |
| 152 | +*/ |
| 153 | +static void vectorSliceFunc( |
| 154 | + sqlite3_context *context, |
| 155 | + int argc, |
| 156 | + sqlite3_value **argv |
| 157 | +){ |
| 158 | + char *pzErrMsg = NULL; |
| 159 | + Vector *pVector = NULL, *pTarget = NULL; |
| 160 | + int type, dims; |
| 161 | + sqlite3_int64 start_idx, end_idx; |
| 162 | + int new_dims; |
| 163 | + |
| 164 | + if( argc != 3 ){ |
| 165 | + sqlite3_result_error(context, "vector_slice requires exactly three arguments", -1); |
| 166 | + goto out; |
| 167 | + } |
| 168 | + |
| 169 | + /* Parse the vector */ |
| 170 | + if( detectVectorParameters(argv[0], 0, &type, &dims, &pzErrMsg) != 0 ){ |
| 171 | + sqlite3_result_error(context, pzErrMsg, -1); |
| 172 | + sqlite3_free(pzErrMsg); |
| 173 | + goto out; |
| 174 | + } |
| 175 | + pVector = vectorContextAlloc(context, type, dims); |
| 176 | + if( pVector == NULL ){ |
| 177 | + goto out; |
| 178 | + } |
| 179 | + if( vectorParseWithType(argv[0], pVector, &pzErrMsg) < 0 ){ |
| 180 | + sqlite3_result_error(context, pzErrMsg, -1); |
| 181 | + sqlite3_free(pzErrMsg); |
| 182 | + goto out; |
| 183 | + } |
| 184 | + |
| 185 | + /* Get start and end indices */ |
| 186 | + if( sqlite3_value_type(argv[1]) != SQLITE_INTEGER ){ |
| 187 | + sqlite3_result_error(context, "vector_slice: start_idx must be an integer", -1); |
| 188 | + goto out; |
| 189 | + } |
| 190 | + start_idx = sqlite3_value_int64(argv[1]); |
| 191 | + |
| 192 | + if( sqlite3_value_type(argv[2]) != SQLITE_INTEGER ){ |
| 193 | + sqlite3_result_error(context, "vector_slice: end_idx must be an integer", -1); |
| 194 | + goto out; |
| 195 | + } |
| 196 | + end_idx = sqlite3_value_int64(argv[2]); |
| 197 | + |
| 198 | + /* Validate indices */ |
| 199 | + if( start_idx < 0 || end_idx < 0 ){ |
| 200 | + sqlite3_result_error(context, "vector_slice: indices must be non-negative", -1); |
| 201 | + goto out; |
| 202 | + } |
| 203 | + |
| 204 | + if( start_idx > end_idx ){ |
| 205 | + sqlite3_result_error(context, "vector_slice: start_idx must not be greater than end_idx", -1); |
| 206 | + goto out; |
| 207 | + } |
| 208 | + |
| 209 | + if( start_idx >= dims || end_idx > dims ){ |
| 210 | + sqlite3_result_error(context, "vector_slice: indices out of bounds", -1); |
| 211 | + goto out; |
| 212 | + } |
| 213 | + |
| 214 | + new_dims = (int)(end_idx - start_idx); |
| 215 | + pTarget = vectorContextAlloc(context, type, new_dims); |
| 216 | + if( pTarget == NULL ){ |
| 217 | + goto out; |
| 218 | + } |
| 219 | + |
| 220 | + /* Copy the appropriate slice of data */ |
| 221 | + switch( type ){ |
| 222 | + case VECTOR_TYPE_FLOAT32: { |
| 223 | + float *pDst = (float*)pTarget->data; |
| 224 | + float *pSrc = (float*)pVector->data; |
| 225 | + memcpy(pDst, pSrc + start_idx, new_dims * sizeof(float)); |
| 226 | + break; |
| 227 | + } |
| 228 | + case VECTOR_TYPE_FLOAT64: { |
| 229 | + double *pDst = (double*)pTarget->data; |
| 230 | + double *pSrc = (double*)pVector->data; |
| 231 | + memcpy(pDst, pSrc + start_idx, new_dims * sizeof(double)); |
| 232 | + break; |
| 233 | + } |
| 234 | + case VECTOR_TYPE_FLOAT1BIT: { |
| 235 | + /* For FLOAT1BIT, we need bit-by-bit extraction, which is more complex */ |
| 236 | + sqlite3_result_error(context, "vector_slice: FLOAT1BIT vectors not yet supported", -1); |
| 237 | + goto out; |
| 238 | + } |
| 239 | + case VECTOR_TYPE_FLOAT8: { |
| 240 | + /* For FLOAT8, copy data and parameters */ |
| 241 | + u8 *pDst = (u8*)pTarget->data; |
| 242 | + u8 *pSrc = (u8*)pVector->data; |
| 243 | + memcpy(pDst, pSrc + start_idx, new_dims); |
| 244 | + |
| 245 | + /* Copy parameters (alpha and shift) */ |
| 246 | + float *pParams = (float*)(pSrc + ALIGN(dims, sizeof(float))); |
| 247 | + float *pNewParams = (float*)(pDst + ALIGN(new_dims, sizeof(float))); |
| 248 | + memcpy(pNewParams, pParams, 2 * sizeof(float)); |
| 249 | + break; |
| 250 | + } |
| 251 | + case VECTOR_TYPE_FLOAT16: { |
| 252 | + u16 *pDst = (u16*)pTarget->data; |
| 253 | + u16 *pSrc = (u16*)pVector->data; |
| 254 | + memcpy(pDst, pSrc + start_idx, new_dims * sizeof(u16)); |
| 255 | + break; |
| 256 | + } |
| 257 | + case VECTOR_TYPE_FLOATB16: { |
| 258 | + u16 *pDst = (u16*)pTarget->data; |
| 259 | + u16 *pSrc = (u16*)pVector->data; |
| 260 | + memcpy(pDst, pSrc + start_idx, new_dims * sizeof(u16)); |
| 261 | + break; |
| 262 | + } |
| 263 | + default: |
| 264 | + sqlite3_result_error(context, "vector_slice: unsupported vector type", -1); |
| 265 | + goto out; |
| 266 | + } |
| 267 | + |
| 268 | + vectorSerializeWithMeta(context, pTarget); |
| 269 | + |
| 270 | +out: |
| 271 | + if( pTarget ){ |
| 272 | + vectorFree(pTarget); |
| 273 | + } |
| 274 | + if( pVector ){ |
| 275 | + vectorFree(pVector); |
| 276 | + } |
| 277 | +} |
| 278 | + |
| 279 | +#endif /* VECTOR_FUNC_IMPL_C */ |
0 commit comments