support pgsql versions up to 11
[smlar.git] / smlar_gist.c
1 #include "smlar.h"
2
3 #include "fmgr.h"
4 #include "access/gist.h"
5 #include "access/skey.h"
6 #include "access/tuptoaster.h"
7 #include "utils/memutils.h"
8
9 typedef struct SmlSign {
10         int32   vl_len_; /* varlena header (do not touch directly!) */
11         int32   flag:8,
12                         size:24;
13         int32   maxrepeat;
14         char    data[1];
15 } SmlSign;
16
17 #define SMLSIGNHDRSZ    (offsetof(SmlSign, data))
18
19 #define BITBYTE 8
20 #define SIGLENINT  61
21 #define SIGLEN  ( sizeof(int)*SIGLENINT )
22 #define SIGLENBIT (SIGLEN*BITBYTE - 1)  /* see makesign */
23 typedef char BITVEC[SIGLEN];
24 typedef char *BITVECP;
25 #define LOOPBYTE \
26                 for(i=0;i<SIGLEN;i++)
27
28 #define GETBYTE(x,i) ( *( (BITVECP)(x) + (int)( (i) / BITBYTE ) ) )
29 #define GETBITBYTE(x,i) ( ((char)(x)) >> i & 0x01 )
30 #define CLRBIT(x,i)   GETBYTE(x,i) &= ~( 0x01 << ( (i) % BITBYTE ) )
31 #define SETBIT(x,i)   GETBYTE(x,i) |=  ( 0x01 << ( (i) % BITBYTE ) )
32 #define GETBIT(x,i) ( (GETBYTE(x,i) >> ( (i) % BITBYTE )) & 0x01 )
33
34 #define HASHVAL(val) (((unsigned int)(val)) % SIGLENBIT)
35 #define HASH(sign, val) SETBIT((sign), HASHVAL(val))
36
37 #define ARRKEY                  0x01
38 #define SIGNKEY                 0x02
39 #define ALLISTRUE               0x04
40
41 #define ISARRKEY(x)             ( ((SmlSign*)x)->flag & ARRKEY )
42 #define ISSIGNKEY(x)    ( ((SmlSign*)x)->flag & SIGNKEY )
43 #define ISALLTRUE(x)    ( ((SmlSign*)x)->flag & ALLISTRUE )
44
45 #define CALCGTSIZE(flag, len)   ( SMLSIGNHDRSZ + ( ( (flag) & ARRKEY ) ? ((len)*sizeof(uint32)) : (((flag) & ALLISTRUE) ? 0 : SIGLEN) ) )
46 #define GETSIGN(x)                              ( (BITVECP)( (char*)x+SMLSIGNHDRSZ ) )
47 #define GETARR(x)                               ( (uint32*)( (char*)x+SMLSIGNHDRSZ ) )
48
49 #define GETENTRY(vec,pos) ((SmlSign *) DatumGetPointer((vec)->vector[(pos)].key))
50
51 /*
52  * Fake IO
53  */
54 PG_FUNCTION_INFO_V1(gsmlsign_in);
55 Datum   gsmlsign_in(PG_FUNCTION_ARGS);
56 Datum
57 gsmlsign_in(PG_FUNCTION_ARGS)
58 {
59         elog(ERROR, "not implemented");
60         PG_RETURN_DATUM(0);
61 }
62
63 PG_FUNCTION_INFO_V1(gsmlsign_out);
64 Datum   gsmlsign_out(PG_FUNCTION_ARGS);
65 Datum
66 gsmlsign_out(PG_FUNCTION_ARGS)
67 {
68         elog(ERROR, "not implemented");
69         PG_RETURN_DATUM(0);
70 }
71
72 /*
73  * Compress/decompress
74  */
75
76 /* Number of one-bits in an unsigned byte */
77 static const uint8 number_of_ones[256] = {
78         0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4,
79         1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
80         1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
81         2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
82         1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
83         2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
84         2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
85         3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
86         1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
87         2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
88         2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
89         3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
90         2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
91         3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
92         3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
93         4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8
94 };
95
96 static int
97 compareint(const void *va, const void *vb)
98 {
99         uint32  a = *((uint32 *) va);
100         uint32  b = *((uint32 *) vb);
101
102         if (a == b)
103                 return 0;
104         return (a > b) ? 1 : -1;
105 }
106
107 /*
108  * Removes duplicates from an array of int32. 'l' is
109  * size of the input array. Returns the new size of the array.
110  */
111 static int
112 uniqueint(uint32 *a, int32 l, int32 *max)
113 {
114         uint32          *ptr,
115                                 *res;
116         int32           cnt = 0;
117
118         *max = 1;
119
120         if (l <= 1)
121                 return l;
122
123         ptr = res = a;
124
125         qsort((void *) a, l, sizeof(uint32), compareint);
126
127         while (ptr - a < l)
128                 if (*ptr != *res)
129                 {
130                         cnt = 1;
131                         *(++res) = *ptr++;
132                 }
133                 else
134                 {
135                         cnt++;
136                         if ( cnt > *max )
137                                 *max = cnt;
138                         ptr++;
139                 }
140
141         if ( cnt > *max )
142                 *max = cnt;
143
144         return res + 1 - a;
145 }
146
147 SmlSign*
148 Array2HashedArray(ProcTypeInfo info, ArrayType *a)
149 {
150         SimpleArray *s = Array2SimpleArray(info, a);
151         SmlSign         *sign;
152         int32           len, i;
153         uint32          *ptr;
154
155         len = CALCGTSIZE( ARRKEY, s->nelems );
156
157         getFmgrInfoHash(s->info);
158         if (s->info->tupDesc)
159                 elog(ERROR, "GiST  doesn't support composite (weighted) type");
160
161         sign = palloc( len );
162         sign->flag = ARRKEY;
163         sign->size = s->nelems;
164
165         ptr = GETARR(sign);
166         for(i=0;i<s->nelems;i++)
167                 ptr[i] = DatumGetUInt32( FunctionCall1( &s->info->hashFunc, s->elems[i] ) );
168
169         /*
170          * there is a collision of hash-function; len is always equal or less than
171          * s->nelems
172          */
173         sign->size = uniqueint( GETARR(sign), sign->size, &sign->maxrepeat );
174         len = CALCGTSIZE( ARRKEY, sign->size );
175         SET_VARSIZE(sign, len);
176
177         return sign;
178 }
179
180 static int
181 HashedElemCmp(const void *va, const void *vb)
182 {
183         uint32  a = ((HashedElem *) va)->hash;
184         uint32  b = ((HashedElem *) vb)->hash;
185
186         if (a == b)
187         {
188                 double  ma = ((HashedElem *) va)->idfMin;
189                 double  mb = ((HashedElem *) va)->idfMin;
190
191                 if (ma == mb)
192                         return 0;
193
194                 return ( ma > mb ) ? 1 : -1;
195         }
196
197         return (a > b) ? 1 : -1;
198 }
199
200 static int
201 uniqueHashedElem(HashedElem *a, int32 l)
202 {
203         HashedElem      *ptr,
204                                 *res;
205
206         if (l <= 1)
207                 return l;
208
209         ptr = res = a;
210
211         qsort(a, l, sizeof(HashedElem), HashedElemCmp);
212
213         while (ptr - a < l)
214                 if (ptr->hash != res->hash)
215                         *(++res) = *ptr++;
216                 else
217                 {
218                         res->idfMax = ptr->idfMax;
219                         ptr++;
220                 }
221
222         return res + 1 - a;
223 }
224
225 static StatCache*
226 getHashedCache(void *cache)
227 {
228         StatCache *stat = getStat(cache, SIGLENBIT);
229
230         if ( stat->nhelems < 0 )
231         {
232                 int i;
233                 /*
234                  * Init
235                  */
236
237                 if (stat->info->tupDesc)
238                         elog(ERROR, "GiST  doesn't support composite (weighted) type");
239                 getFmgrInfoHash(stat->info);
240                 for(i=0;i<stat->nelems;i++)
241                 {
242                         uint32  hash = DatumGetUInt32( FunctionCall1( &stat->info->hashFunc, stat->elems[i].datum ) );
243                         int             index = HASHVAL(hash);
244
245                         stat->helems[i].hash = hash;
246                         stat->helems[i].idfMin = stat->helems[i].idfMax = stat->elems[i].idf;   
247                         if ( stat->selems[index].idfMin == 0.0 )
248                                 stat->selems[index].idfMin = stat->selems[index].idfMax = stat->elems[i].idf;
249                         else if ( stat->selems[index].idfMin > stat->elems[i].idf )
250                                 stat->selems[index].idfMin = stat->elems[i].idf;
251                         else if ( stat->selems[index].idfMax < stat->elems[i].idf )
252                                 stat->selems[index].idfMax = stat->elems[i].idf;
253                 }
254
255                 stat->nhelems = uniqueHashedElem( stat->helems, stat->nelems);
256         }
257
258         return stat;
259 }
260
261 static HashedElem*
262 getHashedElemIdf(StatCache *stat, uint32 hash, HashedElem *StopLow)
263 {
264         HashedElem      *StopMiddle,
265                                 *StopHigh = stat->helems + stat->nhelems;
266
267         if ( !StopLow )
268                 StopLow = stat->helems;
269
270         while (StopLow < StopHigh) {
271                 StopMiddle = StopLow + ((StopHigh - StopLow) >> 1);
272
273                 if ( StopMiddle->hash == hash )
274                         return StopMiddle;
275                 else if ( StopMiddle->hash < hash )
276                         StopLow = StopMiddle + 1;
277                 else
278                         StopHigh = StopMiddle;
279         }
280
281         return NULL;
282 }
283
284 static void
285 fillHashVal(void *cache, SimpleArray *a)
286 {
287         int i;
288
289         if (a->hash)
290                 return;
291
292         allocateHash(cache, a);
293
294         if (a->info->tupDesc)
295                 elog(ERROR, "GiST  doesn't support composite (weighted) type");
296         getFmgrInfoHash(a->info);
297
298         for(i=0;i<a->nelems;i++)
299                 a->hash[i] = DatumGetUInt32( FunctionCall1( &a->info->hashFunc, a->elems[i] ) );
300 }
301
302
303 static bool
304 hasHashedElem(SmlSign  *a, uint32 h)
305 {
306         uint32  *StopLow = GETARR(a),
307                         *StopHigh = GETARR(a) + a->size,
308                         *StopMiddle;
309
310         while (StopLow < StopHigh) {
311                 StopMiddle = StopLow + ((StopHigh - StopLow) >> 1);
312
313                 if ( *StopMiddle == h )
314                         return true;
315                 else if ( *StopMiddle < h )
316                         StopLow = StopMiddle + 1;
317                 else
318                         StopHigh = StopMiddle;
319         }
320
321         return false;
322 }
323
324 static void
325 makesign(BITVECP sign, SmlSign  *a)
326 {
327         int32   i;
328         uint32  *ptr = GETARR(a);
329
330         MemSet((void *) sign, 0, sizeof(BITVEC));
331         SETBIT(sign, SIGLENBIT);   /* set last unused bit */
332
333         for (i = 0; i < a->size; i++)
334                 HASH(sign, ptr[i]);
335 }
336
337 static int32
338 sizebitvec(BITVECP sign)
339 {
340         int32   size = 0,
341                         i;
342
343         LOOPBYTE
344                 size += number_of_ones[(unsigned char) sign[i]];
345
346         return size;
347 }
348
349 PG_FUNCTION_INFO_V1(gsmlsign_compress);
350 Datum gsmlsign_compress(PG_FUNCTION_ARGS);
351 Datum
352 gsmlsign_compress(PG_FUNCTION_ARGS)
353 {
354         GISTENTRY  *entry = (GISTENTRY *) PG_GETARG_POINTER(0);
355         GISTENTRY  *retval = entry;
356
357         if (entry->leafkey) /* new key */
358         {
359                 SmlSign         *sign;
360                 ArrayType       *a = DatumGetArrayTypeP(entry->key);
361
362                 sign = Array2HashedArray(NULL, a);
363
364                 if ( VARSIZE(sign) > TOAST_INDEX_TARGET )
365                 {       /* make signature due to its big size */
366                         SmlSign *tmpsign;
367                         int             len;
368
369                         len = CALCGTSIZE( SIGNKEY, sign->size );
370                         tmpsign = palloc( len );
371                         tmpsign->flag = SIGNKEY;
372                         SET_VARSIZE(tmpsign, len);
373
374                         makesign(GETSIGN(tmpsign), sign);
375                         tmpsign->size = sizebitvec(GETSIGN(tmpsign));
376                         tmpsign->maxrepeat = sign->maxrepeat;
377                         sign = tmpsign;
378                 }
379
380                 retval = (GISTENTRY *) palloc(sizeof(GISTENTRY));
381                 gistentryinit(*retval, PointerGetDatum(sign),
382                                                 entry->rel, entry->page,
383                                                 entry->offset, false);
384         }
385         else if ( ISSIGNKEY(DatumGetPointer(entry->key)) &&
386                                 !ISALLTRUE(DatumGetPointer(entry->key)) )
387         {
388                 SmlSign *sign = (SmlSign*)DatumGetPointer(entry->key);
389
390                 Assert( sign->size == sizebitvec(GETSIGN(sign)) );
391
392                 if ( sign->size == SIGLENBIT )
393                 {
394                         int32   len = CALCGTSIZE(SIGNKEY | ALLISTRUE, 0);
395                         int32   maxrepeat = sign->maxrepeat;
396
397                         sign = (SmlSign *) palloc(len);
398                         SET_VARSIZE(sign, len);
399                         sign->flag = SIGNKEY | ALLISTRUE;
400                         sign->size = SIGLENBIT;
401                         sign->maxrepeat = maxrepeat;
402
403                         retval = (GISTENTRY *) palloc(sizeof(GISTENTRY));
404
405                         gistentryinit(*retval, PointerGetDatum(sign),
406                                                         entry->rel, entry->page,
407                                                         entry->offset, false);
408                 }
409         }
410
411         PG_RETURN_POINTER(retval);
412 }
413
414 PG_FUNCTION_INFO_V1(gsmlsign_decompress);
415 Datum gsmlsign_decompress(PG_FUNCTION_ARGS);
416 Datum
417 gsmlsign_decompress(PG_FUNCTION_ARGS)
418 {
419         GISTENTRY       *entry = (GISTENTRY *) PG_GETARG_POINTER(0);
420         SmlSign         *key =  (SmlSign*)DatumGetPointer(PG_DETOAST_DATUM(entry->key));
421
422         if (key != (SmlSign *) DatumGetPointer(entry->key))
423         {
424                 GISTENTRY  *retval = (GISTENTRY *) palloc(sizeof(GISTENTRY));
425
426                 gistentryinit(*retval, PointerGetDatum(key),
427                                                 entry->rel, entry->page,
428                                                 entry->offset, false);
429
430                 PG_RETURN_POINTER(retval);
431         }
432
433         PG_RETURN_POINTER(entry);
434 }
435
436 /*
437  * Union method
438  */
439 static bool
440 unionkey(BITVECP sbase, SmlSign *add)
441 {
442         int32   i;
443
444         if (ISSIGNKEY(add))
445         {
446                 BITVECP sadd = GETSIGN(add);
447
448                 if (ISALLTRUE(add))
449                         return true;
450
451                 LOOPBYTE
452                         sbase[i] |= sadd[i];
453         }
454         else
455         {
456                 uint32  *ptr = GETARR(add);
457
458                 for (i = 0; i < add->size; i++)
459                         HASH(sbase, ptr[i]);
460         }
461
462         return false;
463 }
464
465 PG_FUNCTION_INFO_V1(gsmlsign_union);
466 Datum gsmlsign_union(PG_FUNCTION_ARGS);
467 Datum
468 gsmlsign_union(PG_FUNCTION_ARGS)
469 {
470         GistEntryVector *entryvec = (GistEntryVector *) PG_GETARG_POINTER(0);
471         int                             *size = (int *) PG_GETARG_POINTER(1);
472         BITVEC                  base;
473         int32           i,
474                                 len,
475                                 maxrepeat = 1;
476         int32           flag = 0;
477         SmlSign    *result;
478
479         MemSet((void *) base, 0, sizeof(BITVEC));
480         for (i = 0; i < entryvec->n; i++)
481         {
482                 if (GETENTRY(entryvec, i)->maxrepeat > maxrepeat)
483                         maxrepeat = GETENTRY(entryvec, i)->maxrepeat;
484                 if (unionkey(base, GETENTRY(entryvec, i)))
485                 {
486                         flag = ALLISTRUE;
487                         break;
488                 }
489         }
490
491         flag |= SIGNKEY;
492         len = CALCGTSIZE(flag, 0);
493         result = (SmlSign *) palloc(len);
494         *size = len;
495         SET_VARSIZE(result, len);
496         result->flag = flag;
497         result->maxrepeat = maxrepeat;
498
499         if (!ISALLTRUE(result))
500         {
501                 memcpy((void *) GETSIGN(result), (void *) base, sizeof(BITVEC));
502                 result->size = sizebitvec(GETSIGN(result));
503         }
504         else
505                 result->size = SIGLENBIT;
506
507         PG_RETURN_POINTER(result);
508 }
509
510 /*
511  * Same method
512  */
513
514 PG_FUNCTION_INFO_V1(gsmlsign_same);
515 Datum gsmlsign_same(PG_FUNCTION_ARGS);
516 Datum
517 gsmlsign_same(PG_FUNCTION_ARGS)
518 {
519         SmlSign *a = (SmlSign*)PG_GETARG_POINTER(0);
520         SmlSign *b = (SmlSign*)PG_GETARG_POINTER(1);
521         bool    *result = (bool *) PG_GETARG_POINTER(2);
522
523         if (a->size != b->size)
524         {
525                 *result = false;
526         }
527         else if (ISSIGNKEY(a))
528         {       /* then b also ISSIGNKEY */
529                 if ( ISALLTRUE(a) )
530                 {
531                         /* in this case b is all true too - other cases is catched
532                            upper */
533                         *result = true;
534                 }
535                 else
536                 {
537                         int32   i;
538                         BITVECP sa = GETSIGN(a),
539                                         sb = GETSIGN(b);
540
541                         *result = true;
542
543                         if ( !ISALLTRUE(a) )
544                         {
545                                 LOOPBYTE
546                                 {
547                                         if (sa[i] != sb[i])
548                                         {
549                                                 *result = false;
550                                                 break;
551                                         }
552                                 }
553                         }
554                 }
555         }
556         else
557         {
558                 uint32  *ptra = GETARR(a),
559                                 *ptrb = GETARR(b);
560                 int32   i;
561
562                 *result = true;
563                 for (i = 0; i < a->size; i++)
564                 {
565                         if ( ptra[i] != ptrb[i])
566                         {
567                                 *result = false;
568                                 break;
569                         }
570                 }
571         }
572
573         PG_RETURN_POINTER(result);
574 }
575
576 /*
577  * Penalty method
578  */
579 static int
580 hemdistsign(BITVECP a, BITVECP b)
581 {
582         int     i,
583                 diff,
584                 dist = 0;
585
586         LOOPBYTE
587         {
588                 diff = (unsigned char) (a[i] ^ b[i]);
589                 dist += number_of_ones[diff];
590         }
591         return dist;
592 }
593
594 static int
595 hemdist(SmlSign *a, SmlSign *b)
596 {
597         if (ISALLTRUE(a))
598         {
599                 if (ISALLTRUE(b))
600                         return 0;
601                 else
602                         return SIGLENBIT - b->size;
603         }
604         else if (ISALLTRUE(b))
605                 return SIGLENBIT - b->size;
606
607         return hemdistsign(GETSIGN(a), GETSIGN(b));
608 }
609
610 PG_FUNCTION_INFO_V1(gsmlsign_penalty);
611 Datum gsmlsign_penalty(PG_FUNCTION_ARGS);
612 Datum
613 gsmlsign_penalty(PG_FUNCTION_ARGS)
614 {
615         GISTENTRY       *origentry = (GISTENTRY *) PG_GETARG_POINTER(0); /* always ISSIGNKEY */
616         GISTENTRY       *newentry = (GISTENTRY *) PG_GETARG_POINTER(1);
617         float           *penalty = (float *) PG_GETARG_POINTER(2);
618         SmlSign         *origval = (SmlSign *) DatumGetPointer(origentry->key);
619         SmlSign         *newval = (SmlSign *) DatumGetPointer(newentry->key);
620         BITVECP         orig = GETSIGN(origval);
621
622         *penalty = 0.0;
623
624         if (ISARRKEY(newval))
625         {
626                 BITVEC  sign;
627
628                 makesign(sign, newval);
629
630                 if (ISALLTRUE(origval))
631                         *penalty = ((float) (SIGLENBIT - sizebitvec(sign))) / (float) (SIGLENBIT + 1);
632                 else
633                         *penalty = hemdistsign(sign, orig);
634         }
635         else
636                 *penalty = hemdist(origval, newval);
637
638         PG_RETURN_POINTER(penalty);
639 }
640
641 /*
642  * Picksplit method
643  */
644
645 typedef struct
646 {
647         bool    allistrue;
648         int32   size;
649         BITVEC  sign;
650 } CACHESIGN;
651
652 static void
653 fillcache(CACHESIGN *item, SmlSign *key)
654 {
655         item->allistrue = false;
656         item->size = key->size;
657
658         if (ISARRKEY(key))
659         {
660                 makesign(item->sign, key);
661                 item->size = sizebitvec( item->sign );
662         }
663         else if (ISALLTRUE(key))
664                 item->allistrue = true;
665         else
666                 memcpy((void *) item->sign, (void *) GETSIGN(key), sizeof(BITVEC));
667 }
668
669 #define WISH_F(a,b,c) (double)( -(double)(((a)-(b))*((a)-(b))*((a)-(b)))*(c) )
670
671 typedef struct
672 {
673         OffsetNumber pos;
674         int32           cost;
675 } SPLITCOST;
676
677 static int
678 comparecost(const void *va, const void *vb)
679 {
680         SPLITCOST  *a = (SPLITCOST *) va;
681         SPLITCOST  *b = (SPLITCOST *) vb;
682
683         if (a->cost == b->cost)
684                 return 0;
685         else
686                 return (a->cost > b->cost) ? 1 : -1;
687 }
688
689 static int
690 hemdistcache(CACHESIGN *a, CACHESIGN *b)
691 {
692         if (a->allistrue)
693         {
694                 if (b->allistrue)
695                         return 0;
696                 else
697                         return SIGLENBIT - b->size;
698         }
699         else if (b->allistrue)
700                 return SIGLENBIT - a->size;
701
702         return hemdistsign(a->sign, b->sign);
703 }
704
705 PG_FUNCTION_INFO_V1(gsmlsign_picksplit);
706 Datum gsmlsign_picksplit(PG_FUNCTION_ARGS);
707 Datum
708 gsmlsign_picksplit(PG_FUNCTION_ARGS)
709 {
710         GistEntryVector *entryvec = (GistEntryVector *) PG_GETARG_POINTER(0);
711         GIST_SPLITVEC *v = (GIST_SPLITVEC *) PG_GETARG_POINTER(1);
712         OffsetNumber k,
713                                 j;
714         SmlSign         *datum_l,
715                                 *datum_r;
716         BITVECP         union_l,
717                                 union_r;
718         int32           size_alpha,
719                                 size_beta;
720         int32           size_waste,
721                                 waste = -1;
722         int32           nbytes;
723         OffsetNumber seed_1 = 0,
724                                 seed_2 = 0;
725         OffsetNumber *left,
726                                 *right;
727         OffsetNumber maxoff;
728         BITVECP         ptr;
729         int                     i;
730         CACHESIGN  *cache;
731         SPLITCOST  *costvector;
732
733         maxoff = entryvec->n - 2;
734         nbytes = (maxoff + 2) * sizeof(OffsetNumber);
735         v->spl_left = (OffsetNumber *) palloc(nbytes);
736         v->spl_right = (OffsetNumber *) palloc(nbytes);
737
738         cache = (CACHESIGN *) palloc(sizeof(CACHESIGN) * (maxoff + 2));
739         fillcache(&cache[FirstOffsetNumber], GETENTRY(entryvec, FirstOffsetNumber));
740
741         for (k = FirstOffsetNumber; k < maxoff; k = OffsetNumberNext(k))
742         {
743                 for (j = OffsetNumberNext(k); j <= maxoff; j = OffsetNumberNext(j))
744                 {
745                         if (k == FirstOffsetNumber)
746                                 fillcache(&cache[j], GETENTRY(entryvec, j));
747
748                         size_waste = hemdistcache(&(cache[j]), &(cache[k]));
749                         if (size_waste > waste)
750                         {
751                                 waste = size_waste;
752                                 seed_1 = k;
753                                 seed_2 = j;
754                         }
755                 }
756         }
757
758         left = v->spl_left;
759         v->spl_nleft = 0;
760         right = v->spl_right;
761         v->spl_nright = 0;
762
763         if (seed_1 == 0 || seed_2 == 0)
764         {
765                 seed_1 = 1;
766                 seed_2 = 2;
767         }
768
769         /* form initial .. */
770         if (cache[seed_1].allistrue)
771         {
772                 datum_l = (SmlSign *) palloc(CALCGTSIZE(SIGNKEY | ALLISTRUE, 0));
773                 SET_VARSIZE(datum_l, CALCGTSIZE(SIGNKEY | ALLISTRUE, 0));
774                 datum_l->flag = SIGNKEY | ALLISTRUE;
775                 datum_l->size = SIGLENBIT;
776         }
777         else
778         {
779                 datum_l = (SmlSign *) palloc(CALCGTSIZE(SIGNKEY, 0));
780                 SET_VARSIZE(datum_l, CALCGTSIZE(SIGNKEY, 0));
781                 datum_l->flag = SIGNKEY;
782                 memcpy((void *) GETSIGN(datum_l), (void *) cache[seed_1].sign, sizeof(BITVEC));
783                 datum_l->size = cache[seed_1].size;
784         }
785         if (cache[seed_2].allistrue)
786         {
787                 datum_r = (SmlSign *) palloc(CALCGTSIZE(SIGNKEY | ALLISTRUE, 0));
788                 SET_VARSIZE(datum_r, CALCGTSIZE(SIGNKEY | ALLISTRUE, 0));
789                 datum_r->flag = SIGNKEY | ALLISTRUE;
790                 datum_r->size = SIGLENBIT;
791         }
792         else
793         {
794                 datum_r = (SmlSign *) palloc(CALCGTSIZE(SIGNKEY, 0));
795                 SET_VARSIZE(datum_r, CALCGTSIZE(SIGNKEY, 0));
796                 datum_r->flag = SIGNKEY;
797                 memcpy((void *) GETSIGN(datum_r), (void *) cache[seed_2].sign, sizeof(BITVEC));
798                 datum_r->size = cache[seed_2].size;
799         }
800
801         union_l = GETSIGN(datum_l);
802         union_r = GETSIGN(datum_r);
803         maxoff = OffsetNumberNext(maxoff);
804         fillcache(&cache[maxoff], GETENTRY(entryvec, maxoff));
805         /* sort before ... */
806         costvector = (SPLITCOST *) palloc(sizeof(SPLITCOST) * maxoff);
807         for (j = FirstOffsetNumber; j <= maxoff; j = OffsetNumberNext(j))
808         {
809                 costvector[j - 1].pos = j;
810                 size_alpha = hemdistcache(&(cache[seed_1]), &(cache[j]));
811                 size_beta = hemdistcache(&(cache[seed_2]), &(cache[j]));
812                 costvector[j - 1].cost = Abs(size_alpha - size_beta);
813         }
814         qsort((void *) costvector, maxoff, sizeof(SPLITCOST), comparecost);
815
816         datum_l->maxrepeat = datum_r->maxrepeat = 1;
817
818         for (k = 0; k < maxoff; k++)
819         {
820                 j = costvector[k].pos;
821                 if (j == seed_1)
822                 {
823                         *left++ = j;
824                         v->spl_nleft++;
825                         continue;
826                 }
827                 else if (j == seed_2)
828                 {
829                         *right++ = j;
830                         v->spl_nright++;
831                         continue;
832                 }
833
834                 if (ISALLTRUE(datum_l) || cache[j].allistrue)
835                 {
836                         if (ISALLTRUE(datum_l) && cache[j].allistrue)
837                                 size_alpha = 0;
838                         else
839                                 size_alpha = SIGLENBIT - (
840                                                                         (cache[j].allistrue) ? datum_l->size : cache[j].size
841                                                         );
842                 }
843                 else
844                         size_alpha = hemdistsign(cache[j].sign, GETSIGN(datum_l));
845
846                 if (ISALLTRUE(datum_r) || cache[j].allistrue)
847                 {
848                         if (ISALLTRUE(datum_r) && cache[j].allistrue)
849                                 size_beta = 0;
850                         else
851                                 size_beta = SIGLENBIT - (
852                                                                         (cache[j].allistrue) ? datum_r->size : cache[j].size
853                                                         );
854                 }
855                 else
856                         size_beta = hemdistsign(cache[j].sign, GETSIGN(datum_r));
857
858                 if (size_alpha < size_beta + WISH_F(v->spl_nleft, v->spl_nright, 0.1))
859                 {
860                         if (ISALLTRUE(datum_l) || cache[j].allistrue)
861                         {
862                                 if (!ISALLTRUE(datum_l))
863                                         MemSet((void *) GETSIGN(datum_l), 0xff, sizeof(BITVEC));
864                                 datum_l->size = SIGLENBIT;
865                         }
866                         else
867                         {
868                                 ptr = cache[j].sign;
869                                 LOOPBYTE
870                                         union_l[i] |= ptr[i];
871                                 datum_l->size = sizebitvec(union_l);
872                         }
873                         *left++ = j;
874                         v->spl_nleft++;
875                 }
876                 else
877                 {
878                         if (ISALLTRUE(datum_r) || cache[j].allistrue)
879                         {
880                                 if (!ISALLTRUE(datum_r))
881                                         MemSet((void *) GETSIGN(datum_r), 0xff, sizeof(BITVEC));
882                                 datum_r->size = SIGLENBIT;
883                         }
884                         else
885                         {
886                                 ptr = cache[j].sign;
887                                 LOOPBYTE
888                                         union_r[i] |= ptr[i];
889                                 datum_r->size = sizebitvec(union_r);
890                         }
891                         *right++ = j;
892                         v->spl_nright++;
893                 }
894         }
895         *right = *left = FirstOffsetNumber;
896         v->spl_ldatum = PointerGetDatum(datum_l);
897         v->spl_rdatum = PointerGetDatum(datum_r);
898
899         Assert( datum_l->size = sizebitvec(GETSIGN(datum_l)) );
900         Assert( datum_r->size = sizebitvec(GETSIGN(datum_r)) );
901
902         PG_RETURN_POINTER(v);
903 }
904
905 static double
906 getIdfMaxLimit(SmlSign *key)
907 {
908         switch( getTFMethod() )
909         {
910                 case TF_CONST:
911                         return 1.0;
912                         break;
913                 case TF_N:
914                         return (double)(key->maxrepeat);
915                         break;
916                 case TF_LOG:
917                         return 1.0 + log( (double)(key->maxrepeat) );
918                         break;
919                 default:
920                         elog(ERROR,"Unknown TF method: %d", getTFMethod());
921         }
922
923         return 0.0;
924 }
925
926 /*
927  * Consistent function
928  */
929 PG_FUNCTION_INFO_V1(gsmlsign_consistent);
930 Datum gsmlsign_consistent(PG_FUNCTION_ARGS);
931 Datum
932 gsmlsign_consistent(PG_FUNCTION_ARGS)
933 {
934         GISTENTRY               *entry = (GISTENTRY *) PG_GETARG_POINTER(0);
935         StrategyNumber  strategy = (StrategyNumber) PG_GETARG_UINT16(2);
936         bool                    *recheck = (bool *) PG_GETARG_POINTER(4);
937         ArrayType               *a;
938         SmlSign                 *key = (SmlSign*)DatumGetPointer(entry->key);
939         int                             res = false;
940         SmlSign                 *query;
941         SimpleArray             *s;
942         int32                   i;
943
944         fcinfo->flinfo->fn_extra = SearchArrayCache(
945                                                                         fcinfo->flinfo->fn_extra,
946                                                                         fcinfo->flinfo->fn_mcxt,
947                                                                         PG_GETARG_DATUM(1), &a, &s, &query);
948
949         *recheck = true;
950
951         if ( ARRISVOID(a) )
952                 PG_RETURN_BOOL(res);
953
954         if ( strategy == SmlarOverlapStrategy )
955         {
956                 if (ISALLTRUE(key))
957                 {
958                         res = true;
959                 }
960                 else if (ISARRKEY(key))
961                 {
962                         uint32  *kptr = GETARR(key),
963                                         *qptr = GETARR(query);
964
965                         while( kptr - GETARR(key) < key->size && qptr - GETARR(query) < query->size )
966                         {
967                                 if ( *kptr < *qptr )
968                                         kptr++;
969                                 else if ( *kptr > *qptr )
970                                         qptr++;
971                                 else
972                                 {
973                                         res = true;
974                                         break;
975                                 }
976                         }
977                         *recheck = false;
978                 }
979                 else
980                 {
981                         BITVECP sign = GETSIGN(key);
982
983                         fillHashVal(fcinfo->flinfo->fn_extra, s);
984
985                         for(i=0; i<s->nelems; i++)
986                         {
987                                 if ( GETBIT(sign, HASHVAL(s->hash[i])) )
988                                 {
989                                         res = true;
990                                         break;
991                                 }
992                         }
993                 }
994         }
995         /*
996          *  SmlarSimilarityStrategy
997          */
998         else if (ISALLTRUE(key))
999         {
1000                 if ( GIST_LEAF(entry) )
1001                 {
1002                         /*
1003                          * With TF/IDF similarity we cannot say anything useful
1004                          */
1005                         if ( query->size < SIGLENBIT && getSmlType() != ST_TFIDF )
1006                         {
1007                                 double power = ((double)(query->size)) * ((double)(SIGLENBIT));
1008
1009                                 if ( ((double)(query->size)) / sqrt(power) >= GetSmlarLimit() ) 
1010                                         res = true;
1011                         }
1012                         else
1013                         {
1014                                 res = true;     
1015                         }
1016                 }
1017                 else
1018                         res = true;
1019         }
1020         else if (ISARRKEY(key))
1021         {
1022                 uint32  *kptr = GETARR(key),
1023                                 *qptr = GETARR(query);
1024
1025                 Assert( GIST_LEAF(entry) );
1026
1027                 switch(getSmlType())
1028                 {
1029                         case ST_TFIDF:
1030                                 {
1031                                         StatCache       *stat = getHashedCache(fcinfo->flinfo->fn_extra);
1032                                         double          sumU = 0.0,
1033                                                                 sumQ = 0.0,
1034                                                                 sumK = 0.0;
1035                                         double          maxKTF = getIdfMaxLimit(key);
1036                                         HashedElem  *h;
1037
1038                                         Assert( s->df );
1039                                         fillHashVal(fcinfo->flinfo->fn_extra, s);
1040                                         if ( stat->info != s->info )
1041                                                 elog(ERROR,"Statistic and actual argument have different type");
1042
1043                                         for(i=0;i<s->nelems;i++)
1044                                         {
1045                                                 sumQ += s->df[i] * s->df[i];
1046
1047                                                 h = getHashedElemIdf(stat, s->hash[i], NULL);
1048                                                 if ( h && hasHashedElem(key, s->hash[i]) )
1049                                                 {
1050                                                         sumK += h->idfMin * h->idfMin;
1051                                                         sumU += h->idfMax * maxKTF * s->df[i];
1052                                                 } 
1053                                         }
1054
1055                                         if ( sumK > 0.0 && sumQ > 0.0 && sumU / sqrt( sumK * sumQ ) >= GetSmlarLimit() )
1056                                         {
1057                                                 /* 
1058                                                  * More precisely calculate sumK
1059                                                  */
1060                                                 h = NULL;
1061                                                 sumK = 0.0;
1062
1063                                                 for(i=0;i<key->size;i++)
1064                                                 {
1065                                                         h = getHashedElemIdf(stat, GETARR(key)[i], h);                                  
1066                                                         if (h)
1067                                                                 sumK += h->idfMin * h->idfMin;
1068                                                 }
1069
1070                                                 if ( sumK > 0.0 && sumQ > 0.0 && sumU / sqrt( sumK * sumQ ) >= GetSmlarLimit() )
1071                                                         res = true;
1072                                         }
1073                                 }
1074                                 break;
1075                         case ST_COSINE:
1076                                 {
1077                                         double                  power;
1078                                         power = sqrt( ((double)(key->size)) * ((double)(s->nelems)) );
1079
1080                                         if (  ((double)Min(key->size, s->nelems)) / power >= GetSmlarLimit() )
1081                                         {
1082                                                 int  cnt = 0;
1083
1084                                                 while( kptr - GETARR(key) < key->size && qptr - GETARR(query) < query->size )
1085                                                 {
1086                                                         if ( *kptr < *qptr )
1087                                                                 kptr++;
1088                                                         else if ( *kptr > *qptr )
1089                                                                 qptr++;
1090                                                         else
1091                                                         {
1092                                                                 cnt++;
1093                                                                 kptr++;
1094                                                                 qptr++;
1095                                                         }
1096                                                 }
1097
1098                                                 if ( ((double)cnt) / power >= GetSmlarLimit() )
1099                                                         res = true;
1100                                         }
1101                                 }
1102                                 break;
1103                         default:
1104                                 elog(ERROR,"GiST doesn't support current formula type of similarity");
1105                 }
1106         }
1107         else
1108         {       /* signature */
1109                 BITVECP sign = GETSIGN(key);
1110                 int32   count = 0;
1111
1112                 fillHashVal(fcinfo->flinfo->fn_extra, s);
1113
1114                 if ( GIST_LEAF(entry) )
1115                 {
1116                         switch(getSmlType())
1117                         {
1118                                 case ST_TFIDF:
1119                                         {
1120                                                 StatCache       *stat = getHashedCache(fcinfo->flinfo->fn_extra);
1121                                                 double          sumU = 0.0,
1122                                                                         sumQ = 0.0,
1123                                                                         sumK = 0.0;
1124                                                 double          maxKTF = getIdfMaxLimit(key);
1125
1126                                                 Assert( s->df );
1127                                                 if ( stat->info != s->info )
1128                                                         elog(ERROR,"Statistic and actual argument have different type");
1129
1130                                                 for(i=0;i<s->nelems;i++)
1131                                                 {
1132                                                         int32           hbit = HASHVAL(s->hash[i]);
1133
1134                                                         sumQ += s->df[i] * s->df[i];
1135                                                         if ( GETBIT(sign, hbit) )
1136                                                         {
1137                                                                 sumK += stat->selems[ hbit ].idfMin * stat->selems[ hbit ].idfMin;
1138                                                                 sumU += stat->selems[ hbit ].idfMax * maxKTF * s->df[i];
1139                                                         }
1140                                                 }
1141
1142                                                 if ( sumK > 0.0 && sumQ > 0.0 && sumU / sqrt( sumK * sumQ ) >= GetSmlarLimit() )
1143                                                 {
1144                                                         /* 
1145                                                          * More precisely calculate sumK
1146                                                          */
1147                                                         sumK = 0.0;
1148
1149                                                         for(i=0;i<SIGLENBIT;i++)
1150                                                                 if ( GETBIT(sign,i) )
1151                                                                         sumK += stat->selems[ i ].idfMin * stat->selems[ i ].idfMin;
1152
1153                                                         if ( sumK > 0.0 && sumQ > 0.0 && sumU / sqrt( sumK * sumQ ) >= GetSmlarLimit() )
1154                                                                 res = true;
1155                                                 }
1156                                         }
1157                                         break;
1158                                 case ST_COSINE:
1159                                         {
1160                                                 double  power;
1161
1162                                                 power = sqrt( ((double)(key->size)) * ((double)(s->nelems)) );
1163
1164                                                 for(i=0; i<s->nelems; i++)
1165                                                         count += GETBIT(sign, HASHVAL(s->hash[i]));
1166
1167                                                 if ( ((double)count) / power >= GetSmlarLimit() )
1168                                                         res = true;
1169                                         }
1170                                         break;
1171                                 default:
1172                                         elog(ERROR,"GiST doesn't support current formula type of similarity");
1173                         }
1174                 }
1175                 else /* non-leaf page */
1176                 {
1177                         switch(getSmlType())
1178                         {
1179                                 case ST_TFIDF:
1180                                         {
1181                                                 StatCache       *stat = getHashedCache(fcinfo->flinfo->fn_extra);
1182                                                 double          sumU = 0.0,
1183                                                                         sumQ = 0.0,
1184                                                                         minK = -1.0;
1185                                                 double          maxKTF = getIdfMaxLimit(key);
1186
1187                                                 Assert( s->df );
1188                                                 if ( stat->info != s->info )
1189                                                         elog(ERROR,"Statistic and actual argument have different type");
1190
1191                                                 for(i=0;i<s->nelems;i++)
1192                                                 {
1193                                                         int32           hbit = HASHVAL(s->hash[i]);
1194
1195                                                         sumQ += s->df[i] * s->df[i];
1196                                                         if ( GETBIT(sign, hbit) )
1197                                                         {
1198                                                                 sumU += stat->selems[ hbit ].idfMax * maxKTF * s->df[i];
1199                                                                 if ( minK > stat->selems[ hbit ].idfMin  || minK < 0.0 )
1200                                                                         minK = stat->selems[ hbit ].idfMin;
1201                                                         }
1202                                                 }
1203
1204                                                 if ( sumQ > 0.0 && minK > 0.0 && sumU / sqrt( sumQ * minK ) >= GetSmlarLimit() )
1205                                                         res = true;
1206                                         }
1207                                         break;
1208                                 case ST_COSINE:
1209                                         {
1210                                                 for(i=0; i<s->nelems; i++)
1211                                                         count += GETBIT(sign, HASHVAL(s->hash[i]));
1212
1213                                                 if ( s->nelems == count  || sqrt(((double)count) / ((double)(s->nelems))) >= GetSmlarLimit() )
1214                                                         res = true;
1215                                         }
1216                                         break;
1217                                 default:
1218                                         elog(ERROR,"GiST doesn't support current formula type of similarity");
1219                         }
1220                 }
1221         }
1222
1223 #if 0
1224         {
1225                 static int nnres = 0;
1226                 static int nres = 0;
1227                 if ( GIST_LEAF(entry) ) {
1228                         if ( res )
1229                                 nres++;
1230                         else
1231                                 nnres++;
1232                         elog(NOTICE,"%s nn:%d n:%d", (ISARRKEY(key)) ? "ARR" : ( (ISALLTRUE(key)) ? "TRUE" : "SIGN" ), nnres, nres  );
1233                 }
1234         }
1235 #endif
1236
1237         PG_RETURN_BOOL(res);
1238 }