Fix compiler warnings
[smlar.git] / smlar_gist.c
1 #include "smlar.h"
2
3 #include "fmgr.h"
4 #include "access/gist.h"
5 #include "access/skey.h"
6 #if PG_VERSION_NUM < 130000
7 #include "access/tuptoaster.h"
8 #else
9 #include "access/heaptoast.h"
10 #endif
11 #include "utils/memutils.h"
12
13 typedef struct SmlSign {
14         int32   vl_len_; /* varlena header (do not touch directly!) */
15         int32   flag:8,
16                         size:24;
17         int32   maxrepeat;
18         char    data[1];
19 } SmlSign;
20
21 #define SMLSIGNHDRSZ    (offsetof(SmlSign, data))
22
23 #define BITBYTE 8
24 #define SIGLENINT  61
25 #define SIGLEN  ( sizeof(int)*SIGLENINT )
26 #define SIGLENBIT (SIGLEN*BITBYTE - 1)  /* see makesign */
27 typedef char BITVEC[SIGLEN];
28 typedef char *BITVECP;
29 #define LOOPBYTE \
30                 for(i=0;i<SIGLEN;i++)
31
32 #define GETBYTE(x,i) ( *( (BITVECP)(x) + (int)( (i) / BITBYTE ) ) )
33 #define GETBITBYTE(x,i) ( ((char)(x)) >> i & 0x01 )
34 #define CLRBIT(x,i)   GETBYTE(x,i) &= ~( 0x01 << ( (i) % BITBYTE ) )
35 #define SETBIT(x,i)   GETBYTE(x,i) |=  ( 0x01 << ( (i) % BITBYTE ) )
36 #define GETBIT(x,i) ( (GETBYTE(x,i) >> ( (i) % BITBYTE )) & 0x01 )
37
38 #define HASHVAL(val) (((unsigned int)(val)) % SIGLENBIT)
39 #define HASH(sign, val) SETBIT((sign), HASHVAL(val))
40
41 #define ARRKEY                  0x01
42 #define SIGNKEY                 0x02
43 #define ALLISTRUE               0x04
44
45 #define ISARRKEY(x)             ( ((SmlSign*)x)->flag & ARRKEY )
46 #define ISSIGNKEY(x)    ( ((SmlSign*)x)->flag & SIGNKEY )
47 #define ISALLTRUE(x)    ( ((SmlSign*)x)->flag & ALLISTRUE )
48
49 #define CALCGTSIZE(flag, len)   ( SMLSIGNHDRSZ + ( ( (flag) & ARRKEY ) ? ((len)*sizeof(uint32)) : (((flag) & ALLISTRUE) ? 0 : SIGLEN) ) )
50 #define GETSIGN(x)                              ( (BITVECP)( (char*)x+SMLSIGNHDRSZ ) )
51 #define GETARR(x)                               ( (uint32*)( (char*)x+SMLSIGNHDRSZ ) )
52
53 #define GETENTRY(vec,pos) ((SmlSign *) DatumGetPointer((vec)->vector[(pos)].key))
54
55 /*
56  * Fake IO
57  */
58 PG_FUNCTION_INFO_V1(gsmlsign_in);
59 Datum   gsmlsign_in(PG_FUNCTION_ARGS);
60 Datum
61 gsmlsign_in(PG_FUNCTION_ARGS)
62 {
63         elog(ERROR, "not implemented");
64         PG_RETURN_DATUM(0);
65 }
66
67 PG_FUNCTION_INFO_V1(gsmlsign_out);
68 Datum   gsmlsign_out(PG_FUNCTION_ARGS);
69 Datum
70 gsmlsign_out(PG_FUNCTION_ARGS)
71 {
72         elog(ERROR, "not implemented");
73         PG_RETURN_DATUM(0);
74 }
75
76 /*
77  * Compress/decompress
78  */
79
80 /* Number of one-bits in an unsigned byte */
81 static const uint8 number_of_ones[256] = {
82         0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4,
83         1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
84         1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
85         2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
86         1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
87         2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
88         2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
89         3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
90         1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
91         2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
92         2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
93         3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
94         2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
95         3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
96         3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
97         4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8
98 };
99
100 static int
101 compareint(const void *va, const void *vb)
102 {
103         uint32  a = *((uint32 *) va);
104         uint32  b = *((uint32 *) vb);
105
106         if (a == b)
107                 return 0;
108         return (a > b) ? 1 : -1;
109 }
110
111 /*
112  * Removes duplicates from an array of int32. 'l' is
113  * size of the input array. Returns the new size of the array.
114  */
115 static int
116 uniqueint(uint32 *a, int32 l, int32 *max)
117 {
118         uint32          *ptr,
119                                 *res;
120         int32           cnt = 0;
121
122         *max = 1;
123
124         if (l <= 1)
125                 return l;
126
127         ptr = res = a;
128
129         qsort((void *) a, l, sizeof(uint32), compareint);
130
131         while (ptr - a < l)
132                 if (*ptr != *res)
133                 {
134                         cnt = 1;
135                         *(++res) = *ptr++;
136                 }
137                 else
138                 {
139                         cnt++;
140                         if ( cnt > *max )
141                                 *max = cnt;
142                         ptr++;
143                 }
144
145         if ( cnt > *max )
146                 *max = cnt;
147
148         return res + 1 - a;
149 }
150
151 SmlSign*
152 Array2HashedArray(ProcTypeInfo info, ArrayType *a)
153 {
154         SimpleArray *s = Array2SimpleArray(info, a);
155         SmlSign         *sign;
156         int32           len, i;
157         uint32          *ptr;
158
159         len = CALCGTSIZE( ARRKEY, s->nelems );
160
161         getFmgrInfoHash(s->info);
162         if (s->info->tupDesc)
163                 elog(ERROR, "GiST  doesn't support composite (weighted) type");
164
165         sign = palloc( len );
166         sign->flag = ARRKEY;
167         sign->size = s->nelems;
168
169         ptr = GETARR(sign);
170         for(i=0;i<s->nelems;i++)
171                 ptr[i] = DatumGetUInt32(FunctionCall1Coll(&s->info->hashFunc,
172                                                                                                   DEFAULT_COLLATION_OID,
173                                                                                                   s->elems[i]));
174
175         /*
176          * there is a collision of hash-function; len is always equal or less than
177          * s->nelems
178          */
179         sign->size = uniqueint( GETARR(sign), sign->size, &sign->maxrepeat );
180         len = CALCGTSIZE( ARRKEY, sign->size );
181         SET_VARSIZE(sign, len);
182
183         return sign;
184 }
185
186 static int
187 HashedElemCmp(const void *va, const void *vb)
188 {
189         uint32  a = ((HashedElem *) va)->hash;
190         uint32  b = ((HashedElem *) vb)->hash;
191
192         if (a == b)
193         {
194                 double  ma = ((HashedElem *) va)->idfMin;
195                 double  mb = ((HashedElem *) va)->idfMin;
196
197                 if (ma == mb)
198                         return 0;
199
200                 return ( ma > mb ) ? 1 : -1;
201         }
202
203         return (a > b) ? 1 : -1;
204 }
205
206 static int
207 uniqueHashedElem(HashedElem *a, int32 l)
208 {
209         HashedElem      *ptr,
210                                 *res;
211
212         if (l <= 1)
213                 return l;
214
215         ptr = res = a;
216
217         qsort(a, l, sizeof(HashedElem), HashedElemCmp);
218
219         while (ptr - a < l)
220                 if (ptr->hash != res->hash)
221                         *(++res) = *ptr++;
222                 else
223                 {
224                         res->idfMax = ptr->idfMax;
225                         ptr++;
226                 }
227
228         return res + 1 - a;
229 }
230
231 static StatCache*
232 getHashedCache(void *cache)
233 {
234         StatCache *stat = getStat(cache, SIGLENBIT);
235
236         if ( stat->nhelems < 0 )
237         {
238                 int i;
239                 /*
240                  * Init
241                  */
242
243                 if (stat->info->tupDesc)
244                         elog(ERROR, "GiST  doesn't support composite (weighted) type");
245                 getFmgrInfoHash(stat->info);
246                 for(i=0;i<stat->nelems;i++)
247                 {
248                         uint32  hash;
249
250                         hash = DatumGetUInt32(FunctionCall1Coll(&stat->info->hashFunc,
251                                                                                                         DEFAULT_COLLATION_OID,
252                                                                                                         stat->elems[i].datum));
253                         int             index = HASHVAL(hash);
254
255                         stat->helems[i].hash = hash;
256                         stat->helems[i].idfMin = stat->helems[i].idfMax = stat->elems[i].idf;   
257                         if ( stat->selems[index].idfMin == 0.0 )
258                                 stat->selems[index].idfMin = stat->selems[index].idfMax = stat->elems[i].idf;
259                         else if ( stat->selems[index].idfMin > stat->elems[i].idf )
260                                 stat->selems[index].idfMin = stat->elems[i].idf;
261                         else if ( stat->selems[index].idfMax < stat->elems[i].idf )
262                                 stat->selems[index].idfMax = stat->elems[i].idf;
263                 }
264
265                 stat->nhelems = uniqueHashedElem( stat->helems, stat->nelems);
266         }
267
268         return stat;
269 }
270
271 static HashedElem*
272 getHashedElemIdf(StatCache *stat, uint32 hash, HashedElem *StopLow)
273 {
274         HashedElem      *StopMiddle,
275                                 *StopHigh = stat->helems + stat->nhelems;
276
277         if ( !StopLow )
278                 StopLow = stat->helems;
279
280         while (StopLow < StopHigh) {
281                 StopMiddle = StopLow + ((StopHigh - StopLow) >> 1);
282
283                 if ( StopMiddle->hash == hash )
284                         return StopMiddle;
285                 else if ( StopMiddle->hash < hash )
286                         StopLow = StopMiddle + 1;
287                 else
288                         StopHigh = StopMiddle;
289         }
290
291         return NULL;
292 }
293
294 static void
295 fillHashVal(void *cache, SimpleArray *a)
296 {
297         int i;
298
299         if (a->hash)
300                 return;
301
302         allocateHash(cache, a);
303
304         if (a->info->tupDesc)
305                 elog(ERROR, "GiST  doesn't support composite (weighted) type");
306         getFmgrInfoHash(a->info);
307
308         for(i=0;i<a->nelems;i++)
309                 a->hash[i] = DatumGetUInt32(FunctionCall1Coll(&a->info->hashFunc,
310                                                                                                           DEFAULT_COLLATION_OID,
311                                                                                                           a->elems[i]));
312 }
313
314
315 static bool
316 hasHashedElem(SmlSign  *a, uint32 h)
317 {
318         uint32  *StopLow = GETARR(a),
319                         *StopHigh = GETARR(a) + a->size,
320                         *StopMiddle;
321
322         while (StopLow < StopHigh) {
323                 StopMiddle = StopLow + ((StopHigh - StopLow) >> 1);
324
325                 if ( *StopMiddle == h )
326                         return true;
327                 else if ( *StopMiddle < h )
328                         StopLow = StopMiddle + 1;
329                 else
330                         StopHigh = StopMiddle;
331         }
332
333         return false;
334 }
335
336 static void
337 makesign(BITVECP sign, SmlSign  *a)
338 {
339         int32   i;
340         uint32  *ptr = GETARR(a);
341
342         MemSet((void *) sign, 0, sizeof(BITVEC));
343         SETBIT(sign, SIGLENBIT);   /* set last unused bit */
344
345         for (i = 0; i < a->size; i++)
346                 HASH(sign, ptr[i]);
347 }
348
349 static int32
350 sizebitvec(BITVECP sign)
351 {
352         int32   size = 0,
353                         i;
354
355         LOOPBYTE
356                 size += number_of_ones[(unsigned char) sign[i]];
357
358         return size;
359 }
360
361 PG_FUNCTION_INFO_V1(gsmlsign_compress);
362 Datum gsmlsign_compress(PG_FUNCTION_ARGS);
363 Datum
364 gsmlsign_compress(PG_FUNCTION_ARGS)
365 {
366         GISTENTRY  *entry = (GISTENTRY *) PG_GETARG_POINTER(0);
367         GISTENTRY  *retval = entry;
368
369         if (entry->leafkey) /* new key */
370         {
371                 SmlSign         *sign;
372                 ArrayType       *a = DatumGetArrayTypeP(entry->key);
373
374                 sign = Array2HashedArray(NULL, a);
375
376                 if ( VARSIZE(sign) > TOAST_INDEX_TARGET )
377                 {       /* make signature due to its big size */
378                         SmlSign *tmpsign;
379                         int             len;
380
381                         len = CALCGTSIZE( SIGNKEY, sign->size );
382                         tmpsign = palloc( len );
383                         tmpsign->flag = SIGNKEY;
384                         SET_VARSIZE(tmpsign, len);
385
386                         makesign(GETSIGN(tmpsign), sign);
387                         tmpsign->size = sizebitvec(GETSIGN(tmpsign));
388                         tmpsign->maxrepeat = sign->maxrepeat;
389                         sign = tmpsign;
390                 }
391
392                 retval = (GISTENTRY *) palloc(sizeof(GISTENTRY));
393                 gistentryinit(*retval, PointerGetDatum(sign),
394                                                 entry->rel, entry->page,
395                                                 entry->offset, false);
396         }
397         else if ( ISSIGNKEY(DatumGetPointer(entry->key)) &&
398                                 !ISALLTRUE(DatumGetPointer(entry->key)) )
399         {
400                 SmlSign *sign = (SmlSign*)DatumGetPointer(entry->key);
401
402                 Assert( sign->size == sizebitvec(GETSIGN(sign)) );
403
404                 if ( sign->size == SIGLENBIT )
405                 {
406                         int32   len = CALCGTSIZE(SIGNKEY | ALLISTRUE, 0);
407                         int32   maxrepeat = sign->maxrepeat;
408
409                         sign = (SmlSign *) palloc(len);
410                         SET_VARSIZE(sign, len);
411                         sign->flag = SIGNKEY | ALLISTRUE;
412                         sign->size = SIGLENBIT;
413                         sign->maxrepeat = maxrepeat;
414
415                         retval = (GISTENTRY *) palloc(sizeof(GISTENTRY));
416
417                         gistentryinit(*retval, PointerGetDatum(sign),
418                                                         entry->rel, entry->page,
419                                                         entry->offset, false);
420                 }
421         }
422
423         PG_RETURN_POINTER(retval);
424 }
425
426 PG_FUNCTION_INFO_V1(gsmlsign_decompress);
427 Datum gsmlsign_decompress(PG_FUNCTION_ARGS);
428 Datum
429 gsmlsign_decompress(PG_FUNCTION_ARGS)
430 {
431         GISTENTRY       *entry = (GISTENTRY *) PG_GETARG_POINTER(0);
432         SmlSign         *key =  (SmlSign*)DatumGetPointer(PG_DETOAST_DATUM(entry->key));
433
434         if (key != (SmlSign *) DatumGetPointer(entry->key))
435         {
436                 GISTENTRY  *retval = (GISTENTRY *) palloc(sizeof(GISTENTRY));
437
438                 gistentryinit(*retval, PointerGetDatum(key),
439                                                 entry->rel, entry->page,
440                                                 entry->offset, false);
441
442                 PG_RETURN_POINTER(retval);
443         }
444
445         PG_RETURN_POINTER(entry);
446 }
447
448 /*
449  * Union method
450  */
451 static bool
452 unionkey(BITVECP sbase, SmlSign *add)
453 {
454         int32   i;
455
456         if (ISSIGNKEY(add))
457         {
458                 BITVECP sadd = GETSIGN(add);
459
460                 if (ISALLTRUE(add))
461                         return true;
462
463                 LOOPBYTE
464                         sbase[i] |= sadd[i];
465         }
466         else
467         {
468                 uint32  *ptr = GETARR(add);
469
470                 for (i = 0; i < add->size; i++)
471                         HASH(sbase, ptr[i]);
472         }
473
474         return false;
475 }
476
477 PG_FUNCTION_INFO_V1(gsmlsign_union);
478 Datum gsmlsign_union(PG_FUNCTION_ARGS);
479 Datum
480 gsmlsign_union(PG_FUNCTION_ARGS)
481 {
482         GistEntryVector *entryvec = (GistEntryVector *) PG_GETARG_POINTER(0);
483         int                             *size = (int *) PG_GETARG_POINTER(1);
484         BITVEC                  base;
485         int32           i,
486                                 len,
487                                 maxrepeat = 1;
488         int32           flag = 0;
489         SmlSign    *result;
490
491         MemSet((void *) base, 0, sizeof(BITVEC));
492         for (i = 0; i < entryvec->n; i++)
493         {
494                 if (GETENTRY(entryvec, i)->maxrepeat > maxrepeat)
495                         maxrepeat = GETENTRY(entryvec, i)->maxrepeat;
496                 if (unionkey(base, GETENTRY(entryvec, i)))
497                 {
498                         flag = ALLISTRUE;
499                         break;
500                 }
501         }
502
503         flag |= SIGNKEY;
504         len = CALCGTSIZE(flag, 0);
505         result = (SmlSign *) palloc(len);
506         *size = len;
507         SET_VARSIZE(result, len);
508         result->flag = flag;
509         result->maxrepeat = maxrepeat;
510
511         if (!ISALLTRUE(result))
512         {
513                 memcpy((void *) GETSIGN(result), (void *) base, sizeof(BITVEC));
514                 result->size = sizebitvec(GETSIGN(result));
515         }
516         else
517                 result->size = SIGLENBIT;
518
519         PG_RETURN_POINTER(result);
520 }
521
522 /*
523  * Same method
524  */
525
526 PG_FUNCTION_INFO_V1(gsmlsign_same);
527 Datum gsmlsign_same(PG_FUNCTION_ARGS);
528 Datum
529 gsmlsign_same(PG_FUNCTION_ARGS)
530 {
531         SmlSign *a = (SmlSign*)PG_GETARG_POINTER(0);
532         SmlSign *b = (SmlSign*)PG_GETARG_POINTER(1);
533         bool    *result = (bool *) PG_GETARG_POINTER(2);
534
535         if (a->size != b->size)
536         {
537                 *result = false;
538         }
539         else if (ISSIGNKEY(a))
540         {       /* then b also ISSIGNKEY */
541                 if ( ISALLTRUE(a) )
542                 {
543                         /* in this case b is all true too - other cases is catched
544                            upper */
545                         *result = true;
546                 }
547                 else
548                 {
549                         int32   i;
550                         BITVECP sa = GETSIGN(a),
551                                         sb = GETSIGN(b);
552
553                         *result = true;
554
555                         if ( !ISALLTRUE(a) )
556                         {
557                                 LOOPBYTE
558                                 {
559                                         if (sa[i] != sb[i])
560                                         {
561                                                 *result = false;
562                                                 break;
563                                         }
564                                 }
565                         }
566                 }
567         }
568         else
569         {
570                 uint32  *ptra = GETARR(a),
571                                 *ptrb = GETARR(b);
572                 int32   i;
573
574                 *result = true;
575                 for (i = 0; i < a->size; i++)
576                 {
577                         if ( ptra[i] != ptrb[i])
578                         {
579                                 *result = false;
580                                 break;
581                         }
582                 }
583         }
584
585         PG_RETURN_POINTER(result);
586 }
587
588 /*
589  * Penalty method
590  */
591 static int
592 hemdistsign(BITVECP a, BITVECP b)
593 {
594         int     i,
595                 diff,
596                 dist = 0;
597
598         LOOPBYTE
599         {
600                 diff = (unsigned char) (a[i] ^ b[i]);
601                 dist += number_of_ones[diff];
602         }
603         return dist;
604 }
605
606 static int
607 hemdist(SmlSign *a, SmlSign *b)
608 {
609         if (ISALLTRUE(a))
610         {
611                 if (ISALLTRUE(b))
612                         return 0;
613                 else
614                         return SIGLENBIT - b->size;
615         }
616         else if (ISALLTRUE(b))
617                 return SIGLENBIT - b->size;
618
619         return hemdistsign(GETSIGN(a), GETSIGN(b));
620 }
621
622 PG_FUNCTION_INFO_V1(gsmlsign_penalty);
623 Datum gsmlsign_penalty(PG_FUNCTION_ARGS);
624 Datum
625 gsmlsign_penalty(PG_FUNCTION_ARGS)
626 {
627         GISTENTRY       *origentry = (GISTENTRY *) PG_GETARG_POINTER(0); /* always ISSIGNKEY */
628         GISTENTRY       *newentry = (GISTENTRY *) PG_GETARG_POINTER(1);
629         float           *penalty = (float *) PG_GETARG_POINTER(2);
630         SmlSign         *origval = (SmlSign *) DatumGetPointer(origentry->key);
631         SmlSign         *newval = (SmlSign *) DatumGetPointer(newentry->key);
632         BITVECP         orig = GETSIGN(origval);
633
634         *penalty = 0.0;
635
636         if (ISARRKEY(newval))
637         {
638                 BITVEC  sign;
639
640                 makesign(sign, newval);
641
642                 if (ISALLTRUE(origval))
643                         *penalty = ((float) (SIGLENBIT - sizebitvec(sign))) / (float) (SIGLENBIT + 1);
644                 else
645                         *penalty = hemdistsign(sign, orig);
646         }
647         else
648                 *penalty = hemdist(origval, newval);
649
650         PG_RETURN_POINTER(penalty);
651 }
652
653 /*
654  * Picksplit method
655  */
656
657 typedef struct
658 {
659         bool    allistrue;
660         int32   size;
661         BITVEC  sign;
662 } CACHESIGN;
663
664 static void
665 fillcache(CACHESIGN *item, SmlSign *key)
666 {
667         item->allistrue = false;
668         item->size = key->size;
669
670         if (ISARRKEY(key))
671         {
672                 makesign(item->sign, key);
673                 item->size = sizebitvec( item->sign );
674         }
675         else if (ISALLTRUE(key))
676                 item->allistrue = true;
677         else
678                 memcpy((void *) item->sign, (void *) GETSIGN(key), sizeof(BITVEC));
679 }
680
681 #define WISH_F(a,b,c) (double)( -(double)(((a)-(b))*((a)-(b))*((a)-(b)))*(c) )
682
683 typedef struct
684 {
685         OffsetNumber pos;
686         int32           cost;
687 } SPLITCOST;
688
689 static int
690 comparecost(const void *va, const void *vb)
691 {
692         SPLITCOST  *a = (SPLITCOST *) va;
693         SPLITCOST  *b = (SPLITCOST *) vb;
694
695         if (a->cost == b->cost)
696                 return 0;
697         else
698                 return (a->cost > b->cost) ? 1 : -1;
699 }
700
701 static int
702 hemdistcache(CACHESIGN *a, CACHESIGN *b)
703 {
704         if (a->allistrue)
705         {
706                 if (b->allistrue)
707                         return 0;
708                 else
709                         return SIGLENBIT - b->size;
710         }
711         else if (b->allistrue)
712                 return SIGLENBIT - a->size;
713
714         return hemdistsign(a->sign, b->sign);
715 }
716
717 PG_FUNCTION_INFO_V1(gsmlsign_picksplit);
718 Datum gsmlsign_picksplit(PG_FUNCTION_ARGS);
719 Datum
720 gsmlsign_picksplit(PG_FUNCTION_ARGS)
721 {
722         GistEntryVector *entryvec = (GistEntryVector *) PG_GETARG_POINTER(0);
723         GIST_SPLITVEC *v = (GIST_SPLITVEC *) PG_GETARG_POINTER(1);
724         OffsetNumber k,
725                                 j;
726         SmlSign         *datum_l,
727                                 *datum_r;
728         BITVECP         union_l,
729                                 union_r;
730         int32           size_alpha,
731                                 size_beta;
732         int32           size_waste,
733                                 waste = -1;
734         int32           nbytes;
735         OffsetNumber seed_1 = 0,
736                                 seed_2 = 0;
737         OffsetNumber *left,
738                                 *right;
739         OffsetNumber maxoff;
740         BITVECP         ptr;
741         int                     i;
742         CACHESIGN  *cache;
743         SPLITCOST  *costvector;
744
745         maxoff = entryvec->n - 2;
746         nbytes = (maxoff + 2) * sizeof(OffsetNumber);
747         v->spl_left = (OffsetNumber *) palloc(nbytes);
748         v->spl_right = (OffsetNumber *) palloc(nbytes);
749
750         cache = (CACHESIGN *) palloc(sizeof(CACHESIGN) * (maxoff + 2));
751         fillcache(&cache[FirstOffsetNumber], GETENTRY(entryvec, FirstOffsetNumber));
752
753         for (k = FirstOffsetNumber; k < maxoff; k = OffsetNumberNext(k))
754         {
755                 for (j = OffsetNumberNext(k); j <= maxoff; j = OffsetNumberNext(j))
756                 {
757                         if (k == FirstOffsetNumber)
758                                 fillcache(&cache[j], GETENTRY(entryvec, j));
759
760                         size_waste = hemdistcache(&(cache[j]), &(cache[k]));
761                         if (size_waste > waste)
762                         {
763                                 waste = size_waste;
764                                 seed_1 = k;
765                                 seed_2 = j;
766                         }
767                 }
768         }
769
770         left = v->spl_left;
771         v->spl_nleft = 0;
772         right = v->spl_right;
773         v->spl_nright = 0;
774
775         if (seed_1 == 0 || seed_2 == 0)
776         {
777                 seed_1 = 1;
778                 seed_2 = 2;
779         }
780
781         /* form initial .. */
782         if (cache[seed_1].allistrue)
783         {
784                 datum_l = (SmlSign *) palloc(CALCGTSIZE(SIGNKEY | ALLISTRUE, 0));
785                 SET_VARSIZE(datum_l, CALCGTSIZE(SIGNKEY | ALLISTRUE, 0));
786                 datum_l->flag = SIGNKEY | ALLISTRUE;
787                 datum_l->size = SIGLENBIT;
788         }
789         else
790         {
791                 datum_l = (SmlSign *) palloc(CALCGTSIZE(SIGNKEY, 0));
792                 SET_VARSIZE(datum_l, CALCGTSIZE(SIGNKEY, 0));
793                 datum_l->flag = SIGNKEY;
794                 memcpy((void *) GETSIGN(datum_l), (void *) cache[seed_1].sign, sizeof(BITVEC));
795                 datum_l->size = cache[seed_1].size;
796         }
797         if (cache[seed_2].allistrue)
798         {
799                 datum_r = (SmlSign *) palloc(CALCGTSIZE(SIGNKEY | ALLISTRUE, 0));
800                 SET_VARSIZE(datum_r, CALCGTSIZE(SIGNKEY | ALLISTRUE, 0));
801                 datum_r->flag = SIGNKEY | ALLISTRUE;
802                 datum_r->size = SIGLENBIT;
803         }
804         else
805         {
806                 datum_r = (SmlSign *) palloc(CALCGTSIZE(SIGNKEY, 0));
807                 SET_VARSIZE(datum_r, CALCGTSIZE(SIGNKEY, 0));
808                 datum_r->flag = SIGNKEY;
809                 memcpy((void *) GETSIGN(datum_r), (void *) cache[seed_2].sign, sizeof(BITVEC));
810                 datum_r->size = cache[seed_2].size;
811         }
812
813         union_l = GETSIGN(datum_l);
814         union_r = GETSIGN(datum_r);
815         maxoff = OffsetNumberNext(maxoff);
816         fillcache(&cache[maxoff], GETENTRY(entryvec, maxoff));
817         /* sort before ... */
818         costvector = (SPLITCOST *) palloc(sizeof(SPLITCOST) * maxoff);
819         for (j = FirstOffsetNumber; j <= maxoff; j = OffsetNumberNext(j))
820         {
821                 costvector[j - 1].pos = j;
822                 size_alpha = hemdistcache(&(cache[seed_1]), &(cache[j]));
823                 size_beta = hemdistcache(&(cache[seed_2]), &(cache[j]));
824                 costvector[j - 1].cost = Abs(size_alpha - size_beta);
825         }
826         qsort((void *) costvector, maxoff, sizeof(SPLITCOST), comparecost);
827
828         datum_l->maxrepeat = datum_r->maxrepeat = 1;
829
830         for (k = 0; k < maxoff; k++)
831         {
832                 j = costvector[k].pos;
833                 if (j == seed_1)
834                 {
835                         *left++ = j;
836                         v->spl_nleft++;
837                         continue;
838                 }
839                 else if (j == seed_2)
840                 {
841                         *right++ = j;
842                         v->spl_nright++;
843                         continue;
844                 }
845
846                 if (ISALLTRUE(datum_l) || cache[j].allistrue)
847                 {
848                         if (ISALLTRUE(datum_l) && cache[j].allistrue)
849                                 size_alpha = 0;
850                         else
851                                 size_alpha = SIGLENBIT - (
852                                                                         (cache[j].allistrue) ? datum_l->size : cache[j].size
853                                                         );
854                 }
855                 else
856                         size_alpha = hemdistsign(cache[j].sign, GETSIGN(datum_l));
857
858                 if (ISALLTRUE(datum_r) || cache[j].allistrue)
859                 {
860                         if (ISALLTRUE(datum_r) && cache[j].allistrue)
861                                 size_beta = 0;
862                         else
863                                 size_beta = SIGLENBIT - (
864                                                                         (cache[j].allistrue) ? datum_r->size : cache[j].size
865                                                         );
866                 }
867                 else
868                         size_beta = hemdistsign(cache[j].sign, GETSIGN(datum_r));
869
870                 if (size_alpha < size_beta + WISH_F(v->spl_nleft, v->spl_nright, 0.1))
871                 {
872                         if (ISALLTRUE(datum_l) || cache[j].allistrue)
873                         {
874                                 if (!ISALLTRUE(datum_l))
875                                         MemSet((void *) GETSIGN(datum_l), 0xff, sizeof(BITVEC));
876                                 datum_l->size = SIGLENBIT;
877                         }
878                         else
879                         {
880                                 ptr = cache[j].sign;
881                                 LOOPBYTE
882                                         union_l[i] |= ptr[i];
883                                 datum_l->size = sizebitvec(union_l);
884                         }
885                         *left++ = j;
886                         v->spl_nleft++;
887                 }
888                 else
889                 {
890                         if (ISALLTRUE(datum_r) || cache[j].allistrue)
891                         {
892                                 if (!ISALLTRUE(datum_r))
893                                         MemSet((void *) GETSIGN(datum_r), 0xff, sizeof(BITVEC));
894                                 datum_r->size = SIGLENBIT;
895                         }
896                         else
897                         {
898                                 ptr = cache[j].sign;
899                                 LOOPBYTE
900                                         union_r[i] |= ptr[i];
901                                 datum_r->size = sizebitvec(union_r);
902                         }
903                         *right++ = j;
904                         v->spl_nright++;
905                 }
906         }
907         *right = *left = FirstOffsetNumber;
908         v->spl_ldatum = PointerGetDatum(datum_l);
909         v->spl_rdatum = PointerGetDatum(datum_r);
910
911         Assert( datum_l->size = sizebitvec(GETSIGN(datum_l)) );
912         Assert( datum_r->size = sizebitvec(GETSIGN(datum_r)) );
913
914         PG_RETURN_POINTER(v);
915 }
916
917 static double
918 getIdfMaxLimit(SmlSign *key)
919 {
920         switch( getTFMethod() )
921         {
922                 case TF_CONST:
923                         return 1.0;
924                         break;
925                 case TF_N:
926                         return (double)(key->maxrepeat);
927                         break;
928                 case TF_LOG:
929                         return 1.0 + log( (double)(key->maxrepeat) );
930                         break;
931                 default:
932                         elog(ERROR,"Unknown TF method: %d", getTFMethod());
933         }
934
935         return 0.0;
936 }
937
938 /*
939  * Consistent function
940  */
941 PG_FUNCTION_INFO_V1(gsmlsign_consistent);
942 Datum gsmlsign_consistent(PG_FUNCTION_ARGS);
943 Datum
944 gsmlsign_consistent(PG_FUNCTION_ARGS)
945 {
946         GISTENTRY               *entry = (GISTENTRY *) PG_GETARG_POINTER(0);
947         StrategyNumber  strategy = (StrategyNumber) PG_GETARG_UINT16(2);
948         bool                    *recheck = (bool *) PG_GETARG_POINTER(4);
949         ArrayType               *a;
950         SmlSign                 *key = (SmlSign*)DatumGetPointer(entry->key);
951         int                             res = false;
952         SmlSign                 *query;
953         SimpleArray             *s;
954         int32                   i;
955
956         fcinfo->flinfo->fn_extra = SearchArrayCache(
957                                                                         fcinfo->flinfo->fn_extra,
958                                                                         fcinfo->flinfo->fn_mcxt,
959                                                                         PG_GETARG_DATUM(1), &a, &s, &query);
960
961         *recheck = true;
962
963         if ( ARRISVOID(a) )
964                 PG_RETURN_BOOL(res);
965
966         if ( strategy == SmlarOverlapStrategy )
967         {
968                 if (ISALLTRUE(key))
969                 {
970                         res = true;
971                 }
972                 else if (ISARRKEY(key))
973                 {
974                         uint32  *kptr = GETARR(key),
975                                         *qptr = GETARR(query);
976
977                         while( kptr - GETARR(key) < key->size && qptr - GETARR(query) < query->size )
978                         {
979                                 if ( *kptr < *qptr )
980                                         kptr++;
981                                 else if ( *kptr > *qptr )
982                                         qptr++;
983                                 else
984                                 {
985                                         res = true;
986                                         break;
987                                 }
988                         }
989                         *recheck = false;
990                 }
991                 else
992                 {
993                         BITVECP sign = GETSIGN(key);
994
995                         fillHashVal(fcinfo->flinfo->fn_extra, s);
996
997                         for(i=0; i<s->nelems; i++)
998                         {
999                                 if ( GETBIT(sign, HASHVAL(s->hash[i])) )
1000                                 {
1001                                         res = true;
1002                                         break;
1003                                 }
1004                         }
1005                 }
1006         }
1007         /*
1008          *  SmlarSimilarityStrategy
1009          */
1010         else if (ISALLTRUE(key))
1011         {
1012                 if ( GIST_LEAF(entry) )
1013                 {
1014                         /*
1015                          * With TF/IDF similarity we cannot say anything useful
1016                          */
1017                         if ( query->size < SIGLENBIT && getSmlType() != ST_TFIDF )
1018                         {
1019                                 double power = ((double)(query->size)) * ((double)(SIGLENBIT));
1020
1021                                 if ( ((double)(query->size)) / sqrt(power) >= GetSmlarLimit() ) 
1022                                         res = true;
1023                         }
1024                         else
1025                         {
1026                                 res = true;     
1027                         }
1028                 }
1029                 else
1030                         res = true;
1031         }
1032         else if (ISARRKEY(key))
1033         {
1034                 uint32  *kptr = GETARR(key),
1035                                 *qptr = GETARR(query);
1036
1037                 Assert( GIST_LEAF(entry) );
1038
1039                 switch(getSmlType())
1040                 {
1041                         case ST_TFIDF:
1042                                 {
1043                                         StatCache       *stat = getHashedCache(fcinfo->flinfo->fn_extra);
1044                                         double          sumU = 0.0,
1045                                                                 sumQ = 0.0,
1046                                                                 sumK = 0.0;
1047                                         double          maxKTF = getIdfMaxLimit(key);
1048                                         HashedElem  *h;
1049
1050                                         Assert( s->df );
1051                                         fillHashVal(fcinfo->flinfo->fn_extra, s);
1052                                         if ( stat->info != s->info )
1053                                                 elog(ERROR,"Statistic and actual argument have different type");
1054
1055                                         for(i=0;i<s->nelems;i++)
1056                                         {
1057                                                 sumQ += s->df[i] * s->df[i];
1058
1059                                                 h = getHashedElemIdf(stat, s->hash[i], NULL);
1060                                                 if ( h && hasHashedElem(key, s->hash[i]) )
1061                                                 {
1062                                                         sumK += h->idfMin * h->idfMin;
1063                                                         sumU += h->idfMax * maxKTF * s->df[i];
1064                                                 } 
1065                                         }
1066
1067                                         if ( sumK > 0.0 && sumQ > 0.0 && sumU / sqrt( sumK * sumQ ) >= GetSmlarLimit() )
1068                                         {
1069                                                 /* 
1070                                                  * More precisely calculate sumK
1071                                                  */
1072                                                 h = NULL;
1073                                                 sumK = 0.0;
1074
1075                                                 for(i=0;i<key->size;i++)
1076                                                 {
1077                                                         h = getHashedElemIdf(stat, GETARR(key)[i], h);                                  
1078                                                         if (h)
1079                                                                 sumK += h->idfMin * h->idfMin;
1080                                                 }
1081
1082                                                 if ( sumK > 0.0 && sumQ > 0.0 && sumU / sqrt( sumK * sumQ ) >= GetSmlarLimit() )
1083                                                         res = true;
1084                                         }
1085                                 }
1086                                 break;
1087                         case ST_COSINE:
1088                                 {
1089                                         double                  power;
1090                                         power = sqrt( ((double)(key->size)) * ((double)(s->nelems)) );
1091
1092                                         if (  ((double)Min(key->size, s->nelems)) / power >= GetSmlarLimit() )
1093                                         {
1094                                                 int  cnt = 0;
1095
1096                                                 while( kptr - GETARR(key) < key->size && qptr - GETARR(query) < query->size )
1097                                                 {
1098                                                         if ( *kptr < *qptr )
1099                                                                 kptr++;
1100                                                         else if ( *kptr > *qptr )
1101                                                                 qptr++;
1102                                                         else
1103                                                         {
1104                                                                 cnt++;
1105                                                                 kptr++;
1106                                                                 qptr++;
1107                                                         }
1108                                                 }
1109
1110                                                 if ( ((double)cnt) / power >= GetSmlarLimit() )
1111                                                         res = true;
1112                                         }
1113                                 }
1114                                 break;
1115                         default:
1116                                 elog(ERROR,"GiST doesn't support current formula type of similarity");
1117                 }
1118         }
1119         else
1120         {       /* signature */
1121                 BITVECP sign = GETSIGN(key);
1122                 int32   count = 0;
1123
1124                 fillHashVal(fcinfo->flinfo->fn_extra, s);
1125
1126                 if ( GIST_LEAF(entry) )
1127                 {
1128                         switch(getSmlType())
1129                         {
1130                                 case ST_TFIDF:
1131                                         {
1132                                                 StatCache       *stat = getHashedCache(fcinfo->flinfo->fn_extra);
1133                                                 double          sumU = 0.0,
1134                                                                         sumQ = 0.0,
1135                                                                         sumK = 0.0;
1136                                                 double          maxKTF = getIdfMaxLimit(key);
1137
1138                                                 Assert( s->df );
1139                                                 if ( stat->info != s->info )
1140                                                         elog(ERROR,"Statistic and actual argument have different type");
1141
1142                                                 for(i=0;i<s->nelems;i++)
1143                                                 {
1144                                                         int32           hbit = HASHVAL(s->hash[i]);
1145
1146                                                         sumQ += s->df[i] * s->df[i];
1147                                                         if ( GETBIT(sign, hbit) )
1148                                                         {
1149                                                                 sumK += stat->selems[ hbit ].idfMin * stat->selems[ hbit ].idfMin;
1150                                                                 sumU += stat->selems[ hbit ].idfMax * maxKTF * s->df[i];
1151                                                         }
1152                                                 }
1153
1154                                                 if ( sumK > 0.0 && sumQ > 0.0 && sumU / sqrt( sumK * sumQ ) >= GetSmlarLimit() )
1155                                                 {
1156                                                         /* 
1157                                                          * More precisely calculate sumK
1158                                                          */
1159                                                         sumK = 0.0;
1160
1161                                                         for(i=0;i<SIGLENBIT;i++)
1162                                                                 if ( GETBIT(sign,i) )
1163                                                                         sumK += stat->selems[ i ].idfMin * stat->selems[ i ].idfMin;
1164
1165                                                         if ( sumK > 0.0 && sumQ > 0.0 && sumU / sqrt( sumK * sumQ ) >= GetSmlarLimit() )
1166                                                                 res = true;
1167                                                 }
1168                                         }
1169                                         break;
1170                                 case ST_COSINE:
1171                                         {
1172                                                 double  power;
1173
1174                                                 power = sqrt( ((double)(key->size)) * ((double)(s->nelems)) );
1175
1176                                                 for(i=0; i<s->nelems; i++)
1177                                                         count += GETBIT(sign, HASHVAL(s->hash[i]));
1178
1179                                                 if ( ((double)count) / power >= GetSmlarLimit() )
1180                                                         res = true;
1181                                         }
1182                                         break;
1183                                 default:
1184                                         elog(ERROR,"GiST doesn't support current formula type of similarity");
1185                         }
1186                 }
1187                 else /* non-leaf page */
1188                 {
1189                         switch(getSmlType())
1190                         {
1191                                 case ST_TFIDF:
1192                                         {
1193                                                 StatCache       *stat = getHashedCache(fcinfo->flinfo->fn_extra);
1194                                                 double          sumU = 0.0,
1195                                                                         sumQ = 0.0,
1196                                                                         minK = -1.0;
1197                                                 double          maxKTF = getIdfMaxLimit(key);
1198
1199                                                 Assert( s->df );
1200                                                 if ( stat->info != s->info )
1201                                                         elog(ERROR,"Statistic and actual argument have different type");
1202
1203                                                 for(i=0;i<s->nelems;i++)
1204                                                 {
1205                                                         int32           hbit = HASHVAL(s->hash[i]);
1206
1207                                                         sumQ += s->df[i] * s->df[i];
1208                                                         if ( GETBIT(sign, hbit) )
1209                                                         {
1210                                                                 sumU += stat->selems[ hbit ].idfMax * maxKTF * s->df[i];
1211                                                                 if ( minK > stat->selems[ hbit ].idfMin  || minK < 0.0 )
1212                                                                         minK = stat->selems[ hbit ].idfMin;
1213                                                         }
1214                                                 }
1215
1216                                                 if ( sumQ > 0.0 && minK > 0.0 && sumU / sqrt( sumQ * minK ) >= GetSmlarLimit() )
1217                                                         res = true;
1218                                         }
1219                                         break;
1220                                 case ST_COSINE:
1221                                         {
1222                                                 for(i=0; i<s->nelems; i++)
1223                                                         count += GETBIT(sign, HASHVAL(s->hash[i]));
1224
1225                                                 if ( s->nelems == count  || sqrt(((double)count) / ((double)(s->nelems))) >= GetSmlarLimit() )
1226                                                         res = true;
1227                                         }
1228                                         break;
1229                                 default:
1230                                         elog(ERROR,"GiST doesn't support current formula type of similarity");
1231                         }
1232                 }
1233         }
1234
1235 #if 0
1236         {
1237                 static int nnres = 0;
1238                 static int nres = 0;
1239                 if ( GIST_LEAF(entry) ) {
1240                         if ( res )
1241                                 nres++;
1242                         else
1243                                 nnres++;
1244                         elog(NOTICE,"%s nn:%d n:%d", (ISARRKEY(key)) ? "ARR" : ( (ISALLTRUE(key)) ? "TRUE" : "SIGN" ), nnres, nres  );
1245                 }
1246         }
1247 #endif
1248
1249         PG_RETURN_BOOL(res);
1250 }