Fix compiler warnings
[smlar.git] / smlar_gist.c
1 #include "smlar.h"
2
3 #include "fmgr.h"
4 #include "access/gist.h"
5 #include "access/skey.h"
6 #if PG_VERSION_NUM < 130000
7 #include "access/tuptoaster.h"
8 #else
9 #include "access/heaptoast.h"
10 #endif
11 #include "utils/memutils.h"
12
13 typedef struct SmlSign {
14         int32   vl_len_; /* varlena header (do not touch directly!) */
15         int32   flag:8,
16                         size:24;
17         int32   maxrepeat;
18         char    data[1];
19 } SmlSign;
20
21 #define SMLSIGNHDRSZ    (offsetof(SmlSign, data))
22
23 #define BITBYTE 8
24 #define SIGLENINT  61
25 #define SIGLEN  ( sizeof(int)*SIGLENINT )
26 #define SIGLENBIT (SIGLEN*BITBYTE - 1)  /* see makesign */
27 typedef char BITVEC[SIGLEN];
28 typedef char *BITVECP;
29 #define LOOPBYTE \
30                 for(i=0;i<SIGLEN;i++)
31
32 #define GETBYTE(x,i) ( *( (BITVECP)(x) + (int)( (i) / BITBYTE ) ) )
33 #define GETBITBYTE(x,i) ( ((char)(x)) >> i & 0x01 )
34 #define CLRBIT(x,i)   GETBYTE(x,i) &= ~( 0x01 << ( (i) % BITBYTE ) )
35 #define SETBIT(x,i)   GETBYTE(x,i) |=  ( 0x01 << ( (i) % BITBYTE ) )
36 #define GETBIT(x,i) ( (GETBYTE(x,i) >> ( (i) % BITBYTE )) & 0x01 )
37
38 #define HASHVAL(val) (((unsigned int)(val)) % SIGLENBIT)
39 #define HASH(sign, val) SETBIT((sign), HASHVAL(val))
40
41 #define ARRKEY                  0x01
42 #define SIGNKEY                 0x02
43 #define ALLISTRUE               0x04
44
45 #define ISARRKEY(x)             ( ((SmlSign*)x)->flag & ARRKEY )
46 #define ISSIGNKEY(x)    ( ((SmlSign*)x)->flag & SIGNKEY )
47 #define ISALLTRUE(x)    ( ((SmlSign*)x)->flag & ALLISTRUE )
48
49 #define CALCGTSIZE(flag, len)   ( SMLSIGNHDRSZ + ( ( (flag) & ARRKEY ) ? ((len)*sizeof(uint32)) : (((flag) & ALLISTRUE) ? 0 : SIGLEN) ) )
50 #define GETSIGN(x)                              ( (BITVECP)( (char*)x+SMLSIGNHDRSZ ) )
51 #define GETARR(x)                               ( (uint32*)( (char*)x+SMLSIGNHDRSZ ) )
52
53 #define GETENTRY(vec,pos) ((SmlSign *) DatumGetPointer((vec)->vector[(pos)].key))
54
55 /*
56  * Fake IO
57  */
58 PG_FUNCTION_INFO_V1(gsmlsign_in);
59 Datum   gsmlsign_in(PG_FUNCTION_ARGS);
60 Datum
61 gsmlsign_in(PG_FUNCTION_ARGS)
62 {
63         elog(ERROR, "not implemented");
64         PG_RETURN_DATUM(0);
65 }
66
67 PG_FUNCTION_INFO_V1(gsmlsign_out);
68 Datum   gsmlsign_out(PG_FUNCTION_ARGS);
69 Datum
70 gsmlsign_out(PG_FUNCTION_ARGS)
71 {
72         elog(ERROR, "not implemented");
73         PG_RETURN_DATUM(0);
74 }
75
76 /*
77  * Compress/decompress
78  */
79
80 /* Number of one-bits in an unsigned byte */
81 static const uint8 number_of_ones[256] = {
82         0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4,
83         1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
84         1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
85         2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
86         1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
87         2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
88         2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
89         3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
90         1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
91         2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
92         2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
93         3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
94         2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
95         3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
96         3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
97         4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8
98 };
99
100 static int
101 compareint(const void *va, const void *vb)
102 {
103         uint32  a = *((uint32 *) va);
104         uint32  b = *((uint32 *) vb);
105
106         if (a == b)
107                 return 0;
108         return (a > b) ? 1 : -1;
109 }
110
111 /*
112  * Removes duplicates from an array of int32. 'l' is
113  * size of the input array. Returns the new size of the array.
114  */
115 static int
116 uniqueint(uint32 *a, int32 l, int32 *max)
117 {
118         uint32          *ptr,
119                                 *res;
120         int32           cnt = 0;
121
122         *max = 1;
123
124         if (l <= 1)
125                 return l;
126
127         ptr = res = a;
128
129         qsort((void *) a, l, sizeof(uint32), compareint);
130
131         while (ptr - a < l)
132                 if (*ptr != *res)
133                 {
134                         cnt = 1;
135                         *(++res) = *ptr++;
136                 }
137                 else
138                 {
139                         cnt++;
140                         if ( cnt > *max )
141                                 *max = cnt;
142                         ptr++;
143                 }
144
145         if ( cnt > *max )
146                 *max = cnt;
147
148         return res + 1 - a;
149 }
150
151 SmlSign*
152 Array2HashedArray(ProcTypeInfo info, ArrayType *a)
153 {
154         SimpleArray *s = Array2SimpleArray(info, a);
155         SmlSign         *sign;
156         int32           len, i;
157         uint32          *ptr;
158
159         len = CALCGTSIZE( ARRKEY, s->nelems );
160
161         getFmgrInfoHash(s->info);
162         if (s->info->tupDesc)
163                 elog(ERROR, "GiST  doesn't support composite (weighted) type");
164
165         sign = palloc( len );
166         sign->flag = ARRKEY;
167         sign->size = s->nelems;
168
169         ptr = GETARR(sign);
170         for(i=0;i<s->nelems;i++)
171                 ptr[i] = DatumGetUInt32(FunctionCall1Coll(&s->info->hashFunc,
172                                                                                                   DEFAULT_COLLATION_OID,
173                                                                                                   s->elems[i]));
174
175         /*
176          * there is a collision of hash-function; len is always equal or less than
177          * s->nelems
178          */
179         sign->size = uniqueint( GETARR(sign), sign->size, &sign->maxrepeat );
180         len = CALCGTSIZE( ARRKEY, sign->size );
181         SET_VARSIZE(sign, len);
182
183         return sign;
184 }
185
186 static int
187 HashedElemCmp(const void *va, const void *vb)
188 {
189         uint32  a = ((HashedElem *) va)->hash;
190         uint32  b = ((HashedElem *) vb)->hash;
191
192         if (a == b)
193         {
194                 double  ma = ((HashedElem *) va)->idfMin;
195                 double  mb = ((HashedElem *) va)->idfMin;
196
197                 if (ma == mb)
198                         return 0;
199
200                 return ( ma > mb ) ? 1 : -1;
201         }
202
203         return (a > b) ? 1 : -1;
204 }
205
206 static int
207 uniqueHashedElem(HashedElem *a, int32 l)
208 {
209         HashedElem      *ptr,
210                                 *res;
211
212         if (l <= 1)
213                 return l;
214
215         ptr = res = a;
216
217         qsort(a, l, sizeof(HashedElem), HashedElemCmp);
218
219         while (ptr - a < l)
220                 if (ptr->hash != res->hash)
221                         *(++res) = *ptr++;
222                 else
223                 {
224                         res->idfMax = ptr->idfMax;
225                         ptr++;
226                 }
227
228         return res + 1 - a;
229 }
230
231 static StatCache*
232 getHashedCache(void *cache)
233 {
234         StatCache *stat = getStat(cache, SIGLENBIT);
235
236         if ( stat->nhelems < 0 )
237         {
238                 int i;
239                 /*
240                  * Init
241                  */
242
243                 if (stat->info->tupDesc)
244                         elog(ERROR, "GiST  doesn't support composite (weighted) type");
245                 getFmgrInfoHash(stat->info);
246                 for(i=0;i<stat->nelems;i++)
247                 {
248                         uint32  hash;
249                         int             index;
250
251                         hash = DatumGetUInt32(FunctionCall1Coll(&stat->info->hashFunc,
252                                                                                                         DEFAULT_COLLATION_OID,
253                                                                                                         stat->elems[i].datum));
254                         index = HASHVAL(hash);
255
256                         stat->helems[i].hash = hash;
257                         stat->helems[i].idfMin = stat->helems[i].idfMax = stat->elems[i].idf;   
258                         if ( stat->selems[index].idfMin == 0.0 )
259                                 stat->selems[index].idfMin = stat->selems[index].idfMax = stat->elems[i].idf;
260                         else if ( stat->selems[index].idfMin > stat->elems[i].idf )
261                                 stat->selems[index].idfMin = stat->elems[i].idf;
262                         else if ( stat->selems[index].idfMax < stat->elems[i].idf )
263                                 stat->selems[index].idfMax = stat->elems[i].idf;
264                 }
265
266                 stat->nhelems = uniqueHashedElem( stat->helems, stat->nelems);
267         }
268
269         return stat;
270 }
271
272 static HashedElem*
273 getHashedElemIdf(StatCache *stat, uint32 hash, HashedElem *StopLow)
274 {
275         HashedElem      *StopMiddle,
276                                 *StopHigh = stat->helems + stat->nhelems;
277
278         if ( !StopLow )
279                 StopLow = stat->helems;
280
281         while (StopLow < StopHigh) {
282                 StopMiddle = StopLow + ((StopHigh - StopLow) >> 1);
283
284                 if ( StopMiddle->hash == hash )
285                         return StopMiddle;
286                 else if ( StopMiddle->hash < hash )
287                         StopLow = StopMiddle + 1;
288                 else
289                         StopHigh = StopMiddle;
290         }
291
292         return NULL;
293 }
294
295 static void
296 fillHashVal(void *cache, SimpleArray *a)
297 {
298         int i;
299
300         if (a->hash)
301                 return;
302
303         allocateHash(cache, a);
304
305         if (a->info->tupDesc)
306                 elog(ERROR, "GiST  doesn't support composite (weighted) type");
307         getFmgrInfoHash(a->info);
308
309         for(i=0;i<a->nelems;i++)
310                 a->hash[i] = DatumGetUInt32(FunctionCall1Coll(&a->info->hashFunc,
311                                                                                                           DEFAULT_COLLATION_OID,
312                                                                                                           a->elems[i]));
313 }
314
315
316 static bool
317 hasHashedElem(SmlSign  *a, uint32 h)
318 {
319         uint32  *StopLow = GETARR(a),
320                         *StopHigh = GETARR(a) + a->size,
321                         *StopMiddle;
322
323         while (StopLow < StopHigh) {
324                 StopMiddle = StopLow + ((StopHigh - StopLow) >> 1);
325
326                 if ( *StopMiddle == h )
327                         return true;
328                 else if ( *StopMiddle < h )
329                         StopLow = StopMiddle + 1;
330                 else
331                         StopHigh = StopMiddle;
332         }
333
334         return false;
335 }
336
337 static void
338 makesign(BITVECP sign, SmlSign  *a)
339 {
340         int32   i;
341         uint32  *ptr = GETARR(a);
342
343         MemSet((void *) sign, 0, sizeof(BITVEC));
344         SETBIT(sign, SIGLENBIT);   /* set last unused bit */
345
346         for (i = 0; i < a->size; i++)
347                 HASH(sign, ptr[i]);
348 }
349
350 static int32
351 sizebitvec(BITVECP sign)
352 {
353         int32   size = 0,
354                         i;
355
356         LOOPBYTE
357                 size += number_of_ones[(unsigned char) sign[i]];
358
359         return size;
360 }
361
362 PG_FUNCTION_INFO_V1(gsmlsign_compress);
363 Datum gsmlsign_compress(PG_FUNCTION_ARGS);
364 Datum
365 gsmlsign_compress(PG_FUNCTION_ARGS)
366 {
367         GISTENTRY  *entry = (GISTENTRY *) PG_GETARG_POINTER(0);
368         GISTENTRY  *retval = entry;
369
370         if (entry->leafkey) /* new key */
371         {
372                 SmlSign         *sign;
373                 ArrayType       *a = DatumGetArrayTypeP(entry->key);
374
375                 sign = Array2HashedArray(NULL, a);
376
377                 if ( VARSIZE(sign) > TOAST_INDEX_TARGET )
378                 {       /* make signature due to its big size */
379                         SmlSign *tmpsign;
380                         int             len;
381
382                         len = CALCGTSIZE( SIGNKEY, sign->size );
383                         tmpsign = palloc( len );
384                         tmpsign->flag = SIGNKEY;
385                         SET_VARSIZE(tmpsign, len);
386
387                         makesign(GETSIGN(tmpsign), sign);
388                         tmpsign->size = sizebitvec(GETSIGN(tmpsign));
389                         tmpsign->maxrepeat = sign->maxrepeat;
390                         sign = tmpsign;
391                 }
392
393                 retval = (GISTENTRY *) palloc(sizeof(GISTENTRY));
394                 gistentryinit(*retval, PointerGetDatum(sign),
395                                                 entry->rel, entry->page,
396                                                 entry->offset, false);
397         }
398         else if ( ISSIGNKEY(DatumGetPointer(entry->key)) &&
399                                 !ISALLTRUE(DatumGetPointer(entry->key)) )
400         {
401                 SmlSign *sign = (SmlSign*)DatumGetPointer(entry->key);
402
403                 Assert( sign->size == sizebitvec(GETSIGN(sign)) );
404
405                 if ( sign->size == SIGLENBIT )
406                 {
407                         int32   len = CALCGTSIZE(SIGNKEY | ALLISTRUE, 0);
408                         int32   maxrepeat = sign->maxrepeat;
409
410                         sign = (SmlSign *) palloc(len);
411                         SET_VARSIZE(sign, len);
412                         sign->flag = SIGNKEY | ALLISTRUE;
413                         sign->size = SIGLENBIT;
414                         sign->maxrepeat = maxrepeat;
415
416                         retval = (GISTENTRY *) palloc(sizeof(GISTENTRY));
417
418                         gistentryinit(*retval, PointerGetDatum(sign),
419                                                         entry->rel, entry->page,
420                                                         entry->offset, false);
421                 }
422         }
423
424         PG_RETURN_POINTER(retval);
425 }
426
427 PG_FUNCTION_INFO_V1(gsmlsign_decompress);
428 Datum gsmlsign_decompress(PG_FUNCTION_ARGS);
429 Datum
430 gsmlsign_decompress(PG_FUNCTION_ARGS)
431 {
432         GISTENTRY       *entry = (GISTENTRY *) PG_GETARG_POINTER(0);
433         SmlSign         *key =  (SmlSign*)DatumGetPointer(PG_DETOAST_DATUM(entry->key));
434
435         if (key != (SmlSign *) DatumGetPointer(entry->key))
436         {
437                 GISTENTRY  *retval = (GISTENTRY *) palloc(sizeof(GISTENTRY));
438
439                 gistentryinit(*retval, PointerGetDatum(key),
440                                                 entry->rel, entry->page,
441                                                 entry->offset, false);
442
443                 PG_RETURN_POINTER(retval);
444         }
445
446         PG_RETURN_POINTER(entry);
447 }
448
449 /*
450  * Union method
451  */
452 static bool
453 unionkey(BITVECP sbase, SmlSign *add)
454 {
455         int32   i;
456
457         if (ISSIGNKEY(add))
458         {
459                 BITVECP sadd = GETSIGN(add);
460
461                 if (ISALLTRUE(add))
462                         return true;
463
464                 LOOPBYTE
465                         sbase[i] |= sadd[i];
466         }
467         else
468         {
469                 uint32  *ptr = GETARR(add);
470
471                 for (i = 0; i < add->size; i++)
472                         HASH(sbase, ptr[i]);
473         }
474
475         return false;
476 }
477
478 PG_FUNCTION_INFO_V1(gsmlsign_union);
479 Datum gsmlsign_union(PG_FUNCTION_ARGS);
480 Datum
481 gsmlsign_union(PG_FUNCTION_ARGS)
482 {
483         GistEntryVector *entryvec = (GistEntryVector *) PG_GETARG_POINTER(0);
484         int                             *size = (int *) PG_GETARG_POINTER(1);
485         BITVEC                  base;
486         int32           i,
487                                 len,
488                                 maxrepeat = 1;
489         int32           flag = 0;
490         SmlSign    *result;
491
492         MemSet((void *) base, 0, sizeof(BITVEC));
493         for (i = 0; i < entryvec->n; i++)
494         {
495                 if (GETENTRY(entryvec, i)->maxrepeat > maxrepeat)
496                         maxrepeat = GETENTRY(entryvec, i)->maxrepeat;
497                 if (unionkey(base, GETENTRY(entryvec, i)))
498                 {
499                         flag = ALLISTRUE;
500                         break;
501                 }
502         }
503
504         flag |= SIGNKEY;
505         len = CALCGTSIZE(flag, 0);
506         result = (SmlSign *) palloc(len);
507         *size = len;
508         SET_VARSIZE(result, len);
509         result->flag = flag;
510         result->maxrepeat = maxrepeat;
511
512         if (!ISALLTRUE(result))
513         {
514                 memcpy((void *) GETSIGN(result), (void *) base, sizeof(BITVEC));
515                 result->size = sizebitvec(GETSIGN(result));
516         }
517         else
518                 result->size = SIGLENBIT;
519
520         PG_RETURN_POINTER(result);
521 }
522
523 /*
524  * Same method
525  */
526
527 PG_FUNCTION_INFO_V1(gsmlsign_same);
528 Datum gsmlsign_same(PG_FUNCTION_ARGS);
529 Datum
530 gsmlsign_same(PG_FUNCTION_ARGS)
531 {
532         SmlSign *a = (SmlSign*)PG_GETARG_POINTER(0);
533         SmlSign *b = (SmlSign*)PG_GETARG_POINTER(1);
534         bool    *result = (bool *) PG_GETARG_POINTER(2);
535
536         if (a->size != b->size)
537         {
538                 *result = false;
539         }
540         else if (ISSIGNKEY(a))
541         {       /* then b also ISSIGNKEY */
542                 if ( ISALLTRUE(a) )
543                 {
544                         /* in this case b is all true too - other cases is catched
545                            upper */
546                         *result = true;
547                 }
548                 else
549                 {
550                         int32   i;
551                         BITVECP sa = GETSIGN(a),
552                                         sb = GETSIGN(b);
553
554                         *result = true;
555
556                         if ( !ISALLTRUE(a) )
557                         {
558                                 LOOPBYTE
559                                 {
560                                         if (sa[i] != sb[i])
561                                         {
562                                                 *result = false;
563                                                 break;
564                                         }
565                                 }
566                         }
567                 }
568         }
569         else
570         {
571                 uint32  *ptra = GETARR(a),
572                                 *ptrb = GETARR(b);
573                 int32   i;
574
575                 *result = true;
576                 for (i = 0; i < a->size; i++)
577                 {
578                         if ( ptra[i] != ptrb[i])
579                         {
580                                 *result = false;
581                                 break;
582                         }
583                 }
584         }
585
586         PG_RETURN_POINTER(result);
587 }
588
589 /*
590  * Penalty method
591  */
592 static int
593 hemdistsign(BITVECP a, BITVECP b)
594 {
595         int     i,
596                 diff,
597                 dist = 0;
598
599         LOOPBYTE
600         {
601                 diff = (unsigned char) (a[i] ^ b[i]);
602                 dist += number_of_ones[diff];
603         }
604         return dist;
605 }
606
607 static int
608 hemdist(SmlSign *a, SmlSign *b)
609 {
610         if (ISALLTRUE(a))
611         {
612                 if (ISALLTRUE(b))
613                         return 0;
614                 else
615                         return SIGLENBIT - b->size;
616         }
617         else if (ISALLTRUE(b))
618                 return SIGLENBIT - b->size;
619
620         return hemdistsign(GETSIGN(a), GETSIGN(b));
621 }
622
623 PG_FUNCTION_INFO_V1(gsmlsign_penalty);
624 Datum gsmlsign_penalty(PG_FUNCTION_ARGS);
625 Datum
626 gsmlsign_penalty(PG_FUNCTION_ARGS)
627 {
628         GISTENTRY       *origentry = (GISTENTRY *) PG_GETARG_POINTER(0); /* always ISSIGNKEY */
629         GISTENTRY       *newentry = (GISTENTRY *) PG_GETARG_POINTER(1);
630         float           *penalty = (float *) PG_GETARG_POINTER(2);
631         SmlSign         *origval = (SmlSign *) DatumGetPointer(origentry->key);
632         SmlSign         *newval = (SmlSign *) DatumGetPointer(newentry->key);
633         BITVECP         orig = GETSIGN(origval);
634
635         *penalty = 0.0;
636
637         if (ISARRKEY(newval))
638         {
639                 BITVEC  sign;
640
641                 makesign(sign, newval);
642
643                 if (ISALLTRUE(origval))
644                         *penalty = ((float) (SIGLENBIT - sizebitvec(sign))) / (float) (SIGLENBIT + 1);
645                 else
646                         *penalty = hemdistsign(sign, orig);
647         }
648         else
649                 *penalty = hemdist(origval, newval);
650
651         PG_RETURN_POINTER(penalty);
652 }
653
654 /*
655  * Picksplit method
656  */
657
658 typedef struct
659 {
660         bool    allistrue;
661         int32   size;
662         BITVEC  sign;
663 } CACHESIGN;
664
665 static void
666 fillcache(CACHESIGN *item, SmlSign *key)
667 {
668         item->allistrue = false;
669         item->size = key->size;
670
671         if (ISARRKEY(key))
672         {
673                 makesign(item->sign, key);
674                 item->size = sizebitvec( item->sign );
675         }
676         else if (ISALLTRUE(key))
677                 item->allistrue = true;
678         else
679                 memcpy((void *) item->sign, (void *) GETSIGN(key), sizeof(BITVEC));
680 }
681
682 #define WISH_F(a,b,c) (double)( -(double)(((a)-(b))*((a)-(b))*((a)-(b)))*(c) )
683
684 typedef struct
685 {
686         OffsetNumber pos;
687         int32           cost;
688 } SPLITCOST;
689
690 static int
691 comparecost(const void *va, const void *vb)
692 {
693         SPLITCOST  *a = (SPLITCOST *) va;
694         SPLITCOST  *b = (SPLITCOST *) vb;
695
696         if (a->cost == b->cost)
697                 return 0;
698         else
699                 return (a->cost > b->cost) ? 1 : -1;
700 }
701
702 static int
703 hemdistcache(CACHESIGN *a, CACHESIGN *b)
704 {
705         if (a->allistrue)
706         {
707                 if (b->allistrue)
708                         return 0;
709                 else
710                         return SIGLENBIT - b->size;
711         }
712         else if (b->allistrue)
713                 return SIGLENBIT - a->size;
714
715         return hemdistsign(a->sign, b->sign);
716 }
717
718 PG_FUNCTION_INFO_V1(gsmlsign_picksplit);
719 Datum gsmlsign_picksplit(PG_FUNCTION_ARGS);
720 Datum
721 gsmlsign_picksplit(PG_FUNCTION_ARGS)
722 {
723         GistEntryVector *entryvec = (GistEntryVector *) PG_GETARG_POINTER(0);
724         GIST_SPLITVEC *v = (GIST_SPLITVEC *) PG_GETARG_POINTER(1);
725         OffsetNumber k,
726                                 j;
727         SmlSign         *datum_l,
728                                 *datum_r;
729         BITVECP         union_l,
730                                 union_r;
731         int32           size_alpha,
732                                 size_beta;
733         int32           size_waste,
734                                 waste = -1;
735         int32           nbytes;
736         OffsetNumber seed_1 = 0,
737                                 seed_2 = 0;
738         OffsetNumber *left,
739                                 *right;
740         OffsetNumber maxoff;
741         BITVECP         ptr;
742         int                     i;
743         CACHESIGN  *cache;
744         SPLITCOST  *costvector;
745
746         maxoff = entryvec->n - 2;
747         nbytes = (maxoff + 2) * sizeof(OffsetNumber);
748         v->spl_left = (OffsetNumber *) palloc(nbytes);
749         v->spl_right = (OffsetNumber *) palloc(nbytes);
750
751         cache = (CACHESIGN *) palloc(sizeof(CACHESIGN) * (maxoff + 2));
752         fillcache(&cache[FirstOffsetNumber], GETENTRY(entryvec, FirstOffsetNumber));
753
754         for (k = FirstOffsetNumber; k < maxoff; k = OffsetNumberNext(k))
755         {
756                 for (j = OffsetNumberNext(k); j <= maxoff; j = OffsetNumberNext(j))
757                 {
758                         if (k == FirstOffsetNumber)
759                                 fillcache(&cache[j], GETENTRY(entryvec, j));
760
761                         size_waste = hemdistcache(&(cache[j]), &(cache[k]));
762                         if (size_waste > waste)
763                         {
764                                 waste = size_waste;
765                                 seed_1 = k;
766                                 seed_2 = j;
767                         }
768                 }
769         }
770
771         left = v->spl_left;
772         v->spl_nleft = 0;
773         right = v->spl_right;
774         v->spl_nright = 0;
775
776         if (seed_1 == 0 || seed_2 == 0)
777         {
778                 seed_1 = 1;
779                 seed_2 = 2;
780         }
781
782         /* form initial .. */
783         if (cache[seed_1].allistrue)
784         {
785                 datum_l = (SmlSign *) palloc(CALCGTSIZE(SIGNKEY | ALLISTRUE, 0));
786                 SET_VARSIZE(datum_l, CALCGTSIZE(SIGNKEY | ALLISTRUE, 0));
787                 datum_l->flag = SIGNKEY | ALLISTRUE;
788                 datum_l->size = SIGLENBIT;
789         }
790         else
791         {
792                 datum_l = (SmlSign *) palloc(CALCGTSIZE(SIGNKEY, 0));
793                 SET_VARSIZE(datum_l, CALCGTSIZE(SIGNKEY, 0));
794                 datum_l->flag = SIGNKEY;
795                 memcpy((void *) GETSIGN(datum_l), (void *) cache[seed_1].sign, sizeof(BITVEC));
796                 datum_l->size = cache[seed_1].size;
797         }
798         if (cache[seed_2].allistrue)
799         {
800                 datum_r = (SmlSign *) palloc(CALCGTSIZE(SIGNKEY | ALLISTRUE, 0));
801                 SET_VARSIZE(datum_r, CALCGTSIZE(SIGNKEY | ALLISTRUE, 0));
802                 datum_r->flag = SIGNKEY | ALLISTRUE;
803                 datum_r->size = SIGLENBIT;
804         }
805         else
806         {
807                 datum_r = (SmlSign *) palloc(CALCGTSIZE(SIGNKEY, 0));
808                 SET_VARSIZE(datum_r, CALCGTSIZE(SIGNKEY, 0));
809                 datum_r->flag = SIGNKEY;
810                 memcpy((void *) GETSIGN(datum_r), (void *) cache[seed_2].sign, sizeof(BITVEC));
811                 datum_r->size = cache[seed_2].size;
812         }
813
814         union_l = GETSIGN(datum_l);
815         union_r = GETSIGN(datum_r);
816         maxoff = OffsetNumberNext(maxoff);
817         fillcache(&cache[maxoff], GETENTRY(entryvec, maxoff));
818         /* sort before ... */
819         costvector = (SPLITCOST *) palloc(sizeof(SPLITCOST) * maxoff);
820         for (j = FirstOffsetNumber; j <= maxoff; j = OffsetNumberNext(j))
821         {
822                 costvector[j - 1].pos = j;
823                 size_alpha = hemdistcache(&(cache[seed_1]), &(cache[j]));
824                 size_beta = hemdistcache(&(cache[seed_2]), &(cache[j]));
825                 costvector[j - 1].cost = Abs(size_alpha - size_beta);
826         }
827         qsort((void *) costvector, maxoff, sizeof(SPLITCOST), comparecost);
828
829         datum_l->maxrepeat = datum_r->maxrepeat = 1;
830
831         for (k = 0; k < maxoff; k++)
832         {
833                 j = costvector[k].pos;
834                 if (j == seed_1)
835                 {
836                         *left++ = j;
837                         v->spl_nleft++;
838                         continue;
839                 }
840                 else if (j == seed_2)
841                 {
842                         *right++ = j;
843                         v->spl_nright++;
844                         continue;
845                 }
846
847                 if (ISALLTRUE(datum_l) || cache[j].allistrue)
848                 {
849                         if (ISALLTRUE(datum_l) && cache[j].allistrue)
850                                 size_alpha = 0;
851                         else
852                                 size_alpha = SIGLENBIT - (
853                                                                         (cache[j].allistrue) ? datum_l->size : cache[j].size
854                                                         );
855                 }
856                 else
857                         size_alpha = hemdistsign(cache[j].sign, GETSIGN(datum_l));
858
859                 if (ISALLTRUE(datum_r) || cache[j].allistrue)
860                 {
861                         if (ISALLTRUE(datum_r) && cache[j].allistrue)
862                                 size_beta = 0;
863                         else
864                                 size_beta = SIGLENBIT - (
865                                                                         (cache[j].allistrue) ? datum_r->size : cache[j].size
866                                                         );
867                 }
868                 else
869                         size_beta = hemdistsign(cache[j].sign, GETSIGN(datum_r));
870
871                 if (size_alpha < size_beta + WISH_F(v->spl_nleft, v->spl_nright, 0.1))
872                 {
873                         if (ISALLTRUE(datum_l) || cache[j].allistrue)
874                         {
875                                 if (!ISALLTRUE(datum_l))
876                                         MemSet((void *) GETSIGN(datum_l), 0xff, sizeof(BITVEC));
877                                 datum_l->size = SIGLENBIT;
878                         }
879                         else
880                         {
881                                 ptr = cache[j].sign;
882                                 LOOPBYTE
883                                         union_l[i] |= ptr[i];
884                                 datum_l->size = sizebitvec(union_l);
885                         }
886                         *left++ = j;
887                         v->spl_nleft++;
888                 }
889                 else
890                 {
891                         if (ISALLTRUE(datum_r) || cache[j].allistrue)
892                         {
893                                 if (!ISALLTRUE(datum_r))
894                                         MemSet((void *) GETSIGN(datum_r), 0xff, sizeof(BITVEC));
895                                 datum_r->size = SIGLENBIT;
896                         }
897                         else
898                         {
899                                 ptr = cache[j].sign;
900                                 LOOPBYTE
901                                         union_r[i] |= ptr[i];
902                                 datum_r->size = sizebitvec(union_r);
903                         }
904                         *right++ = j;
905                         v->spl_nright++;
906                 }
907         }
908         *right = *left = FirstOffsetNumber;
909         v->spl_ldatum = PointerGetDatum(datum_l);
910         v->spl_rdatum = PointerGetDatum(datum_r);
911
912         Assert( datum_l->size = sizebitvec(GETSIGN(datum_l)) );
913         Assert( datum_r->size = sizebitvec(GETSIGN(datum_r)) );
914
915         PG_RETURN_POINTER(v);
916 }
917
918 static double
919 getIdfMaxLimit(SmlSign *key)
920 {
921         switch( getTFMethod() )
922         {
923                 case TF_CONST:
924                         return 1.0;
925                         break;
926                 case TF_N:
927                         return (double)(key->maxrepeat);
928                         break;
929                 case TF_LOG:
930                         return 1.0 + log( (double)(key->maxrepeat) );
931                         break;
932                 default:
933                         elog(ERROR,"Unknown TF method: %d", getTFMethod());
934         }
935
936         return 0.0;
937 }
938
939 /*
940  * Consistent function
941  */
942 PG_FUNCTION_INFO_V1(gsmlsign_consistent);
943 Datum gsmlsign_consistent(PG_FUNCTION_ARGS);
944 Datum
945 gsmlsign_consistent(PG_FUNCTION_ARGS)
946 {
947         GISTENTRY               *entry = (GISTENTRY *) PG_GETARG_POINTER(0);
948         StrategyNumber  strategy = (StrategyNumber) PG_GETARG_UINT16(2);
949         bool                    *recheck = (bool *) PG_GETARG_POINTER(4);
950         ArrayType               *a;
951         SmlSign                 *key = (SmlSign*)DatumGetPointer(entry->key);
952         int                             res = false;
953         SmlSign                 *query;
954         SimpleArray             *s;
955         int32                   i;
956
957         fcinfo->flinfo->fn_extra = SearchArrayCache(
958                                                                         fcinfo->flinfo->fn_extra,
959                                                                         fcinfo->flinfo->fn_mcxt,
960                                                                         PG_GETARG_DATUM(1), &a, &s, &query);
961
962         *recheck = true;
963
964         if ( ARRISVOID(a) )
965                 PG_RETURN_BOOL(res);
966
967         if ( strategy == SmlarOverlapStrategy )
968         {
969                 if (ISALLTRUE(key))
970                 {
971                         res = true;
972                 }
973                 else if (ISARRKEY(key))
974                 {
975                         uint32  *kptr = GETARR(key),
976                                         *qptr = GETARR(query);
977
978                         while( kptr - GETARR(key) < key->size && qptr - GETARR(query) < query->size )
979                         {
980                                 if ( *kptr < *qptr )
981                                         kptr++;
982                                 else if ( *kptr > *qptr )
983                                         qptr++;
984                                 else
985                                 {
986                                         res = true;
987                                         break;
988                                 }
989                         }
990                         *recheck = false;
991                 }
992                 else
993                 {
994                         BITVECP sign = GETSIGN(key);
995
996                         fillHashVal(fcinfo->flinfo->fn_extra, s);
997
998                         for(i=0; i<s->nelems; i++)
999                         {
1000                                 if ( GETBIT(sign, HASHVAL(s->hash[i])) )
1001                                 {
1002                                         res = true;
1003                                         break;
1004                                 }
1005                         }
1006                 }
1007         }
1008         /*
1009          *  SmlarSimilarityStrategy
1010          */
1011         else if (ISALLTRUE(key))
1012         {
1013                 if ( GIST_LEAF(entry) )
1014                 {
1015                         /*
1016                          * With TF/IDF similarity we cannot say anything useful
1017                          */
1018                         if ( query->size < SIGLENBIT && getSmlType() != ST_TFIDF )
1019                         {
1020                                 double power = ((double)(query->size)) * ((double)(SIGLENBIT));
1021
1022                                 if ( ((double)(query->size)) / sqrt(power) >= GetSmlarLimit() ) 
1023                                         res = true;
1024                         }
1025                         else
1026                         {
1027                                 res = true;     
1028                         }
1029                 }
1030                 else
1031                         res = true;
1032         }
1033         else if (ISARRKEY(key))
1034         {
1035                 uint32  *kptr = GETARR(key),
1036                                 *qptr = GETARR(query);
1037
1038                 Assert( GIST_LEAF(entry) );
1039
1040                 switch(getSmlType())
1041                 {
1042                         case ST_TFIDF:
1043                                 {
1044                                         StatCache       *stat = getHashedCache(fcinfo->flinfo->fn_extra);
1045                                         double          sumU = 0.0,
1046                                                                 sumQ = 0.0,
1047                                                                 sumK = 0.0;
1048                                         double          maxKTF = getIdfMaxLimit(key);
1049                                         HashedElem  *h;
1050
1051                                         Assert( s->df );
1052                                         fillHashVal(fcinfo->flinfo->fn_extra, s);
1053                                         if ( stat->info != s->info )
1054                                                 elog(ERROR,"Statistic and actual argument have different type");
1055
1056                                         for(i=0;i<s->nelems;i++)
1057                                         {
1058                                                 sumQ += s->df[i] * s->df[i];
1059
1060                                                 h = getHashedElemIdf(stat, s->hash[i], NULL);
1061                                                 if ( h && hasHashedElem(key, s->hash[i]) )
1062                                                 {
1063                                                         sumK += h->idfMin * h->idfMin;
1064                                                         sumU += h->idfMax * maxKTF * s->df[i];
1065                                                 } 
1066                                         }
1067
1068                                         if ( sumK > 0.0 && sumQ > 0.0 && sumU / sqrt( sumK * sumQ ) >= GetSmlarLimit() )
1069                                         {
1070                                                 /* 
1071                                                  * More precisely calculate sumK
1072                                                  */
1073                                                 h = NULL;
1074                                                 sumK = 0.0;
1075
1076                                                 for(i=0;i<key->size;i++)
1077                                                 {
1078                                                         h = getHashedElemIdf(stat, GETARR(key)[i], h);                                  
1079                                                         if (h)
1080                                                                 sumK += h->idfMin * h->idfMin;
1081                                                 }
1082
1083                                                 if ( sumK > 0.0 && sumQ > 0.0 && sumU / sqrt( sumK * sumQ ) >= GetSmlarLimit() )
1084                                                         res = true;
1085                                         }
1086                                 }
1087                                 break;
1088                         case ST_COSINE:
1089                                 {
1090                                         double                  power;
1091                                         power = sqrt( ((double)(key->size)) * ((double)(s->nelems)) );
1092
1093                                         if (  ((double)Min(key->size, s->nelems)) / power >= GetSmlarLimit() )
1094                                         {
1095                                                 int  cnt = 0;
1096
1097                                                 while( kptr - GETARR(key) < key->size && qptr - GETARR(query) < query->size )
1098                                                 {
1099                                                         if ( *kptr < *qptr )
1100                                                                 kptr++;
1101                                                         else if ( *kptr > *qptr )
1102                                                                 qptr++;
1103                                                         else
1104                                                         {
1105                                                                 cnt++;
1106                                                                 kptr++;
1107                                                                 qptr++;
1108                                                         }
1109                                                 }
1110
1111                                                 if ( ((double)cnt) / power >= GetSmlarLimit() )
1112                                                         res = true;
1113                                         }
1114                                 }
1115                                 break;
1116                         default:
1117                                 elog(ERROR,"GiST doesn't support current formula type of similarity");
1118                 }
1119         }
1120         else
1121         {       /* signature */
1122                 BITVECP sign = GETSIGN(key);
1123                 int32   count = 0;
1124
1125                 fillHashVal(fcinfo->flinfo->fn_extra, s);
1126
1127                 if ( GIST_LEAF(entry) )
1128                 {
1129                         switch(getSmlType())
1130                         {
1131                                 case ST_TFIDF:
1132                                         {
1133                                                 StatCache       *stat = getHashedCache(fcinfo->flinfo->fn_extra);
1134                                                 double          sumU = 0.0,
1135                                                                         sumQ = 0.0,
1136                                                                         sumK = 0.0;
1137                                                 double          maxKTF = getIdfMaxLimit(key);
1138
1139                                                 Assert( s->df );
1140                                                 if ( stat->info != s->info )
1141                                                         elog(ERROR,"Statistic and actual argument have different type");
1142
1143                                                 for(i=0;i<s->nelems;i++)
1144                                                 {
1145                                                         int32           hbit = HASHVAL(s->hash[i]);
1146
1147                                                         sumQ += s->df[i] * s->df[i];
1148                                                         if ( GETBIT(sign, hbit) )
1149                                                         {
1150                                                                 sumK += stat->selems[ hbit ].idfMin * stat->selems[ hbit ].idfMin;
1151                                                                 sumU += stat->selems[ hbit ].idfMax * maxKTF * s->df[i];
1152                                                         }
1153                                                 }
1154
1155                                                 if ( sumK > 0.0 && sumQ > 0.0 && sumU / sqrt( sumK * sumQ ) >= GetSmlarLimit() )
1156                                                 {
1157                                                         /* 
1158                                                          * More precisely calculate sumK
1159                                                          */
1160                                                         sumK = 0.0;
1161
1162                                                         for(i=0;i<SIGLENBIT;i++)
1163                                                                 if ( GETBIT(sign,i) )
1164                                                                         sumK += stat->selems[ i ].idfMin * stat->selems[ i ].idfMin;
1165
1166                                                         if ( sumK > 0.0 && sumQ > 0.0 && sumU / sqrt( sumK * sumQ ) >= GetSmlarLimit() )
1167                                                                 res = true;
1168                                                 }
1169                                         }
1170                                         break;
1171                                 case ST_COSINE:
1172                                         {
1173                                                 double  power;
1174
1175                                                 power = sqrt( ((double)(key->size)) * ((double)(s->nelems)) );
1176
1177                                                 for(i=0; i<s->nelems; i++)
1178                                                         count += GETBIT(sign, HASHVAL(s->hash[i]));
1179
1180                                                 if ( ((double)count) / power >= GetSmlarLimit() )
1181                                                         res = true;
1182                                         }
1183                                         break;
1184                                 default:
1185                                         elog(ERROR,"GiST doesn't support current formula type of similarity");
1186                         }
1187                 }
1188                 else /* non-leaf page */
1189                 {
1190                         switch(getSmlType())
1191                         {
1192                                 case ST_TFIDF:
1193                                         {
1194                                                 StatCache       *stat = getHashedCache(fcinfo->flinfo->fn_extra);
1195                                                 double          sumU = 0.0,
1196                                                                         sumQ = 0.0,
1197                                                                         minK = -1.0;
1198                                                 double          maxKTF = getIdfMaxLimit(key);
1199
1200                                                 Assert( s->df );
1201                                                 if ( stat->info != s->info )
1202                                                         elog(ERROR,"Statistic and actual argument have different type");
1203
1204                                                 for(i=0;i<s->nelems;i++)
1205                                                 {
1206                                                         int32           hbit = HASHVAL(s->hash[i]);
1207
1208                                                         sumQ += s->df[i] * s->df[i];
1209                                                         if ( GETBIT(sign, hbit) )
1210                                                         {
1211                                                                 sumU += stat->selems[ hbit ].idfMax * maxKTF * s->df[i];
1212                                                                 if ( minK > stat->selems[ hbit ].idfMin  || minK < 0.0 )
1213                                                                         minK = stat->selems[ hbit ].idfMin;
1214                                                         }
1215                                                 }
1216
1217                                                 if ( sumQ > 0.0 && minK > 0.0 && sumU / sqrt( sumQ * minK ) >= GetSmlarLimit() )
1218                                                         res = true;
1219                                         }
1220                                         break;
1221                                 case ST_COSINE:
1222                                         {
1223                                                 for(i=0; i<s->nelems; i++)
1224                                                         count += GETBIT(sign, HASHVAL(s->hash[i]));
1225
1226                                                 if ( s->nelems == count  || sqrt(((double)count) / ((double)(s->nelems))) >= GetSmlarLimit() )
1227                                                         res = true;
1228                                         }
1229                                         break;
1230                                 default:
1231                                         elog(ERROR,"GiST doesn't support current formula type of similarity");
1232                         }
1233                 }
1234         }
1235
1236 #if 0
1237         {
1238                 static int nnres = 0;
1239                 static int nres = 0;
1240                 if ( GIST_LEAF(entry) ) {
1241                         if ( res )
1242                                 nres++;
1243                         else
1244                                 nnres++;
1245                         elog(NOTICE,"%s nn:%d n:%d", (ISARRKEY(key)) ? "ARR" : ( (ISALLTRUE(key)) ? "TRUE" : "SIGN" ), nnres, nres  );
1246                 }
1247         }
1248 #endif
1249
1250         PG_RETURN_BOOL(res);
1251 }