fix gotchas with pg 16&17
[smlar.git] / smlar_gist.c
1 #include "smlar.h"
2
3 #include "fmgr.h"
4 #include "access/gist.h"
5 #include "access/skey.h"
6 #if PG_VERSION_NUM < 130000
7 #include "access/tuptoaster.h"
8 #else
9 #include "access/heaptoast.h"
10 #endif
11 #include "utils/memutils.h"
12
13 typedef struct SmlSign {
14         int32   vl_len_; /* varlena header (do not touch directly!) */
15         int32   flag:8,
16                         size:24;
17         int32   maxrepeat;
18         char    data[1];
19 } SmlSign;
20
21 #define SMLSIGNHDRSZ    (offsetof(SmlSign, data))
22
23 #define BITBYTE 8
24 #define SIGLENINT  61
25 #define SIGLEN  ( sizeof(int)*SIGLENINT )
26 #define SIGLENBIT (SIGLEN*BITBYTE - 1)  /* see makesign */
27 typedef char BITVEC[SIGLEN];
28 typedef char *BITVECP;
29 #define LOOPBYTE \
30                 for(i=0;i<SIGLEN;i++)
31
32 #define GETBYTE(x,i) ( *( (BITVECP)(x) + (int)( (i) / BITBYTE ) ) )
33 #define GETBITBYTE(x,i) ( ((char)(x)) >> i & 0x01 )
34 #define CLRBIT(x,i)   GETBYTE(x,i) &= ~( 0x01 << ( (i) % BITBYTE ) )
35 #define SETBIT(x,i)   GETBYTE(x,i) |=  ( 0x01 << ( (i) % BITBYTE ) )
36 #define GETBIT(x,i) ( (GETBYTE(x,i) >> ( (i) % BITBYTE )) & 0x01 )
37
38 #define HASHVAL(val) (((unsigned int)(val)) % SIGLENBIT)
39 #define HASH(sign, val) SETBIT((sign), HASHVAL(val))
40
41 #define ARRKEY                  0x01
42 #define SIGNKEY                 0x02
43 #define ALLISTRUE               0x04
44
45 #define ISARRKEY(x)             ( ((SmlSign*)x)->flag & ARRKEY )
46 #define ISSIGNKEY(x)    ( ((SmlSign*)x)->flag & SIGNKEY )
47 #define ISALLTRUE(x)    ( ((SmlSign*)x)->flag & ALLISTRUE )
48
49 #define CALCGTSIZE(flag, len)   ( SMLSIGNHDRSZ + ( ( (flag) & ARRKEY ) ? ((len)*sizeof(uint32)) : (((flag) & ALLISTRUE) ? 0 : SIGLEN) ) )
50 #define GETSIGN(x)                              ( (BITVECP)( (char*)x+SMLSIGNHDRSZ ) )
51 #define GETARR(x)                               ( (uint32*)( (char*)x+SMLSIGNHDRSZ ) )
52
53 #define GETENTRY(vec,pos) ((SmlSign *) DatumGetPointer((vec)->vector[(pos)].key))
54
55 #ifndef Abs
56 #define Abs(x)  abs(x)
57 #endif
58
59 /*
60  * Fake IO
61  */
62 PG_FUNCTION_INFO_V1(gsmlsign_in);
63 Datum   gsmlsign_in(PG_FUNCTION_ARGS);
64 Datum
65 gsmlsign_in(PG_FUNCTION_ARGS)
66 {
67         elog(ERROR, "not implemented");
68         PG_RETURN_DATUM(0);
69 }
70
71 PG_FUNCTION_INFO_V1(gsmlsign_out);
72 Datum   gsmlsign_out(PG_FUNCTION_ARGS);
73 Datum
74 gsmlsign_out(PG_FUNCTION_ARGS)
75 {
76         elog(ERROR, "not implemented");
77         PG_RETURN_DATUM(0);
78 }
79
80 /*
81  * Compress/decompress
82  */
83
84 /* Number of one-bits in an unsigned byte */
85 static const uint8 number_of_ones[256] = {
86         0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4,
87         1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
88         1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
89         2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
90         1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
91         2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
92         2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
93         3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
94         1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
95         2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
96         2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
97         3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
98         2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
99         3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
100         3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
101         4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8
102 };
103
104 static int
105 compareint(const void *va, const void *vb)
106 {
107         uint32  a = *((uint32 *) va);
108         uint32  b = *((uint32 *) vb);
109
110         if (a == b)
111                 return 0;
112         return (a > b) ? 1 : -1;
113 }
114
115 /*
116  * Removes duplicates from an array of int32. 'l' is
117  * size of the input array. Returns the new size of the array.
118  */
119 static int
120 uniqueint(uint32 *a, int32 l, int32 *max)
121 {
122         uint32          *ptr,
123                                 *res;
124         int32           cnt = 0;
125
126         *max = 1;
127
128         if (l <= 1)
129                 return l;
130
131         ptr = res = a;
132
133         qsort((void *) a, l, sizeof(uint32), compareint);
134
135         while (ptr - a < l)
136                 if (*ptr != *res)
137                 {
138                         cnt = 1;
139                         *(++res) = *ptr++;
140                 }
141                 else
142                 {
143                         cnt++;
144                         if ( cnt > *max )
145                                 *max = cnt;
146                         ptr++;
147                 }
148
149         if ( cnt > *max )
150                 *max = cnt;
151
152         return res + 1 - a;
153 }
154
155 SmlSign*
156 Array2HashedArray(ProcTypeInfo info, ArrayType *a)
157 {
158         SimpleArray *s = Array2SimpleArray(info, a);
159         SmlSign         *sign;
160         int32           len, i;
161         uint32          *ptr;
162
163         len = CALCGTSIZE( ARRKEY, s->nelems );
164
165         getFmgrInfoHash(s->info);
166         if (s->info->tupDesc)
167                 elog(ERROR, "GiST  doesn't support composite (weighted) type");
168
169         sign = palloc( len );
170         sign->flag = ARRKEY;
171         sign->size = s->nelems;
172
173         ptr = GETARR(sign);
174         for(i=0;i<s->nelems;i++)
175                 ptr[i] = DatumGetUInt32(FunctionCall1Coll(&s->info->hashFunc,
176                                                                                                   DEFAULT_COLLATION_OID,
177                                                                                                   s->elems[i]));
178
179         /*
180          * there is a collision of hash-function; len is always equal or less than
181          * s->nelems
182          */
183         sign->size = uniqueint( GETARR(sign), sign->size, &sign->maxrepeat );
184         len = CALCGTSIZE( ARRKEY, sign->size );
185         SET_VARSIZE(sign, len);
186
187         return sign;
188 }
189
190 static int
191 HashedElemCmp(const void *va, const void *vb)
192 {
193         uint32  a = ((HashedElem *) va)->hash;
194         uint32  b = ((HashedElem *) vb)->hash;
195
196         if (a == b)
197         {
198                 double  ma = ((HashedElem *) va)->idfMin;
199                 double  mb = ((HashedElem *) va)->idfMin;
200
201                 if (ma == mb)
202                         return 0;
203
204                 return ( ma > mb ) ? 1 : -1;
205         }
206
207         return (a > b) ? 1 : -1;
208 }
209
210 static int
211 uniqueHashedElem(HashedElem *a, int32 l)
212 {
213         HashedElem      *ptr,
214                                 *res;
215
216         if (l <= 1)
217                 return l;
218
219         ptr = res = a;
220
221         qsort(a, l, sizeof(HashedElem), HashedElemCmp);
222
223         while (ptr - a < l)
224                 if (ptr->hash != res->hash)
225                         *(++res) = *ptr++;
226                 else
227                 {
228                         res->idfMax = ptr->idfMax;
229                         ptr++;
230                 }
231
232         return res + 1 - a;
233 }
234
235 static StatCache*
236 getHashedCache(void *cache)
237 {
238         StatCache *stat = getStat(cache, SIGLENBIT);
239
240         if ( stat->nhelems < 0 )
241         {
242                 int i;
243                 /*
244                  * Init
245                  */
246
247                 if (stat->info->tupDesc)
248                         elog(ERROR, "GiST  doesn't support composite (weighted) type");
249                 getFmgrInfoHash(stat->info);
250                 for(i=0;i<stat->nelems;i++)
251                 {
252                         uint32  hash;
253                         int             index;
254
255                         hash = DatumGetUInt32(FunctionCall1Coll(&stat->info->hashFunc,
256                                                                                                         DEFAULT_COLLATION_OID,
257                                                                                                         stat->elems[i].datum));
258                         index = HASHVAL(hash);
259
260                         stat->helems[i].hash = hash;
261                         stat->helems[i].idfMin = stat->helems[i].idfMax = stat->elems[i].idf;   
262                         if ( stat->selems[index].idfMin == 0.0 )
263                                 stat->selems[index].idfMin = stat->selems[index].idfMax = stat->elems[i].idf;
264                         else if ( stat->selems[index].idfMin > stat->elems[i].idf )
265                                 stat->selems[index].idfMin = stat->elems[i].idf;
266                         else if ( stat->selems[index].idfMax < stat->elems[i].idf )
267                                 stat->selems[index].idfMax = stat->elems[i].idf;
268                 }
269
270                 stat->nhelems = uniqueHashedElem( stat->helems, stat->nelems);
271         }
272
273         return stat;
274 }
275
276 static HashedElem*
277 getHashedElemIdf(StatCache *stat, uint32 hash, HashedElem *StopLow)
278 {
279         HashedElem      *StopMiddle,
280                                 *StopHigh = stat->helems + stat->nhelems;
281
282         if ( !StopLow )
283                 StopLow = stat->helems;
284
285         while (StopLow < StopHigh) {
286                 StopMiddle = StopLow + ((StopHigh - StopLow) >> 1);
287
288                 if ( StopMiddle->hash == hash )
289                         return StopMiddle;
290                 else if ( StopMiddle->hash < hash )
291                         StopLow = StopMiddle + 1;
292                 else
293                         StopHigh = StopMiddle;
294         }
295
296         return NULL;
297 }
298
299 static void
300 fillHashVal(void *cache, SimpleArray *a)
301 {
302         int i;
303
304         if (a->hash)
305                 return;
306
307         allocateHash(cache, a);
308
309         if (a->info->tupDesc)
310                 elog(ERROR, "GiST  doesn't support composite (weighted) type");
311         getFmgrInfoHash(a->info);
312
313         for(i=0;i<a->nelems;i++)
314                 a->hash[i] = DatumGetUInt32(FunctionCall1Coll(&a->info->hashFunc,
315                                                                                                           DEFAULT_COLLATION_OID,
316                                                                                                           a->elems[i]));
317 }
318
319
320 static bool
321 hasHashedElem(SmlSign  *a, uint32 h)
322 {
323         uint32  *StopLow = GETARR(a),
324                         *StopHigh = GETARR(a) + a->size,
325                         *StopMiddle;
326
327         while (StopLow < StopHigh) {
328                 StopMiddle = StopLow + ((StopHigh - StopLow) >> 1);
329
330                 if ( *StopMiddle == h )
331                         return true;
332                 else if ( *StopMiddle < h )
333                         StopLow = StopMiddle + 1;
334                 else
335                         StopHigh = StopMiddle;
336         }
337
338         return false;
339 }
340
341 static void
342 makesign(BITVECP sign, SmlSign  *a)
343 {
344         int32   i;
345         uint32  *ptr = GETARR(a);
346
347         MemSet((void *) sign, 0, sizeof(BITVEC));
348         SETBIT(sign, SIGLENBIT);   /* set last unused bit */
349
350         for (i = 0; i < a->size; i++)
351                 HASH(sign, ptr[i]);
352 }
353
354 static int32
355 sizebitvec(BITVECP sign)
356 {
357         int32   size = 0,
358                         i;
359
360         LOOPBYTE
361                 size += number_of_ones[(unsigned char) sign[i]];
362
363         return size;
364 }
365
366 PG_FUNCTION_INFO_V1(gsmlsign_compress);
367 Datum gsmlsign_compress(PG_FUNCTION_ARGS);
368 Datum
369 gsmlsign_compress(PG_FUNCTION_ARGS)
370 {
371         GISTENTRY  *entry = (GISTENTRY *) PG_GETARG_POINTER(0);
372         GISTENTRY  *retval = entry;
373
374         if (entry->leafkey) /* new key */
375         {
376                 SmlSign         *sign;
377                 ArrayType       *a = DatumGetArrayTypeP(entry->key);
378
379                 sign = Array2HashedArray(NULL, a);
380
381                 if ( VARSIZE(sign) > TOAST_INDEX_TARGET )
382                 {       /* make signature due to its big size */
383                         SmlSign *tmpsign;
384                         int             len;
385
386                         len = CALCGTSIZE( SIGNKEY, sign->size );
387                         tmpsign = palloc( len );
388                         tmpsign->flag = SIGNKEY;
389                         SET_VARSIZE(tmpsign, len);
390
391                         makesign(GETSIGN(tmpsign), sign);
392                         tmpsign->size = sizebitvec(GETSIGN(tmpsign));
393                         tmpsign->maxrepeat = sign->maxrepeat;
394                         sign = tmpsign;
395                 }
396
397                 retval = (GISTENTRY *) palloc(sizeof(GISTENTRY));
398                 gistentryinit(*retval, PointerGetDatum(sign),
399                                                 entry->rel, entry->page,
400                                                 entry->offset, false);
401         }
402         else if ( ISSIGNKEY(DatumGetPointer(entry->key)) &&
403                                 !ISALLTRUE(DatumGetPointer(entry->key)) )
404         {
405                 SmlSign *sign = (SmlSign*)DatumGetPointer(entry->key);
406
407                 Assert( sign->size == sizebitvec(GETSIGN(sign)) );
408
409                 if ( sign->size == SIGLENBIT )
410                 {
411                         int32   len = CALCGTSIZE(SIGNKEY | ALLISTRUE, 0);
412                         int32   maxrepeat = sign->maxrepeat;
413
414                         sign = (SmlSign *) palloc(len);
415                         SET_VARSIZE(sign, len);
416                         sign->flag = SIGNKEY | ALLISTRUE;
417                         sign->size = SIGLENBIT;
418                         sign->maxrepeat = maxrepeat;
419
420                         retval = (GISTENTRY *) palloc(sizeof(GISTENTRY));
421
422                         gistentryinit(*retval, PointerGetDatum(sign),
423                                                         entry->rel, entry->page,
424                                                         entry->offset, false);
425                 }
426         }
427
428         PG_RETURN_POINTER(retval);
429 }
430
431 PG_FUNCTION_INFO_V1(gsmlsign_decompress);
432 Datum gsmlsign_decompress(PG_FUNCTION_ARGS);
433 Datum
434 gsmlsign_decompress(PG_FUNCTION_ARGS)
435 {
436         GISTENTRY       *entry = (GISTENTRY *) PG_GETARG_POINTER(0);
437         SmlSign         *key =  (SmlSign*)PG_DETOAST_DATUM(entry->key);
438
439         if (key != (SmlSign *) DatumGetPointer(entry->key))
440         {
441                 GISTENTRY  *retval = (GISTENTRY *) palloc(sizeof(GISTENTRY));
442
443                 gistentryinit(*retval, PointerGetDatum(key),
444                                                 entry->rel, entry->page,
445                                                 entry->offset, false);
446
447                 PG_RETURN_POINTER(retval);
448         }
449
450         PG_RETURN_POINTER(entry);
451 }
452
453 /*
454  * Union method
455  */
456 static bool
457 unionkey(BITVECP sbase, SmlSign *add)
458 {
459         int32   i;
460
461         if (ISSIGNKEY(add))
462         {
463                 BITVECP sadd = GETSIGN(add);
464
465                 if (ISALLTRUE(add))
466                         return true;
467
468                 LOOPBYTE
469                         sbase[i] |= sadd[i];
470         }
471         else
472         {
473                 uint32  *ptr = GETARR(add);
474
475                 for (i = 0; i < add->size; i++)
476                         HASH(sbase, ptr[i]);
477         }
478
479         return false;
480 }
481
482 PG_FUNCTION_INFO_V1(gsmlsign_union);
483 Datum gsmlsign_union(PG_FUNCTION_ARGS);
484 Datum
485 gsmlsign_union(PG_FUNCTION_ARGS)
486 {
487         GistEntryVector *entryvec = (GistEntryVector *) PG_GETARG_POINTER(0);
488         int                             *size = (int *) PG_GETARG_POINTER(1);
489         BITVEC                  base;
490         int32           i,
491                                 len,
492                                 maxrepeat = 1;
493         int32           flag = 0;
494         SmlSign    *result;
495
496         MemSet((void *) base, 0, sizeof(BITVEC));
497         for (i = 0; i < entryvec->n; i++)
498         {
499                 if (GETENTRY(entryvec, i)->maxrepeat > maxrepeat)
500                         maxrepeat = GETENTRY(entryvec, i)->maxrepeat;
501                 if (unionkey(base, GETENTRY(entryvec, i)))
502                 {
503                         flag = ALLISTRUE;
504                         break;
505                 }
506         }
507
508         flag |= SIGNKEY;
509         len = CALCGTSIZE(flag, 0);
510         result = (SmlSign *) palloc(len);
511         *size = len;
512         SET_VARSIZE(result, len);
513         result->flag = flag;
514         result->maxrepeat = maxrepeat;
515
516         if (!ISALLTRUE(result))
517         {
518                 memcpy((void *) GETSIGN(result), (void *) base, sizeof(BITVEC));
519                 result->size = sizebitvec(GETSIGN(result));
520         }
521         else
522                 result->size = SIGLENBIT;
523
524         PG_RETURN_POINTER(result);
525 }
526
527 /*
528  * Same method
529  */
530
531 PG_FUNCTION_INFO_V1(gsmlsign_same);
532 Datum gsmlsign_same(PG_FUNCTION_ARGS);
533 Datum
534 gsmlsign_same(PG_FUNCTION_ARGS)
535 {
536         SmlSign *a = (SmlSign*)PG_GETARG_POINTER(0);
537         SmlSign *b = (SmlSign*)PG_GETARG_POINTER(1);
538         bool    *result = (bool *) PG_GETARG_POINTER(2);
539
540         if (a->size != b->size)
541         {
542                 *result = false;
543         }
544         else if (ISSIGNKEY(a))
545         {       /* then b also ISSIGNKEY */
546                 if ( ISALLTRUE(a) )
547                 {
548                         /* in this case b is all true too - other cases is catched
549                            upper */
550                         *result = true;
551                 }
552                 else
553                 {
554                         int32   i;
555                         BITVECP sa = GETSIGN(a),
556                                         sb = GETSIGN(b);
557
558                         *result = true;
559
560                         if ( !ISALLTRUE(a) )
561                         {
562                                 LOOPBYTE
563                                 {
564                                         if (sa[i] != sb[i])
565                                         {
566                                                 *result = false;
567                                                 break;
568                                         }
569                                 }
570                         }
571                 }
572         }
573         else
574         {
575                 uint32  *ptra = GETARR(a),
576                                 *ptrb = GETARR(b);
577                 int32   i;
578
579                 *result = true;
580                 for (i = 0; i < a->size; i++)
581                 {
582                         if ( ptra[i] != ptrb[i])
583                         {
584                                 *result = false;
585                                 break;
586                         }
587                 }
588         }
589
590         PG_RETURN_POINTER(result);
591 }
592
593 /*
594  * Penalty method
595  */
596 static int
597 hemdistsign(BITVECP a, BITVECP b)
598 {
599         int     i,
600                 diff,
601                 dist = 0;
602
603         LOOPBYTE
604         {
605                 diff = (unsigned char) (a[i] ^ b[i]);
606                 dist += number_of_ones[diff];
607         }
608         return dist;
609 }
610
611 static int
612 hemdist(SmlSign *a, SmlSign *b)
613 {
614         if (ISALLTRUE(a))
615         {
616                 if (ISALLTRUE(b))
617                         return 0;
618                 else
619                         return SIGLENBIT - b->size;
620         }
621         else if (ISALLTRUE(b))
622                 return SIGLENBIT - b->size;
623
624         return hemdistsign(GETSIGN(a), GETSIGN(b));
625 }
626
627 PG_FUNCTION_INFO_V1(gsmlsign_penalty);
628 Datum gsmlsign_penalty(PG_FUNCTION_ARGS);
629 Datum
630 gsmlsign_penalty(PG_FUNCTION_ARGS)
631 {
632         GISTENTRY       *origentry = (GISTENTRY *) PG_GETARG_POINTER(0); /* always ISSIGNKEY */
633         GISTENTRY       *newentry = (GISTENTRY *) PG_GETARG_POINTER(1);
634         float           *penalty = (float *) PG_GETARG_POINTER(2);
635         SmlSign         *origval = (SmlSign *) DatumGetPointer(origentry->key);
636         SmlSign         *newval = (SmlSign *) DatumGetPointer(newentry->key);
637         BITVECP         orig = GETSIGN(origval);
638
639         *penalty = 0.0;
640
641         if (ISARRKEY(newval))
642         {
643                 BITVEC  sign;
644
645                 makesign(sign, newval);
646
647                 if (ISALLTRUE(origval))
648                         *penalty = ((float) (SIGLENBIT - sizebitvec(sign))) / (float) (SIGLENBIT + 1);
649                 else
650                         *penalty = hemdistsign(sign, orig);
651         }
652         else
653                 *penalty = hemdist(origval, newval);
654
655         PG_RETURN_POINTER(penalty);
656 }
657
658 /*
659  * Picksplit method
660  */
661
662 typedef struct
663 {
664         bool    allistrue;
665         int32   size;
666         BITVEC  sign;
667 } CACHESIGN;
668
669 static void
670 fillcache(CACHESIGN *item, SmlSign *key)
671 {
672         item->allistrue = false;
673         item->size = key->size;
674
675         if (ISARRKEY(key))
676         {
677                 makesign(item->sign, key);
678                 item->size = sizebitvec( item->sign );
679         }
680         else if (ISALLTRUE(key))
681                 item->allistrue = true;
682         else
683                 memcpy((void *) item->sign, (void *) GETSIGN(key), sizeof(BITVEC));
684 }
685
686 #define WISH_F(a,b,c) (double)( -(double)(((a)-(b))*((a)-(b))*((a)-(b)))*(c) )
687
688 typedef struct
689 {
690         OffsetNumber pos;
691         int32           cost;
692 } SPLITCOST;
693
694 static int
695 comparecost(const void *va, const void *vb)
696 {
697         SPLITCOST  *a = (SPLITCOST *) va;
698         SPLITCOST  *b = (SPLITCOST *) vb;
699
700         if (a->cost == b->cost)
701                 return 0;
702         else
703                 return (a->cost > b->cost) ? 1 : -1;
704 }
705
706 static int
707 hemdistcache(CACHESIGN *a, CACHESIGN *b)
708 {
709         if (a->allistrue)
710         {
711                 if (b->allistrue)
712                         return 0;
713                 else
714                         return SIGLENBIT - b->size;
715         }
716         else if (b->allistrue)
717                 return SIGLENBIT - a->size;
718
719         return hemdistsign(a->sign, b->sign);
720 }
721
722 PG_FUNCTION_INFO_V1(gsmlsign_picksplit);
723 Datum gsmlsign_picksplit(PG_FUNCTION_ARGS);
724 Datum
725 gsmlsign_picksplit(PG_FUNCTION_ARGS)
726 {
727         GistEntryVector *entryvec = (GistEntryVector *) PG_GETARG_POINTER(0);
728         GIST_SPLITVEC *v = (GIST_SPLITVEC *) PG_GETARG_POINTER(1);
729         OffsetNumber k,
730                                 j;
731         SmlSign         *datum_l,
732                                 *datum_r;
733         BITVECP         union_l,
734                                 union_r;
735         int32           size_alpha,
736                                 size_beta;
737         int32           size_waste,
738                                 waste = -1;
739         int32           nbytes;
740         OffsetNumber seed_1 = 0,
741                                 seed_2 = 0;
742         OffsetNumber *left,
743                                 *right;
744         OffsetNumber maxoff;
745         BITVECP         ptr;
746         int                     i;
747         CACHESIGN  *cache;
748         SPLITCOST  *costvector;
749
750         maxoff = entryvec->n - 2;
751         nbytes = (maxoff + 2) * sizeof(OffsetNumber);
752         v->spl_left = (OffsetNumber *) palloc(nbytes);
753         v->spl_right = (OffsetNumber *) palloc(nbytes);
754
755         cache = (CACHESIGN *) palloc(sizeof(CACHESIGN) * (maxoff + 2));
756         fillcache(&cache[FirstOffsetNumber], GETENTRY(entryvec, FirstOffsetNumber));
757
758         for (k = FirstOffsetNumber; k < maxoff; k = OffsetNumberNext(k))
759         {
760                 for (j = OffsetNumberNext(k); j <= maxoff; j = OffsetNumberNext(j))
761                 {
762                         if (k == FirstOffsetNumber)
763                                 fillcache(&cache[j], GETENTRY(entryvec, j));
764
765                         size_waste = hemdistcache(&(cache[j]), &(cache[k]));
766                         if (size_waste > waste)
767                         {
768                                 waste = size_waste;
769                                 seed_1 = k;
770                                 seed_2 = j;
771                         }
772                 }
773         }
774
775         left = v->spl_left;
776         v->spl_nleft = 0;
777         right = v->spl_right;
778         v->spl_nright = 0;
779
780         if (seed_1 == 0 || seed_2 == 0)
781         {
782                 seed_1 = 1;
783                 seed_2 = 2;
784         }
785
786         /* form initial .. */
787         if (cache[seed_1].allistrue)
788         {
789                 datum_l = (SmlSign *) palloc(CALCGTSIZE(SIGNKEY | ALLISTRUE, 0));
790                 SET_VARSIZE(datum_l, CALCGTSIZE(SIGNKEY | ALLISTRUE, 0));
791                 datum_l->flag = SIGNKEY | ALLISTRUE;
792                 datum_l->size = SIGLENBIT;
793         }
794         else
795         {
796                 datum_l = (SmlSign *) palloc(CALCGTSIZE(SIGNKEY, 0));
797                 SET_VARSIZE(datum_l, CALCGTSIZE(SIGNKEY, 0));
798                 datum_l->flag = SIGNKEY;
799                 memcpy((void *) GETSIGN(datum_l), (void *) cache[seed_1].sign, sizeof(BITVEC));
800                 datum_l->size = cache[seed_1].size;
801         }
802         if (cache[seed_2].allistrue)
803         {
804                 datum_r = (SmlSign *) palloc(CALCGTSIZE(SIGNKEY | ALLISTRUE, 0));
805                 SET_VARSIZE(datum_r, CALCGTSIZE(SIGNKEY | ALLISTRUE, 0));
806                 datum_r->flag = SIGNKEY | ALLISTRUE;
807                 datum_r->size = SIGLENBIT;
808         }
809         else
810         {
811                 datum_r = (SmlSign *) palloc(CALCGTSIZE(SIGNKEY, 0));
812                 SET_VARSIZE(datum_r, CALCGTSIZE(SIGNKEY, 0));
813                 datum_r->flag = SIGNKEY;
814                 memcpy((void *) GETSIGN(datum_r), (void *) cache[seed_2].sign, sizeof(BITVEC));
815                 datum_r->size = cache[seed_2].size;
816         }
817
818         union_l = GETSIGN(datum_l);
819         union_r = GETSIGN(datum_r);
820         maxoff = OffsetNumberNext(maxoff);
821         fillcache(&cache[maxoff], GETENTRY(entryvec, maxoff));
822         /* sort before ... */
823         costvector = (SPLITCOST *) palloc(sizeof(SPLITCOST) * maxoff);
824         for (j = FirstOffsetNumber; j <= maxoff; j = OffsetNumberNext(j))
825         {
826                 costvector[j - 1].pos = j;
827                 size_alpha = hemdistcache(&(cache[seed_1]), &(cache[j]));
828                 size_beta = hemdistcache(&(cache[seed_2]), &(cache[j]));
829                 costvector[j - 1].cost = Abs(size_alpha - size_beta);
830         }
831         qsort((void *) costvector, maxoff, sizeof(SPLITCOST), comparecost);
832
833         datum_l->maxrepeat = datum_r->maxrepeat = 1;
834
835         for (k = 0; k < maxoff; k++)
836         {
837                 j = costvector[k].pos;
838                 if (j == seed_1)
839                 {
840                         *left++ = j;
841                         v->spl_nleft++;
842                         continue;
843                 }
844                 else if (j == seed_2)
845                 {
846                         *right++ = j;
847                         v->spl_nright++;
848                         continue;
849                 }
850
851                 if (ISALLTRUE(datum_l) || cache[j].allistrue)
852                 {
853                         if (ISALLTRUE(datum_l) && cache[j].allistrue)
854                                 size_alpha = 0;
855                         else
856                                 size_alpha = SIGLENBIT - (
857                                                                         (cache[j].allistrue) ? datum_l->size : cache[j].size
858                                                         );
859                 }
860                 else
861                         size_alpha = hemdistsign(cache[j].sign, GETSIGN(datum_l));
862
863                 if (ISALLTRUE(datum_r) || cache[j].allistrue)
864                 {
865                         if (ISALLTRUE(datum_r) && cache[j].allistrue)
866                                 size_beta = 0;
867                         else
868                                 size_beta = SIGLENBIT - (
869                                                                         (cache[j].allistrue) ? datum_r->size : cache[j].size
870                                                         );
871                 }
872                 else
873                         size_beta = hemdistsign(cache[j].sign, GETSIGN(datum_r));
874
875                 if (size_alpha < size_beta + WISH_F(v->spl_nleft, v->spl_nright, 0.1))
876                 {
877                         if (ISALLTRUE(datum_l) || cache[j].allistrue)
878                         {
879                                 if (!ISALLTRUE(datum_l))
880                                         MemSet((void *) GETSIGN(datum_l), 0xff, sizeof(BITVEC));
881                                 datum_l->size = SIGLENBIT;
882                         }
883                         else
884                         {
885                                 ptr = cache[j].sign;
886                                 LOOPBYTE
887                                         union_l[i] |= ptr[i];
888                                 datum_l->size = sizebitvec(union_l);
889                         }
890                         *left++ = j;
891                         v->spl_nleft++;
892                 }
893                 else
894                 {
895                         if (ISALLTRUE(datum_r) || cache[j].allistrue)
896                         {
897                                 if (!ISALLTRUE(datum_r))
898                                         MemSet((void *) GETSIGN(datum_r), 0xff, sizeof(BITVEC));
899                                 datum_r->size = SIGLENBIT;
900                         }
901                         else
902                         {
903                                 ptr = cache[j].sign;
904                                 LOOPBYTE
905                                         union_r[i] |= ptr[i];
906                                 datum_r->size = sizebitvec(union_r);
907                         }
908                         *right++ = j;
909                         v->spl_nright++;
910                 }
911         }
912         *right = *left = FirstOffsetNumber;
913         v->spl_ldatum = PointerGetDatum(datum_l);
914         v->spl_rdatum = PointerGetDatum(datum_r);
915
916         Assert( datum_l->size = sizebitvec(GETSIGN(datum_l)) );
917         Assert( datum_r->size = sizebitvec(GETSIGN(datum_r)) );
918
919         PG_RETURN_POINTER(v);
920 }
921
922 static double
923 getIdfMaxLimit(SmlSign *key)
924 {
925         switch( getTFMethod() )
926         {
927                 case TF_CONST:
928                         return 1.0;
929                         break;
930                 case TF_N:
931                         return (double)(key->maxrepeat);
932                         break;
933                 case TF_LOG:
934                         return 1.0 + log( (double)(key->maxrepeat) );
935                         break;
936                 default:
937                         elog(ERROR,"Unknown TF method: %d", getTFMethod());
938         }
939
940         return 0.0;
941 }
942
943 /*
944  * Consistent function
945  */
946 PG_FUNCTION_INFO_V1(gsmlsign_consistent);
947 Datum gsmlsign_consistent(PG_FUNCTION_ARGS);
948 Datum
949 gsmlsign_consistent(PG_FUNCTION_ARGS)
950 {
951         GISTENTRY               *entry = (GISTENTRY *) PG_GETARG_POINTER(0);
952         StrategyNumber  strategy = (StrategyNumber) PG_GETARG_UINT16(2);
953         bool                    *recheck = (bool *) PG_GETARG_POINTER(4);
954         ArrayType               *a;
955         SmlSign                 *key = (SmlSign*)DatumGetPointer(entry->key);
956         int                             res = false;
957         SmlSign                 *query;
958         SimpleArray             *s;
959         int32                   i;
960
961         fcinfo->flinfo->fn_extra = SearchArrayCache(
962                                                                         fcinfo->flinfo->fn_extra,
963                                                                         fcinfo->flinfo->fn_mcxt,
964                                                                         PG_GETARG_DATUM(1), &a, &s, &query);
965
966         *recheck = true;
967
968         if ( ARRISVOID(a) )
969                 PG_RETURN_BOOL(res);
970
971         if ( strategy == SmlarOverlapStrategy )
972         {
973                 if (ISALLTRUE(key))
974                 {
975                         res = true;
976                 }
977                 else if (ISARRKEY(key))
978                 {
979                         uint32  *kptr = GETARR(key),
980                                         *qptr = GETARR(query);
981
982                         while( kptr - GETARR(key) < key->size && qptr - GETARR(query) < query->size )
983                         {
984                                 if ( *kptr < *qptr )
985                                         kptr++;
986                                 else if ( *kptr > *qptr )
987                                         qptr++;
988                                 else
989                                 {
990                                         res = true;
991                                         break;
992                                 }
993                         }
994                         *recheck = false;
995                 }
996                 else
997                 {
998                         BITVECP sign = GETSIGN(key);
999
1000                         fillHashVal(fcinfo->flinfo->fn_extra, s);
1001
1002                         for(i=0; i<s->nelems; i++)
1003                         {
1004                                 if ( GETBIT(sign, HASHVAL(s->hash[i])) )
1005                                 {
1006                                         res = true;
1007                                         break;
1008                                 }
1009                         }
1010                 }
1011         }
1012         /*
1013          *  SmlarSimilarityStrategy
1014          */
1015         else if (ISALLTRUE(key))
1016         {
1017                 if ( GIST_LEAF(entry) )
1018                 {
1019                         /*
1020                          * With TF/IDF similarity we cannot say anything useful
1021                          */
1022                         if ( query->size < SIGLENBIT && getSmlType() != ST_TFIDF )
1023                         {
1024                                 double power = ((double)(query->size)) * ((double)(SIGLENBIT));
1025
1026                                 if ( ((double)(query->size)) / sqrt(power) >= GetSmlarLimit() ) 
1027                                         res = true;
1028                         }
1029                         else
1030                         {
1031                                 res = true;     
1032                         }
1033                 }
1034                 else
1035                         res = true;
1036         }
1037         else if (ISARRKEY(key))
1038         {
1039                 uint32  *kptr = GETARR(key),
1040                                 *qptr = GETARR(query);
1041
1042                 Assert( GIST_LEAF(entry) );
1043
1044                 switch(getSmlType())
1045                 {
1046                         case ST_TFIDF:
1047                                 {
1048                                         StatCache       *stat = getHashedCache(fcinfo->flinfo->fn_extra);
1049                                         double          sumU = 0.0,
1050                                                                 sumQ = 0.0,
1051                                                                 sumK = 0.0;
1052                                         double          maxKTF = getIdfMaxLimit(key);
1053                                         HashedElem  *h;
1054
1055                                         Assert( s->df );
1056                                         fillHashVal(fcinfo->flinfo->fn_extra, s);
1057                                         if ( stat->info != s->info )
1058                                                 elog(ERROR,"Statistic and actual argument have different type");
1059
1060                                         for(i=0;i<s->nelems;i++)
1061                                         {
1062                                                 sumQ += s->df[i] * s->df[i];
1063
1064                                                 h = getHashedElemIdf(stat, s->hash[i], NULL);
1065                                                 if ( h && hasHashedElem(key, s->hash[i]) )
1066                                                 {
1067                                                         sumK += h->idfMin * h->idfMin;
1068                                                         sumU += h->idfMax * maxKTF * s->df[i];
1069                                                 } 
1070                                         }
1071
1072                                         if ( sumK > 0.0 && sumQ > 0.0 && sumU / sqrt( sumK * sumQ ) >= GetSmlarLimit() )
1073                                         {
1074                                                 /* 
1075                                                  * More precisely calculate sumK
1076                                                  */
1077                                                 h = NULL;
1078                                                 sumK = 0.0;
1079
1080                                                 for(i=0;i<key->size;i++)
1081                                                 {
1082                                                         h = getHashedElemIdf(stat, GETARR(key)[i], h);                                  
1083                                                         if (h)
1084                                                                 sumK += h->idfMin * h->idfMin;
1085                                                 }
1086
1087                                                 if ( sumK > 0.0 && sumQ > 0.0 && sumU / sqrt( sumK * sumQ ) >= GetSmlarLimit() )
1088                                                         res = true;
1089                                         }
1090                                 }
1091                                 break;
1092                         case ST_COSINE:
1093                                 {
1094                                         double                  power;
1095                                         power = sqrt( ((double)(key->size)) * ((double)(s->nelems)) );
1096
1097                                         if (  ((double)Min(key->size, s->nelems)) / power >= GetSmlarLimit() )
1098                                         {
1099                                                 int  cnt = 0;
1100
1101                                                 while( kptr - GETARR(key) < key->size && qptr - GETARR(query) < query->size )
1102                                                 {
1103                                                         if ( *kptr < *qptr )
1104                                                                 kptr++;
1105                                                         else if ( *kptr > *qptr )
1106                                                                 qptr++;
1107                                                         else
1108                                                         {
1109                                                                 cnt++;
1110                                                                 kptr++;
1111                                                                 qptr++;
1112                                                         }
1113                                                 }
1114
1115                                                 if ( ((double)cnt) / power >= GetSmlarLimit() )
1116                                                         res = true;
1117                                         }
1118                                 }
1119                                 break;
1120                         default:
1121                                 elog(ERROR,"GiST doesn't support current formula type of similarity");
1122                 }
1123         }
1124         else
1125         {       /* signature */
1126                 BITVECP sign = GETSIGN(key);
1127                 int32   count = 0;
1128
1129                 fillHashVal(fcinfo->flinfo->fn_extra, s);
1130
1131                 if ( GIST_LEAF(entry) )
1132                 {
1133                         switch(getSmlType())
1134                         {
1135                                 case ST_TFIDF:
1136                                         {
1137                                                 StatCache       *stat = getHashedCache(fcinfo->flinfo->fn_extra);
1138                                                 double          sumU = 0.0,
1139                                                                         sumQ = 0.0,
1140                                                                         sumK = 0.0;
1141                                                 double          maxKTF = getIdfMaxLimit(key);
1142
1143                                                 Assert( s->df );
1144                                                 if ( stat->info != s->info )
1145                                                         elog(ERROR,"Statistic and actual argument have different type");
1146
1147                                                 for(i=0;i<s->nelems;i++)
1148                                                 {
1149                                                         int32           hbit = HASHVAL(s->hash[i]);
1150
1151                                                         sumQ += s->df[i] * s->df[i];
1152                                                         if ( GETBIT(sign, hbit) )
1153                                                         {
1154                                                                 sumK += stat->selems[ hbit ].idfMin * stat->selems[ hbit ].idfMin;
1155                                                                 sumU += stat->selems[ hbit ].idfMax * maxKTF * s->df[i];
1156                                                         }
1157                                                 }
1158
1159                                                 if ( sumK > 0.0 && sumQ > 0.0 && sumU / sqrt( sumK * sumQ ) >= GetSmlarLimit() )
1160                                                 {
1161                                                         /* 
1162                                                          * More precisely calculate sumK
1163                                                          */
1164                                                         sumK = 0.0;
1165
1166                                                         for(i=0;i<SIGLENBIT;i++)
1167                                                                 if ( GETBIT(sign,i) )
1168                                                                         sumK += stat->selems[ i ].idfMin * stat->selems[ i ].idfMin;
1169
1170                                                         if ( sumK > 0.0 && sumQ > 0.0 && sumU / sqrt( sumK * sumQ ) >= GetSmlarLimit() )
1171                                                                 res = true;
1172                                                 }
1173                                         }
1174                                         break;
1175                                 case ST_COSINE:
1176                                         {
1177                                                 double  power;
1178
1179                                                 power = sqrt( ((double)(key->size)) * ((double)(s->nelems)) );
1180
1181                                                 for(i=0; i<s->nelems; i++)
1182                                                         count += GETBIT(sign, HASHVAL(s->hash[i]));
1183
1184                                                 if ( ((double)count) / power >= GetSmlarLimit() )
1185                                                         res = true;
1186                                         }
1187                                         break;
1188                                 default:
1189                                         elog(ERROR,"GiST doesn't support current formula type of similarity");
1190                         }
1191                 }
1192                 else /* non-leaf page */
1193                 {
1194                         switch(getSmlType())
1195                         {
1196                                 case ST_TFIDF:
1197                                         {
1198                                                 StatCache       *stat = getHashedCache(fcinfo->flinfo->fn_extra);
1199                                                 double          sumU = 0.0,
1200                                                                         sumQ = 0.0,
1201                                                                         minK = -1.0;
1202                                                 double          maxKTF = getIdfMaxLimit(key);
1203
1204                                                 Assert( s->df );
1205                                                 if ( stat->info != s->info )
1206                                                         elog(ERROR,"Statistic and actual argument have different type");
1207
1208                                                 for(i=0;i<s->nelems;i++)
1209                                                 {
1210                                                         int32           hbit = HASHVAL(s->hash[i]);
1211
1212                                                         sumQ += s->df[i] * s->df[i];
1213                                                         if ( GETBIT(sign, hbit) )
1214                                                         {
1215                                                                 sumU += stat->selems[ hbit ].idfMax * maxKTF * s->df[i];
1216                                                                 if ( minK > stat->selems[ hbit ].idfMin  || minK < 0.0 )
1217                                                                         minK = stat->selems[ hbit ].idfMin;
1218                                                         }
1219                                                 }
1220
1221                                                 if ( sumQ > 0.0 && minK > 0.0 && sumU / sqrt( sumQ * minK ) >= GetSmlarLimit() )
1222                                                         res = true;
1223                                         }
1224                                         break;
1225                                 case ST_COSINE:
1226                                         {
1227                                                 for(i=0; i<s->nelems; i++)
1228                                                         count += GETBIT(sign, HASHVAL(s->hash[i]));
1229
1230                                                 if ( s->nelems == count  || sqrt(((double)count) / ((double)(s->nelems))) >= GetSmlarLimit() )
1231                                                         res = true;
1232                                         }
1233                                         break;
1234                                 default:
1235                                         elog(ERROR,"GiST doesn't support current formula type of similarity");
1236                         }
1237                 }
1238         }
1239
1240 #if 0
1241         {
1242                 static int nnres = 0;
1243                 static int nres = 0;
1244                 if ( GIST_LEAF(entry) ) {
1245                         if ( res )
1246                                 nres++;
1247                         else
1248                                 nnres++;
1249                         elog(NOTICE,"%s nn:%d n:%d", (ISARRKEY(key)) ? "ARR" : ( (ISALLTRUE(key)) ? "TRUE" : "SIGN" ), nnres, nres  );
1250                 }
1251         }
1252 #endif
1253
1254         PG_RETURN_BOOL(res);
1255 }