b1d907393c1f79ca1e280928d2600b48e1887473
[smlar.git] / smlar_gist.c
1 #include "smlar.h"
2
3 #include "fmgr.h"
4 #include "access/gist.h"
5 #include "access/skey.h"
6 #include "access/tuptoaster.h"
7 #include "utils/memutils.h"
8
9 typedef struct SmlSign {
10         int32   vl_len_; /* varlena header (do not touch directly!) */
11         int32   flag:8,
12                         size:24;
13         int32   maxrepeat;
14         char    data[1];
15 } SmlSign;
16
17 #define SMLSIGNHDRSZ    (offsetof(SmlSign, data))
18
19 #define BITBYTE 8
20 #define SIGLENINT  61
21 #define SIGLEN  ( sizeof(int)*SIGLENINT )
22 #define SIGLENBIT (SIGLEN*BITBYTE - 1)  /* see makesign */
23 typedef char BITVEC[SIGLEN];
24 typedef char *BITVECP;
25 #define LOOPBYTE \
26                 for(i=0;i<SIGLEN;i++)
27
28 #define GETBYTE(x,i) ( *( (BITVECP)(x) + (int)( (i) / BITBYTE ) ) )
29 #define GETBITBYTE(x,i) ( ((char)(x)) >> i & 0x01 )
30 #define CLRBIT(x,i)   GETBYTE(x,i) &= ~( 0x01 << ( (i) % BITBYTE ) )
31 #define SETBIT(x,i)   GETBYTE(x,i) |=  ( 0x01 << ( (i) % BITBYTE ) )
32 #define GETBIT(x,i) ( (GETBYTE(x,i) >> ( (i) % BITBYTE )) & 0x01 )
33
34 #define HASHVAL(val) (((unsigned int)(val)) % SIGLENBIT)
35 #define HASH(sign, val) SETBIT((sign), HASHVAL(val))
36
37 #define ARRKEY                  0x01
38 #define SIGNKEY                 0x02
39 #define ALLISTRUE               0x04
40
41 #define ISARRKEY(x)             ( ((SmlSign*)x)->flag & ARRKEY )
42 #define ISSIGNKEY(x)    ( ((SmlSign*)x)->flag & SIGNKEY )
43 #define ISALLTRUE(x)    ( ((SmlSign*)x)->flag & ALLISTRUE )
44
45 #define CALCGTSIZE(flag, len)   ( SMLSIGNHDRSZ + ( ( (flag) & ARRKEY ) ? ((len)*sizeof(uint32)) : (((flag) & ALLISTRUE) ? 0 : SIGLEN) ) )
46 #define GETSIGN(x)                              ( (BITVECP)( (char*)x+SMLSIGNHDRSZ ) )
47 #define GETARR(x)                               ( (uint32*)( (char*)x+SMLSIGNHDRSZ ) )
48
49 #define GETENTRY(vec,pos) ((SmlSign *) DatumGetPointer((vec)->vector[(pos)].key))
50
51 /*
52  * Fake IO
53  */
54 PG_FUNCTION_INFO_V1(gsmlsign_in);
55 Datum   gsmlsign_in(PG_FUNCTION_ARGS);
56 Datum
57 gsmlsign_in(PG_FUNCTION_ARGS)
58 {
59         elog(ERROR, "not implemented");
60         PG_RETURN_DATUM(0);
61 }
62
63 PG_FUNCTION_INFO_V1(gsmlsign_out);
64 Datum   gsmlsign_out(PG_FUNCTION_ARGS);
65 Datum
66 gsmlsign_out(PG_FUNCTION_ARGS)
67 {
68         elog(ERROR, "not implemented");
69         PG_RETURN_DATUM(0);
70 }
71
72 /*
73  * Compress/decompress
74  */
75
76 /* Number of one-bits in an unsigned byte */
77 static const uint8 number_of_ones[256] = {
78         0, 1, 1, 2, 1, 2, 2, 3, 1, 2, 2, 3, 2, 3, 3, 4,
79         1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
80         1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
81         2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
82         1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
83         2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
84         2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
85         3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
86         1, 2, 2, 3, 2, 3, 3, 4, 2, 3, 3, 4, 3, 4, 4, 5,
87         2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
88         2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
89         3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
90         2, 3, 3, 4, 3, 4, 4, 5, 3, 4, 4, 5, 4, 5, 5, 6,
91         3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
92         3, 4, 4, 5, 4, 5, 5, 6, 4, 5, 5, 6, 5, 6, 6, 7,
93         4, 5, 5, 6, 5, 6, 6, 7, 5, 6, 6, 7, 6, 7, 7, 8
94 };
95
96 static int
97 compareint(const void *va, const void *vb)
98 {
99         uint32  a = *((uint32 *) va);
100         uint32  b = *((uint32 *) vb);
101
102         if (a == b)
103                 return 0;
104         return (a > b) ? 1 : -1;
105 }
106
107 /*
108  * Removes duplicates from an array of int32. 'l' is
109  * size of the input array. Returns the new size of the array.
110  */
111 static int
112 uniqueint(uint32 *a, int32 l, int32 *max)
113 {
114         uint32          *ptr,
115                                 *res;
116         int32           cnt = 0;
117
118         *max = 1;
119
120         if (l <= 1)
121                 return l;
122
123         ptr = res = a;
124
125         qsort((void *) a, l, sizeof(uint32), compareint);
126
127         while (ptr - a < l)
128                 if (*ptr != *res)
129                 {
130                         cnt = 1;
131                         *(++res) = *ptr++;
132                 }
133                 else
134                 {
135                         cnt++;
136                         if ( cnt > *max )
137                                 *max = cnt;
138                         ptr++;
139                 }
140
141         if ( cnt > *max )
142                 *max = cnt;
143
144         return res + 1 - a;
145 }
146
147 SmlSign*
148 Array2HashedArray(ProcTypeInfo info, ArrayType *a)
149 {
150         SimpleArray *s = Array2SimpleArray(info, a);
151         SmlSign         *sign;
152         int32           len, i;
153         uint32          *ptr;
154
155         len = CALCGTSIZE( ARRKEY, s->nelems );
156
157         getFmgrInfoHash(s->info);
158         if (s->info->tupDesc)
159                 elog(ERROR, "GiST  doesn't support composite (weighted) type");
160
161         sign = palloc( len );
162         sign->flag = ARRKEY;
163         sign->size = s->nelems;
164
165         ptr = GETARR(sign);
166         for(i=0;i<s->nelems;i++)
167                 ptr[i] = DatumGetUInt32(FunctionCall1Coll(&s->info->hashFunc,
168                                                                                                   DEFAULT_COLLATION_OID,
169                                                                                                   s->elems[i]));
170
171         /*
172          * there is a collision of hash-function; len is always equal or less than
173          * s->nelems
174          */
175         sign->size = uniqueint( GETARR(sign), sign->size, &sign->maxrepeat );
176         len = CALCGTSIZE( ARRKEY, sign->size );
177         SET_VARSIZE(sign, len);
178
179         return sign;
180 }
181
182 static int
183 HashedElemCmp(const void *va, const void *vb)
184 {
185         uint32  a = ((HashedElem *) va)->hash;
186         uint32  b = ((HashedElem *) vb)->hash;
187
188         if (a == b)
189         {
190                 double  ma = ((HashedElem *) va)->idfMin;
191                 double  mb = ((HashedElem *) va)->idfMin;
192
193                 if (ma == mb)
194                         return 0;
195
196                 return ( ma > mb ) ? 1 : -1;
197         }
198
199         return (a > b) ? 1 : -1;
200 }
201
202 static int
203 uniqueHashedElem(HashedElem *a, int32 l)
204 {
205         HashedElem      *ptr,
206                                 *res;
207
208         if (l <= 1)
209                 return l;
210
211         ptr = res = a;
212
213         qsort(a, l, sizeof(HashedElem), HashedElemCmp);
214
215         while (ptr - a < l)
216                 if (ptr->hash != res->hash)
217                         *(++res) = *ptr++;
218                 else
219                 {
220                         res->idfMax = ptr->idfMax;
221                         ptr++;
222                 }
223
224         return res + 1 - a;
225 }
226
227 static StatCache*
228 getHashedCache(void *cache)
229 {
230         StatCache *stat = getStat(cache, SIGLENBIT);
231
232         if ( stat->nhelems < 0 )
233         {
234                 int i;
235                 /*
236                  * Init
237                  */
238
239                 if (stat->info->tupDesc)
240                         elog(ERROR, "GiST  doesn't support composite (weighted) type");
241                 getFmgrInfoHash(stat->info);
242                 for(i=0;i<stat->nelems;i++)
243                 {
244                         uint32  hash;
245
246                         hash = DatumGetUInt32(FunctionCall1Coll(&stat->info->hashFunc,
247                                                                                                         DEFAULT_COLLATION_OID,
248                                                                                                         stat->elems[i].datum));
249                         int             index = HASHVAL(hash);
250
251                         stat->helems[i].hash = hash;
252                         stat->helems[i].idfMin = stat->helems[i].idfMax = stat->elems[i].idf;   
253                         if ( stat->selems[index].idfMin == 0.0 )
254                                 stat->selems[index].idfMin = stat->selems[index].idfMax = stat->elems[i].idf;
255                         else if ( stat->selems[index].idfMin > stat->elems[i].idf )
256                                 stat->selems[index].idfMin = stat->elems[i].idf;
257                         else if ( stat->selems[index].idfMax < stat->elems[i].idf )
258                                 stat->selems[index].idfMax = stat->elems[i].idf;
259                 }
260
261                 stat->nhelems = uniqueHashedElem( stat->helems, stat->nelems);
262         }
263
264         return stat;
265 }
266
267 static HashedElem*
268 getHashedElemIdf(StatCache *stat, uint32 hash, HashedElem *StopLow)
269 {
270         HashedElem      *StopMiddle,
271                                 *StopHigh = stat->helems + stat->nhelems;
272
273         if ( !StopLow )
274                 StopLow = stat->helems;
275
276         while (StopLow < StopHigh) {
277                 StopMiddle = StopLow + ((StopHigh - StopLow) >> 1);
278
279                 if ( StopMiddle->hash == hash )
280                         return StopMiddle;
281                 else if ( StopMiddle->hash < hash )
282                         StopLow = StopMiddle + 1;
283                 else
284                         StopHigh = StopMiddle;
285         }
286
287         return NULL;
288 }
289
290 static void
291 fillHashVal(void *cache, SimpleArray *a)
292 {
293         int i;
294
295         if (a->hash)
296                 return;
297
298         allocateHash(cache, a);
299
300         if (a->info->tupDesc)
301                 elog(ERROR, "GiST  doesn't support composite (weighted) type");
302         getFmgrInfoHash(a->info);
303
304         for(i=0;i<a->nelems;i++)
305                 a->hash[i] = DatumGetUInt32(FunctionCall1Coll(&a->info->hashFunc,
306                                                                                                           DEFAULT_COLLATION_OID,
307                                                                                                           a->elems[i]));
308 }
309
310
311 static bool
312 hasHashedElem(SmlSign  *a, uint32 h)
313 {
314         uint32  *StopLow = GETARR(a),
315                         *StopHigh = GETARR(a) + a->size,
316                         *StopMiddle;
317
318         while (StopLow < StopHigh) {
319                 StopMiddle = StopLow + ((StopHigh - StopLow) >> 1);
320
321                 if ( *StopMiddle == h )
322                         return true;
323                 else if ( *StopMiddle < h )
324                         StopLow = StopMiddle + 1;
325                 else
326                         StopHigh = StopMiddle;
327         }
328
329         return false;
330 }
331
332 static void
333 makesign(BITVECP sign, SmlSign  *a)
334 {
335         int32   i;
336         uint32  *ptr = GETARR(a);
337
338         MemSet((void *) sign, 0, sizeof(BITVEC));
339         SETBIT(sign, SIGLENBIT);   /* set last unused bit */
340
341         for (i = 0; i < a->size; i++)
342                 HASH(sign, ptr[i]);
343 }
344
345 static int32
346 sizebitvec(BITVECP sign)
347 {
348         int32   size = 0,
349                         i;
350
351         LOOPBYTE
352                 size += number_of_ones[(unsigned char) sign[i]];
353
354         return size;
355 }
356
357 PG_FUNCTION_INFO_V1(gsmlsign_compress);
358 Datum gsmlsign_compress(PG_FUNCTION_ARGS);
359 Datum
360 gsmlsign_compress(PG_FUNCTION_ARGS)
361 {
362         GISTENTRY  *entry = (GISTENTRY *) PG_GETARG_POINTER(0);
363         GISTENTRY  *retval = entry;
364
365         if (entry->leafkey) /* new key */
366         {
367                 SmlSign         *sign;
368                 ArrayType       *a = DatumGetArrayTypeP(entry->key);
369
370                 sign = Array2HashedArray(NULL, a);
371
372                 if ( VARSIZE(sign) > TOAST_INDEX_TARGET )
373                 {       /* make signature due to its big size */
374                         SmlSign *tmpsign;
375                         int             len;
376
377                         len = CALCGTSIZE( SIGNKEY, sign->size );
378                         tmpsign = palloc( len );
379                         tmpsign->flag = SIGNKEY;
380                         SET_VARSIZE(tmpsign, len);
381
382                         makesign(GETSIGN(tmpsign), sign);
383                         tmpsign->size = sizebitvec(GETSIGN(tmpsign));
384                         tmpsign->maxrepeat = sign->maxrepeat;
385                         sign = tmpsign;
386                 }
387
388                 retval = (GISTENTRY *) palloc(sizeof(GISTENTRY));
389                 gistentryinit(*retval, PointerGetDatum(sign),
390                                                 entry->rel, entry->page,
391                                                 entry->offset, false);
392         }
393         else if ( ISSIGNKEY(DatumGetPointer(entry->key)) &&
394                                 !ISALLTRUE(DatumGetPointer(entry->key)) )
395         {
396                 SmlSign *sign = (SmlSign*)DatumGetPointer(entry->key);
397
398                 Assert( sign->size == sizebitvec(GETSIGN(sign)) );
399
400                 if ( sign->size == SIGLENBIT )
401                 {
402                         int32   len = CALCGTSIZE(SIGNKEY | ALLISTRUE, 0);
403                         int32   maxrepeat = sign->maxrepeat;
404
405                         sign = (SmlSign *) palloc(len);
406                         SET_VARSIZE(sign, len);
407                         sign->flag = SIGNKEY | ALLISTRUE;
408                         sign->size = SIGLENBIT;
409                         sign->maxrepeat = maxrepeat;
410
411                         retval = (GISTENTRY *) palloc(sizeof(GISTENTRY));
412
413                         gistentryinit(*retval, PointerGetDatum(sign),
414                                                         entry->rel, entry->page,
415                                                         entry->offset, false);
416                 }
417         }
418
419         PG_RETURN_POINTER(retval);
420 }
421
422 PG_FUNCTION_INFO_V1(gsmlsign_decompress);
423 Datum gsmlsign_decompress(PG_FUNCTION_ARGS);
424 Datum
425 gsmlsign_decompress(PG_FUNCTION_ARGS)
426 {
427         GISTENTRY       *entry = (GISTENTRY *) PG_GETARG_POINTER(0);
428         SmlSign         *key =  (SmlSign*)DatumGetPointer(PG_DETOAST_DATUM(entry->key));
429
430         if (key != (SmlSign *) DatumGetPointer(entry->key))
431         {
432                 GISTENTRY  *retval = (GISTENTRY *) palloc(sizeof(GISTENTRY));
433
434                 gistentryinit(*retval, PointerGetDatum(key),
435                                                 entry->rel, entry->page,
436                                                 entry->offset, false);
437
438                 PG_RETURN_POINTER(retval);
439         }
440
441         PG_RETURN_POINTER(entry);
442 }
443
444 /*
445  * Union method
446  */
447 static bool
448 unionkey(BITVECP sbase, SmlSign *add)
449 {
450         int32   i;
451
452         if (ISSIGNKEY(add))
453         {
454                 BITVECP sadd = GETSIGN(add);
455
456                 if (ISALLTRUE(add))
457                         return true;
458
459                 LOOPBYTE
460                         sbase[i] |= sadd[i];
461         }
462         else
463         {
464                 uint32  *ptr = GETARR(add);
465
466                 for (i = 0; i < add->size; i++)
467                         HASH(sbase, ptr[i]);
468         }
469
470         return false;
471 }
472
473 PG_FUNCTION_INFO_V1(gsmlsign_union);
474 Datum gsmlsign_union(PG_FUNCTION_ARGS);
475 Datum
476 gsmlsign_union(PG_FUNCTION_ARGS)
477 {
478         GistEntryVector *entryvec = (GistEntryVector *) PG_GETARG_POINTER(0);
479         int                             *size = (int *) PG_GETARG_POINTER(1);
480         BITVEC                  base;
481         int32           i,
482                                 len,
483                                 maxrepeat = 1;
484         int32           flag = 0;
485         SmlSign    *result;
486
487         MemSet((void *) base, 0, sizeof(BITVEC));
488         for (i = 0; i < entryvec->n; i++)
489         {
490                 if (GETENTRY(entryvec, i)->maxrepeat > maxrepeat)
491                         maxrepeat = GETENTRY(entryvec, i)->maxrepeat;
492                 if (unionkey(base, GETENTRY(entryvec, i)))
493                 {
494                         flag = ALLISTRUE;
495                         break;
496                 }
497         }
498
499         flag |= SIGNKEY;
500         len = CALCGTSIZE(flag, 0);
501         result = (SmlSign *) palloc(len);
502         *size = len;
503         SET_VARSIZE(result, len);
504         result->flag = flag;
505         result->maxrepeat = maxrepeat;
506
507         if (!ISALLTRUE(result))
508         {
509                 memcpy((void *) GETSIGN(result), (void *) base, sizeof(BITVEC));
510                 result->size = sizebitvec(GETSIGN(result));
511         }
512         else
513                 result->size = SIGLENBIT;
514
515         PG_RETURN_POINTER(result);
516 }
517
518 /*
519  * Same method
520  */
521
522 PG_FUNCTION_INFO_V1(gsmlsign_same);
523 Datum gsmlsign_same(PG_FUNCTION_ARGS);
524 Datum
525 gsmlsign_same(PG_FUNCTION_ARGS)
526 {
527         SmlSign *a = (SmlSign*)PG_GETARG_POINTER(0);
528         SmlSign *b = (SmlSign*)PG_GETARG_POINTER(1);
529         bool    *result = (bool *) PG_GETARG_POINTER(2);
530
531         if (a->size != b->size)
532         {
533                 *result = false;
534         }
535         else if (ISSIGNKEY(a))
536         {       /* then b also ISSIGNKEY */
537                 if ( ISALLTRUE(a) )
538                 {
539                         /* in this case b is all true too - other cases is catched
540                            upper */
541                         *result = true;
542                 }
543                 else
544                 {
545                         int32   i;
546                         BITVECP sa = GETSIGN(a),
547                                         sb = GETSIGN(b);
548
549                         *result = true;
550
551                         if ( !ISALLTRUE(a) )
552                         {
553                                 LOOPBYTE
554                                 {
555                                         if (sa[i] != sb[i])
556                                         {
557                                                 *result = false;
558                                                 break;
559                                         }
560                                 }
561                         }
562                 }
563         }
564         else
565         {
566                 uint32  *ptra = GETARR(a),
567                                 *ptrb = GETARR(b);
568                 int32   i;
569
570                 *result = true;
571                 for (i = 0; i < a->size; i++)
572                 {
573                         if ( ptra[i] != ptrb[i])
574                         {
575                                 *result = false;
576                                 break;
577                         }
578                 }
579         }
580
581         PG_RETURN_POINTER(result);
582 }
583
584 /*
585  * Penalty method
586  */
587 static int
588 hemdistsign(BITVECP a, BITVECP b)
589 {
590         int     i,
591                 diff,
592                 dist = 0;
593
594         LOOPBYTE
595         {
596                 diff = (unsigned char) (a[i] ^ b[i]);
597                 dist += number_of_ones[diff];
598         }
599         return dist;
600 }
601
602 static int
603 hemdist(SmlSign *a, SmlSign *b)
604 {
605         if (ISALLTRUE(a))
606         {
607                 if (ISALLTRUE(b))
608                         return 0;
609                 else
610                         return SIGLENBIT - b->size;
611         }
612         else if (ISALLTRUE(b))
613                 return SIGLENBIT - b->size;
614
615         return hemdistsign(GETSIGN(a), GETSIGN(b));
616 }
617
618 PG_FUNCTION_INFO_V1(gsmlsign_penalty);
619 Datum gsmlsign_penalty(PG_FUNCTION_ARGS);
620 Datum
621 gsmlsign_penalty(PG_FUNCTION_ARGS)
622 {
623         GISTENTRY       *origentry = (GISTENTRY *) PG_GETARG_POINTER(0); /* always ISSIGNKEY */
624         GISTENTRY       *newentry = (GISTENTRY *) PG_GETARG_POINTER(1);
625         float           *penalty = (float *) PG_GETARG_POINTER(2);
626         SmlSign         *origval = (SmlSign *) DatumGetPointer(origentry->key);
627         SmlSign         *newval = (SmlSign *) DatumGetPointer(newentry->key);
628         BITVECP         orig = GETSIGN(origval);
629
630         *penalty = 0.0;
631
632         if (ISARRKEY(newval))
633         {
634                 BITVEC  sign;
635
636                 makesign(sign, newval);
637
638                 if (ISALLTRUE(origval))
639                         *penalty = ((float) (SIGLENBIT - sizebitvec(sign))) / (float) (SIGLENBIT + 1);
640                 else
641                         *penalty = hemdistsign(sign, orig);
642         }
643         else
644                 *penalty = hemdist(origval, newval);
645
646         PG_RETURN_POINTER(penalty);
647 }
648
649 /*
650  * Picksplit method
651  */
652
653 typedef struct
654 {
655         bool    allistrue;
656         int32   size;
657         BITVEC  sign;
658 } CACHESIGN;
659
660 static void
661 fillcache(CACHESIGN *item, SmlSign *key)
662 {
663         item->allistrue = false;
664         item->size = key->size;
665
666         if (ISARRKEY(key))
667         {
668                 makesign(item->sign, key);
669                 item->size = sizebitvec( item->sign );
670         }
671         else if (ISALLTRUE(key))
672                 item->allistrue = true;
673         else
674                 memcpy((void *) item->sign, (void *) GETSIGN(key), sizeof(BITVEC));
675 }
676
677 #define WISH_F(a,b,c) (double)( -(double)(((a)-(b))*((a)-(b))*((a)-(b)))*(c) )
678
679 typedef struct
680 {
681         OffsetNumber pos;
682         int32           cost;
683 } SPLITCOST;
684
685 static int
686 comparecost(const void *va, const void *vb)
687 {
688         SPLITCOST  *a = (SPLITCOST *) va;
689         SPLITCOST  *b = (SPLITCOST *) vb;
690
691         if (a->cost == b->cost)
692                 return 0;
693         else
694                 return (a->cost > b->cost) ? 1 : -1;
695 }
696
697 static int
698 hemdistcache(CACHESIGN *a, CACHESIGN *b)
699 {
700         if (a->allistrue)
701         {
702                 if (b->allistrue)
703                         return 0;
704                 else
705                         return SIGLENBIT - b->size;
706         }
707         else if (b->allistrue)
708                 return SIGLENBIT - a->size;
709
710         return hemdistsign(a->sign, b->sign);
711 }
712
713 PG_FUNCTION_INFO_V1(gsmlsign_picksplit);
714 Datum gsmlsign_picksplit(PG_FUNCTION_ARGS);
715 Datum
716 gsmlsign_picksplit(PG_FUNCTION_ARGS)
717 {
718         GistEntryVector *entryvec = (GistEntryVector *) PG_GETARG_POINTER(0);
719         GIST_SPLITVEC *v = (GIST_SPLITVEC *) PG_GETARG_POINTER(1);
720         OffsetNumber k,
721                                 j;
722         SmlSign         *datum_l,
723                                 *datum_r;
724         BITVECP         union_l,
725                                 union_r;
726         int32           size_alpha,
727                                 size_beta;
728         int32           size_waste,
729                                 waste = -1;
730         int32           nbytes;
731         OffsetNumber seed_1 = 0,
732                                 seed_2 = 0;
733         OffsetNumber *left,
734                                 *right;
735         OffsetNumber maxoff;
736         BITVECP         ptr;
737         int                     i;
738         CACHESIGN  *cache;
739         SPLITCOST  *costvector;
740
741         maxoff = entryvec->n - 2;
742         nbytes = (maxoff + 2) * sizeof(OffsetNumber);
743         v->spl_left = (OffsetNumber *) palloc(nbytes);
744         v->spl_right = (OffsetNumber *) palloc(nbytes);
745
746         cache = (CACHESIGN *) palloc(sizeof(CACHESIGN) * (maxoff + 2));
747         fillcache(&cache[FirstOffsetNumber], GETENTRY(entryvec, FirstOffsetNumber));
748
749         for (k = FirstOffsetNumber; k < maxoff; k = OffsetNumberNext(k))
750         {
751                 for (j = OffsetNumberNext(k); j <= maxoff; j = OffsetNumberNext(j))
752                 {
753                         if (k == FirstOffsetNumber)
754                                 fillcache(&cache[j], GETENTRY(entryvec, j));
755
756                         size_waste = hemdistcache(&(cache[j]), &(cache[k]));
757                         if (size_waste > waste)
758                         {
759                                 waste = size_waste;
760                                 seed_1 = k;
761                                 seed_2 = j;
762                         }
763                 }
764         }
765
766         left = v->spl_left;
767         v->spl_nleft = 0;
768         right = v->spl_right;
769         v->spl_nright = 0;
770
771         if (seed_1 == 0 || seed_2 == 0)
772         {
773                 seed_1 = 1;
774                 seed_2 = 2;
775         }
776
777         /* form initial .. */
778         if (cache[seed_1].allistrue)
779         {
780                 datum_l = (SmlSign *) palloc(CALCGTSIZE(SIGNKEY | ALLISTRUE, 0));
781                 SET_VARSIZE(datum_l, CALCGTSIZE(SIGNKEY | ALLISTRUE, 0));
782                 datum_l->flag = SIGNKEY | ALLISTRUE;
783                 datum_l->size = SIGLENBIT;
784         }
785         else
786         {
787                 datum_l = (SmlSign *) palloc(CALCGTSIZE(SIGNKEY, 0));
788                 SET_VARSIZE(datum_l, CALCGTSIZE(SIGNKEY, 0));
789                 datum_l->flag = SIGNKEY;
790                 memcpy((void *) GETSIGN(datum_l), (void *) cache[seed_1].sign, sizeof(BITVEC));
791                 datum_l->size = cache[seed_1].size;
792         }
793         if (cache[seed_2].allistrue)
794         {
795                 datum_r = (SmlSign *) palloc(CALCGTSIZE(SIGNKEY | ALLISTRUE, 0));
796                 SET_VARSIZE(datum_r, CALCGTSIZE(SIGNKEY | ALLISTRUE, 0));
797                 datum_r->flag = SIGNKEY | ALLISTRUE;
798                 datum_r->size = SIGLENBIT;
799         }
800         else
801         {
802                 datum_r = (SmlSign *) palloc(CALCGTSIZE(SIGNKEY, 0));
803                 SET_VARSIZE(datum_r, CALCGTSIZE(SIGNKEY, 0));
804                 datum_r->flag = SIGNKEY;
805                 memcpy((void *) GETSIGN(datum_r), (void *) cache[seed_2].sign, sizeof(BITVEC));
806                 datum_r->size = cache[seed_2].size;
807         }
808
809         union_l = GETSIGN(datum_l);
810         union_r = GETSIGN(datum_r);
811         maxoff = OffsetNumberNext(maxoff);
812         fillcache(&cache[maxoff], GETENTRY(entryvec, maxoff));
813         /* sort before ... */
814         costvector = (SPLITCOST *) palloc(sizeof(SPLITCOST) * maxoff);
815         for (j = FirstOffsetNumber; j <= maxoff; j = OffsetNumberNext(j))
816         {
817                 costvector[j - 1].pos = j;
818                 size_alpha = hemdistcache(&(cache[seed_1]), &(cache[j]));
819                 size_beta = hemdistcache(&(cache[seed_2]), &(cache[j]));
820                 costvector[j - 1].cost = Abs(size_alpha - size_beta);
821         }
822         qsort((void *) costvector, maxoff, sizeof(SPLITCOST), comparecost);
823
824         datum_l->maxrepeat = datum_r->maxrepeat = 1;
825
826         for (k = 0; k < maxoff; k++)
827         {
828                 j = costvector[k].pos;
829                 if (j == seed_1)
830                 {
831                         *left++ = j;
832                         v->spl_nleft++;
833                         continue;
834                 }
835                 else if (j == seed_2)
836                 {
837                         *right++ = j;
838                         v->spl_nright++;
839                         continue;
840                 }
841
842                 if (ISALLTRUE(datum_l) || cache[j].allistrue)
843                 {
844                         if (ISALLTRUE(datum_l) && cache[j].allistrue)
845                                 size_alpha = 0;
846                         else
847                                 size_alpha = SIGLENBIT - (
848                                                                         (cache[j].allistrue) ? datum_l->size : cache[j].size
849                                                         );
850                 }
851                 else
852                         size_alpha = hemdistsign(cache[j].sign, GETSIGN(datum_l));
853
854                 if (ISALLTRUE(datum_r) || cache[j].allistrue)
855                 {
856                         if (ISALLTRUE(datum_r) && cache[j].allistrue)
857                                 size_beta = 0;
858                         else
859                                 size_beta = SIGLENBIT - (
860                                                                         (cache[j].allistrue) ? datum_r->size : cache[j].size
861                                                         );
862                 }
863                 else
864                         size_beta = hemdistsign(cache[j].sign, GETSIGN(datum_r));
865
866                 if (size_alpha < size_beta + WISH_F(v->spl_nleft, v->spl_nright, 0.1))
867                 {
868                         if (ISALLTRUE(datum_l) || cache[j].allistrue)
869                         {
870                                 if (!ISALLTRUE(datum_l))
871                                         MemSet((void *) GETSIGN(datum_l), 0xff, sizeof(BITVEC));
872                                 datum_l->size = SIGLENBIT;
873                         }
874                         else
875                         {
876                                 ptr = cache[j].sign;
877                                 LOOPBYTE
878                                         union_l[i] |= ptr[i];
879                                 datum_l->size = sizebitvec(union_l);
880                         }
881                         *left++ = j;
882                         v->spl_nleft++;
883                 }
884                 else
885                 {
886                         if (ISALLTRUE(datum_r) || cache[j].allistrue)
887                         {
888                                 if (!ISALLTRUE(datum_r))
889                                         MemSet((void *) GETSIGN(datum_r), 0xff, sizeof(BITVEC));
890                                 datum_r->size = SIGLENBIT;
891                         }
892                         else
893                         {
894                                 ptr = cache[j].sign;
895                                 LOOPBYTE
896                                         union_r[i] |= ptr[i];
897                                 datum_r->size = sizebitvec(union_r);
898                         }
899                         *right++ = j;
900                         v->spl_nright++;
901                 }
902         }
903         *right = *left = FirstOffsetNumber;
904         v->spl_ldatum = PointerGetDatum(datum_l);
905         v->spl_rdatum = PointerGetDatum(datum_r);
906
907         Assert( datum_l->size = sizebitvec(GETSIGN(datum_l)) );
908         Assert( datum_r->size = sizebitvec(GETSIGN(datum_r)) );
909
910         PG_RETURN_POINTER(v);
911 }
912
913 static double
914 getIdfMaxLimit(SmlSign *key)
915 {
916         switch( getTFMethod() )
917         {
918                 case TF_CONST:
919                         return 1.0;
920                         break;
921                 case TF_N:
922                         return (double)(key->maxrepeat);
923                         break;
924                 case TF_LOG:
925                         return 1.0 + log( (double)(key->maxrepeat) );
926                         break;
927                 default:
928                         elog(ERROR,"Unknown TF method: %d", getTFMethod());
929         }
930
931         return 0.0;
932 }
933
934 /*
935  * Consistent function
936  */
937 PG_FUNCTION_INFO_V1(gsmlsign_consistent);
938 Datum gsmlsign_consistent(PG_FUNCTION_ARGS);
939 Datum
940 gsmlsign_consistent(PG_FUNCTION_ARGS)
941 {
942         GISTENTRY               *entry = (GISTENTRY *) PG_GETARG_POINTER(0);
943         StrategyNumber  strategy = (StrategyNumber) PG_GETARG_UINT16(2);
944         bool                    *recheck = (bool *) PG_GETARG_POINTER(4);
945         ArrayType               *a;
946         SmlSign                 *key = (SmlSign*)DatumGetPointer(entry->key);
947         int                             res = false;
948         SmlSign                 *query;
949         SimpleArray             *s;
950         int32                   i;
951
952         fcinfo->flinfo->fn_extra = SearchArrayCache(
953                                                                         fcinfo->flinfo->fn_extra,
954                                                                         fcinfo->flinfo->fn_mcxt,
955                                                                         PG_GETARG_DATUM(1), &a, &s, &query);
956
957         *recheck = true;
958
959         if ( ARRISVOID(a) )
960                 PG_RETURN_BOOL(res);
961
962         if ( strategy == SmlarOverlapStrategy )
963         {
964                 if (ISALLTRUE(key))
965                 {
966                         res = true;
967                 }
968                 else if (ISARRKEY(key))
969                 {
970                         uint32  *kptr = GETARR(key),
971                                         *qptr = GETARR(query);
972
973                         while( kptr - GETARR(key) < key->size && qptr - GETARR(query) < query->size )
974                         {
975                                 if ( *kptr < *qptr )
976                                         kptr++;
977                                 else if ( *kptr > *qptr )
978                                         qptr++;
979                                 else
980                                 {
981                                         res = true;
982                                         break;
983                                 }
984                         }
985                         *recheck = false;
986                 }
987                 else
988                 {
989                         BITVECP sign = GETSIGN(key);
990
991                         fillHashVal(fcinfo->flinfo->fn_extra, s);
992
993                         for(i=0; i<s->nelems; i++)
994                         {
995                                 if ( GETBIT(sign, HASHVAL(s->hash[i])) )
996                                 {
997                                         res = true;
998                                         break;
999                                 }
1000                         }
1001                 }
1002         }
1003         /*
1004          *  SmlarSimilarityStrategy
1005          */
1006         else if (ISALLTRUE(key))
1007         {
1008                 if ( GIST_LEAF(entry) )
1009                 {
1010                         /*
1011                          * With TF/IDF similarity we cannot say anything useful
1012                          */
1013                         if ( query->size < SIGLENBIT && getSmlType() != ST_TFIDF )
1014                         {
1015                                 double power = ((double)(query->size)) * ((double)(SIGLENBIT));
1016
1017                                 if ( ((double)(query->size)) / sqrt(power) >= GetSmlarLimit() ) 
1018                                         res = true;
1019                         }
1020                         else
1021                         {
1022                                 res = true;     
1023                         }
1024                 }
1025                 else
1026                         res = true;
1027         }
1028         else if (ISARRKEY(key))
1029         {
1030                 uint32  *kptr = GETARR(key),
1031                                 *qptr = GETARR(query);
1032
1033                 Assert( GIST_LEAF(entry) );
1034
1035                 switch(getSmlType())
1036                 {
1037                         case ST_TFIDF:
1038                                 {
1039                                         StatCache       *stat = getHashedCache(fcinfo->flinfo->fn_extra);
1040                                         double          sumU = 0.0,
1041                                                                 sumQ = 0.0,
1042                                                                 sumK = 0.0;
1043                                         double          maxKTF = getIdfMaxLimit(key);
1044                                         HashedElem  *h;
1045
1046                                         Assert( s->df );
1047                                         fillHashVal(fcinfo->flinfo->fn_extra, s);
1048                                         if ( stat->info != s->info )
1049                                                 elog(ERROR,"Statistic and actual argument have different type");
1050
1051                                         for(i=0;i<s->nelems;i++)
1052                                         {
1053                                                 sumQ += s->df[i] * s->df[i];
1054
1055                                                 h = getHashedElemIdf(stat, s->hash[i], NULL);
1056                                                 if ( h && hasHashedElem(key, s->hash[i]) )
1057                                                 {
1058                                                         sumK += h->idfMin * h->idfMin;
1059                                                         sumU += h->idfMax * maxKTF * s->df[i];
1060                                                 } 
1061                                         }
1062
1063                                         if ( sumK > 0.0 && sumQ > 0.0 && sumU / sqrt( sumK * sumQ ) >= GetSmlarLimit() )
1064                                         {
1065                                                 /* 
1066                                                  * More precisely calculate sumK
1067                                                  */
1068                                                 h = NULL;
1069                                                 sumK = 0.0;
1070
1071                                                 for(i=0;i<key->size;i++)
1072                                                 {
1073                                                         h = getHashedElemIdf(stat, GETARR(key)[i], h);                                  
1074                                                         if (h)
1075                                                                 sumK += h->idfMin * h->idfMin;
1076                                                 }
1077
1078                                                 if ( sumK > 0.0 && sumQ > 0.0 && sumU / sqrt( sumK * sumQ ) >= GetSmlarLimit() )
1079                                                         res = true;
1080                                         }
1081                                 }
1082                                 break;
1083                         case ST_COSINE:
1084                                 {
1085                                         double                  power;
1086                                         power = sqrt( ((double)(key->size)) * ((double)(s->nelems)) );
1087
1088                                         if (  ((double)Min(key->size, s->nelems)) / power >= GetSmlarLimit() )
1089                                         {
1090                                                 int  cnt = 0;
1091
1092                                                 while( kptr - GETARR(key) < key->size && qptr - GETARR(query) < query->size )
1093                                                 {
1094                                                         if ( *kptr < *qptr )
1095                                                                 kptr++;
1096                                                         else if ( *kptr > *qptr )
1097                                                                 qptr++;
1098                                                         else
1099                                                         {
1100                                                                 cnt++;
1101                                                                 kptr++;
1102                                                                 qptr++;
1103                                                         }
1104                                                 }
1105
1106                                                 if ( ((double)cnt) / power >= GetSmlarLimit() )
1107                                                         res = true;
1108                                         }
1109                                 }
1110                                 break;
1111                         default:
1112                                 elog(ERROR,"GiST doesn't support current formula type of similarity");
1113                 }
1114         }
1115         else
1116         {       /* signature */
1117                 BITVECP sign = GETSIGN(key);
1118                 int32   count = 0;
1119
1120                 fillHashVal(fcinfo->flinfo->fn_extra, s);
1121
1122                 if ( GIST_LEAF(entry) )
1123                 {
1124                         switch(getSmlType())
1125                         {
1126                                 case ST_TFIDF:
1127                                         {
1128                                                 StatCache       *stat = getHashedCache(fcinfo->flinfo->fn_extra);
1129                                                 double          sumU = 0.0,
1130                                                                         sumQ = 0.0,
1131                                                                         sumK = 0.0;
1132                                                 double          maxKTF = getIdfMaxLimit(key);
1133
1134                                                 Assert( s->df );
1135                                                 if ( stat->info != s->info )
1136                                                         elog(ERROR,"Statistic and actual argument have different type");
1137
1138                                                 for(i=0;i<s->nelems;i++)
1139                                                 {
1140                                                         int32           hbit = HASHVAL(s->hash[i]);
1141
1142                                                         sumQ += s->df[i] * s->df[i];
1143                                                         if ( GETBIT(sign, hbit) )
1144                                                         {
1145                                                                 sumK += stat->selems[ hbit ].idfMin * stat->selems[ hbit ].idfMin;
1146                                                                 sumU += stat->selems[ hbit ].idfMax * maxKTF * s->df[i];
1147                                                         }
1148                                                 }
1149
1150                                                 if ( sumK > 0.0 && sumQ > 0.0 && sumU / sqrt( sumK * sumQ ) >= GetSmlarLimit() )
1151                                                 {
1152                                                         /* 
1153                                                          * More precisely calculate sumK
1154                                                          */
1155                                                         sumK = 0.0;
1156
1157                                                         for(i=0;i<SIGLENBIT;i++)
1158                                                                 if ( GETBIT(sign,i) )
1159                                                                         sumK += stat->selems[ i ].idfMin * stat->selems[ i ].idfMin;
1160
1161                                                         if ( sumK > 0.0 && sumQ > 0.0 && sumU / sqrt( sumK * sumQ ) >= GetSmlarLimit() )
1162                                                                 res = true;
1163                                                 }
1164                                         }
1165                                         break;
1166                                 case ST_COSINE:
1167                                         {
1168                                                 double  power;
1169
1170                                                 power = sqrt( ((double)(key->size)) * ((double)(s->nelems)) );
1171
1172                                                 for(i=0; i<s->nelems; i++)
1173                                                         count += GETBIT(sign, HASHVAL(s->hash[i]));
1174
1175                                                 if ( ((double)count) / power >= GetSmlarLimit() )
1176                                                         res = true;
1177                                         }
1178                                         break;
1179                                 default:
1180                                         elog(ERROR,"GiST doesn't support current formula type of similarity");
1181                         }
1182                 }
1183                 else /* non-leaf page */
1184                 {
1185                         switch(getSmlType())
1186                         {
1187                                 case ST_TFIDF:
1188                                         {
1189                                                 StatCache       *stat = getHashedCache(fcinfo->flinfo->fn_extra);
1190                                                 double          sumU = 0.0,
1191                                                                         sumQ = 0.0,
1192                                                                         minK = -1.0;
1193                                                 double          maxKTF = getIdfMaxLimit(key);
1194
1195                                                 Assert( s->df );
1196                                                 if ( stat->info != s->info )
1197                                                         elog(ERROR,"Statistic and actual argument have different type");
1198
1199                                                 for(i=0;i<s->nelems;i++)
1200                                                 {
1201                                                         int32           hbit = HASHVAL(s->hash[i]);
1202
1203                                                         sumQ += s->df[i] * s->df[i];
1204                                                         if ( GETBIT(sign, hbit) )
1205                                                         {
1206                                                                 sumU += stat->selems[ hbit ].idfMax * maxKTF * s->df[i];
1207                                                                 if ( minK > stat->selems[ hbit ].idfMin  || minK < 0.0 )
1208                                                                         minK = stat->selems[ hbit ].idfMin;
1209                                                         }
1210                                                 }
1211
1212                                                 if ( sumQ > 0.0 && minK > 0.0 && sumU / sqrt( sumQ * minK ) >= GetSmlarLimit() )
1213                                                         res = true;
1214                                         }
1215                                         break;
1216                                 case ST_COSINE:
1217                                         {
1218                                                 for(i=0; i<s->nelems; i++)
1219                                                         count += GETBIT(sign, HASHVAL(s->hash[i]));
1220
1221                                                 if ( s->nelems == count  || sqrt(((double)count) / ((double)(s->nelems))) >= GetSmlarLimit() )
1222                                                         res = true;
1223                                         }
1224                                         break;
1225                                 default:
1226                                         elog(ERROR,"GiST doesn't support current formula type of similarity");
1227                         }
1228                 }
1229         }
1230
1231 #if 0
1232         {
1233                 static int nnres = 0;
1234                 static int nres = 0;
1235                 if ( GIST_LEAF(entry) ) {
1236                         if ( res )
1237                                 nres++;
1238                         else
1239                                 nnres++;
1240                         elog(NOTICE,"%s nn:%d n:%d", (ISARRKEY(key)) ? "ARR" : ( (ISALLTRUE(key)) ? "TRUE" : "SIGN" ), nnres, nres  );
1241                 }
1242         }
1243 #endif
1244
1245         PG_RETURN_BOOL(res);
1246 }