389dff82268e5e062f7eede9cccc9a2ff1fb06e3
[wildspeed.git] / wildspeed.c
1 #include "postgres.h"
2
3 #include "catalog/pg_type.h"
4 #include "mb/pg_wchar.h"
5 #include "utils/array.h"
6
7 /*
8  * MARK_SIGN is a sign of end of string and this character should
9  * be regular character. The best candidate of this is a zero byte
10  * which is accepted for any locale used by postgres. But it's impossible
11  * to show it, so we will replace it to another one (MARK_SIGN_SHOW) which 
12  * can be noticed well. But we can't use it as mark because it's allowed
13  * to be inside string.
14  */ 
15
16 #define         MARK_SIGN               '\0'
17 #define         MARK_SIGN_SHOW  '$'
18
19
20 #define WC_BEGIN                0x01   /* should be in begining of string */
21 #define WC_MIDDLE               0x02   /* should be in middle of string */
22 #define WC_END                  0x04   /* should be in end of string */
23
24 PG_MODULE_MAGIC;
25
26 static text*
27 appendStrToText( text *src, char *str, int32 len, int32 maxlen )
28 {
29         int32   curlen;
30
31         if (src == NULL )
32         {
33                 Assert( maxlen >= 0 );
34                 src = (text*)palloc( VARHDRSZ + sizeof(char*) * maxlen );
35                 SET_VARSIZE(src, 0 + VARHDRSZ);
36         }
37
38         curlen = VARSIZE(src) - VARHDRSZ;
39
40         if (len>0)
41                 memcpy( VARDATA(src) + curlen, str, len );
42
43         SET_VARSIZE(src, curlen + len + VARHDRSZ);
44
45         return src;
46 }
47
48 static text*
49 appendMarkToText( text *src, int32 maxlen )
50 {
51         char sign = MARK_SIGN;
52
53         return appendStrToText( src, &sign, 1, maxlen );
54 }
55
56 static text*
57 setFlagOfText( char flag, int32 maxlen )
58 {
59         char flagstruct[2];
60
61         Assert( maxlen > 0 );
62         /*
63          * Mark text by setting first byte to MARK_SIGN to indicate
64          * that text has flags. It's a safe for non empty string, 
65          * because first character can not be a MARK_SIGN (see
66          * gin_extract_permuted() )
67          */
68
69         flagstruct[0] = MARK_SIGN;
70         flagstruct[1] = flag;
71         
72         return appendStrToText(NULL, flagstruct, 2, maxlen );
73 }
74
75 PG_FUNCTION_INFO_V1(gin_extract_permuted);
76 Datum           gin_extract_permuted(PG_FUNCTION_ARGS);
77 Datum
78 gin_extract_permuted(PG_FUNCTION_ARGS)
79 {
80         text    *src = PG_GETARG_TEXT_P(0);
81         int32   *nentries = (int32 *) PG_GETARG_POINTER(1);
82         Datum   *entries = NULL;
83         int32   srclen = pg_mbstrlen_with_len(VARDATA(src), VARSIZE(src) - VARHDRSZ);
84
85         *nentries = srclen;
86
87         if ( srclen == 0 )
88         {
89                 /*
90                  * Empty string is encoded by alone MARK_SIGN character
91                  */
92                 *nentries = 1;
93                 entries = (Datum*) palloc(sizeof(Datum));
94                 entries[0] = PointerGetDatum( appendMarkToText( NULL, 1 ) );
95         }
96         else
97         {
98                 text    *dst;
99                 int32   i, 
100                                 offset=0; /* offset to current position in src in bytes */ 
101                 int32   nbytes = VARSIZE(src) - VARHDRSZ;
102                 char    *srcptr = VARDATA(src);
103
104                 /*
105                  * Permutation: hello will be permuted to hello$, ello$h, llo$he, lo$hel, o$hell.
106                  * So, number of entries is equial to number of characters (not a bytes)
107                  */
108
109                 entries = (Datum*)palloc(sizeof(char*) * nbytes );
110                 for(i=0; i<srclen;i++) {
111                 
112                         /*
113                          * Copy first part. For llo$he it will be 'llo'
114                          */
115                         dst = appendStrToText( NULL, srcptr + offset, nbytes - offset, nbytes + 1 ); 
116
117                         /*
118                          * Set mark sign ($)
119                          */
120                         dst = appendMarkToText( dst, -1 );
121
122                         /*
123                          * Copy rest of string (in example above 'he')
124                          */
125                         dst = appendStrToText( dst, srcptr, offset, -1 );
126
127                         entries[i] = PointerGetDatum(dst);
128                         offset += pg_mblen( srcptr + offset );
129                 }
130         }
131
132         PG_FREE_IF_COPY(src,0);
133         PG_RETURN_POINTER(entries);
134 }
135
136 PG_FUNCTION_INFO_V1(wildcmp);
137 Datum       wildcmp(PG_FUNCTION_ARGS);
138 Datum
139 wildcmp(PG_FUNCTION_ARGS)
140 {
141         text    *a = PG_GETARG_TEXT_P(0);
142         text    *b = PG_GETARG_TEXT_P(1);
143         bool    partialMatch = PG_GETARG_BOOL(2);
144         int32   cmp;
145         int             lena,
146                         lenb;
147         char    *ptra = VARDATA(a),
148                         *ptrb = VARDATA(b);
149         char    flag = 0;
150
151         lena = VARSIZE(a) - VARHDRSZ;
152         lenb = VARSIZE(b) - VARHDRSZ;
153
154         /*
155          * sets correct pointers and lengths in case of flags
156          * presence
157          */
158         if ( lena > 2 && *ptra == MARK_SIGN )
159         {
160                 flag = *(ptra+1);
161                 ptra+=2;
162                 lena-=2;
163
164                 if ( lenb > 2 && *ptrb == MARK_SIGN )
165                 {
166                         /*
167                          * If they have different flags then they can not be equal, this 
168                          * place works only during check of equality of keys
169                          * to search
170                          */
171                         if ( flag != *(ptrb+1) )
172                                 return 1;
173                         ptrb+=2;
174                         lenb-=2;
175
176                         /* b can not be a product of gin_extract_wildcard for partial match mode */
177                         Assert( partialMatch == false );
178                 }
179         } 
180         else  if ( lenb > 2 && *ptrb == MARK_SIGN )
181         {
182                 /* b can not be a product of gin_extract_wildcard for partial match mode */
183                 Assert( partialMatch == false );
184
185                 ptrb+=2;
186                 lenb-=2;
187         }
188
189         if ( lena == 0 )
190         {
191                 if ( partialMatch )
192                         cmp = 0; /* full scan for partialMatch*/
193                 else
194                         cmp = (lenb>0) ? -1 : 0;
195         }
196         else
197         {
198                 /*
199                  * We couldn't use strcmp because of MARK_SIGN
200                  */
201                 cmp = memcmp(ptra, ptrb, Min(lena, lenb));
202
203                 if ( partialMatch )
204                 {
205                         if ( cmp == 0 )
206                         {
207                                 if ( lena > lenb )
208                                 {
209                                         /*
210                                          * b argument is not beginning with argument a
211                                          */
212                                         cmp = 1;
213                                 }
214                                 else if ( flag > 0 && lenb>lena /* be safe */ )
215                                 { /* there is some flags to check */
216                                         char    actualFlag;
217
218                                         if ( ptrb[ lenb - 1 ] == MARK_SIGN )
219                                                 actualFlag = WC_BEGIN;  
220                                         else if ( ptrb[ lena ] == MARK_SIGN )
221                                                 actualFlag = WC_END;
222                                         else
223                                                 actualFlag = WC_MIDDLE;
224
225                                         if ( (flag & actualFlag) == 0 )
226                                         {
227                                                 /* 
228                                                  * Prefix are matched but this prefix s not placed as needed.
229                                                  * so we should give a smoke signal to GIN that we don't want
230                                                  * this match, but wish to continue scan 
231                                                  */
232                                                 cmp = -1;
233                                         }
234                                 }
235                         } 
236                         else if (cmp < 0)
237                         {
238                                 cmp = 1; /* prevent continue scan */
239                         }
240                 } 
241                 else if ( (cmp == 0) && (lena != lenb) )
242                 {
243                         cmp = (lena < lenb) ? -1 : 1;
244                 }
245         }
246
247         PG_FREE_IF_COPY(a,0);
248         PG_FREE_IF_COPY(b,1);
249         PG_RETURN_INT32( cmp ); 
250 }
251
252 #ifdef OPTIMIZE_WILDCARD_QUERY
253
254 typedef struct 
255 {
256         Datum   entry;
257         int32   len;
258         char    flag;
259 } OptItem;
260
261
262 /*
263  * Function drops most short search word to speedup 
264  * index search by preventing use word which gives
265  * a lot of matches
266  */
267 static void 
268 optimize_wildcard_search( Datum *entries, int32 *nentries )
269 {
270         int32   maxlen=0;
271         OptItem *items;
272         int             i, nitems = *nentries;
273         char    *ptr,*p;
274
275         items = (OptItem*)palloc( sizeof(OptItem) * (*nentries) );
276         for(i=0;i<nitems;i++)
277         {
278                 items[i].entry = entries[i];
279                 items[i].len = VARSIZE(entries[i]) - VARHDRSZ;
280                 ptr = VARDATA(entries[i]);
281
282                 if ( items[i].len > 2 && *ptr == MARK_SIGN )
283                 {
284                         items[i].len-=2;
285                         items[i].flag = *(ptr+1);
286                 }
287                 else
288                 {
289                         items[i].flag = 0;
290                         if ( items[i].len > 1 && (p=strchr(ptr, MARK_SIGN)) != NULL )
291                         {
292                                 if ( p == ptr + items[i].len -1 )
293                                         items[i].flag = WC_BEGIN;
294                                 else 
295                                         items[i].flag = WC_BEGIN | WC_END;
296                         }
297                 }
298
299                 if ( items[i].len > maxlen )
300                         maxlen = items[i].len;
301         }
302         
303         *nentries=0;
304
305         for(i=0;i<nitems;i++)
306         {
307                 if ( (items[i].flag & WC_BEGIN) && (items[i].flag & WC_END) )
308                 {       /* X$Y use always */
309                         entries[ *nentries ] = items[i].entry;
310                         (*nentries)++;
311                 }
312                 else if ( (items[i].flag & WC_MIDDLE) == 0 )
313                 { 
314                         /* 
315                          * for begin-only or end-only word we set more low limit than for 
316                          * other variants
317                          */
318                         if ( 3*items[i].len > maxlen )
319                         {
320                                 entries[ *nentries ] = items[i].entry;
321                                 (*nentries)++;
322                         }
323                 }
324                 else if ( 2*items[i].len > maxlen )
325                 {       
326                         /* 
327                          * use only items with biggest length 
328                          */
329                         entries[ *nentries ] = items[i].entry;
330                         (*nentries)++;
331                 }
332         }
333
334         Assert( *nentries>0 );
335
336 }
337 #endif
338
339 typedef struct 
340 {
341         bool    iswildcard;
342         int32   len;
343         char    *ptr;
344 } WildItem;
345
346 PG_FUNCTION_INFO_V1(gin_extract_wildcard);
347 Datum           gin_extract_wildcard(PG_FUNCTION_ARGS);
348 Datum
349 gin_extract_wildcard(PG_FUNCTION_ARGS)
350 {
351         text                    *q = PG_GETARG_TEXT_P(0);
352         int32                   lenq = VARSIZE(q) - VARHDRSZ;
353         int32                   *nentries = (int32 *) PG_GETARG_POINTER(1);
354 #ifdef NOT_USED
355         StrategyNumber  strategy = PG_GETARG_UINT16(2);
356 #endif
357         bool                    *partialmatch, 
358                                         **ptr_partialmatch = (bool**) PG_GETARG_POINTER(3);
359         Datum                   *entries = NULL;
360         char                    *qptr = VARDATA(q);
361         int                             clen,
362                                         splitqlen = 0,
363                                         i;
364         WildItem                *items;
365         text                    *entry;
366
367         *nentries = 0;
368
369         if ( lenq == 0 )
370         {
371                 partialmatch = *ptr_partialmatch = (bool*)palloc0(sizeof(bool));
372                 *nentries = 1;
373                 entries = (Datum*) palloc(sizeof(Datum));
374                 entries[0] = PointerGetDatum( appendMarkToText( NULL, 1 ) );
375
376                 PG_RETURN_POINTER(entries);
377         }
378
379         partialmatch = *ptr_partialmatch = (bool*)palloc0(sizeof(bool) * lenq);
380         entries = (Datum*) palloc(sizeof(Datum) * lenq);
381         items=(WildItem*) palloc0( sizeof(WildItem) * lenq );
382
383
384         /*
385          * Parse expression to the list of constant parts and
386          * wildcards
387          */
388         while( qptr - VARDATA(q) < lenq )
389         {
390                 clen = pg_mblen(qptr);
391
392                 if ( clen==1 && (*qptr == '_' || *qptr == '%' ) )
393                 {
394                         if ( splitqlen == 0 )
395                         {
396                                 items[ splitqlen ].iswildcard = true;
397                                 splitqlen++;
398                         } 
399                         else if ( items[ splitqlen-1 ].iswildcard == false )
400                         {
401                                 items[ splitqlen-1 ].len = qptr - items[ splitqlen-1 ].ptr;
402                                 items[ splitqlen ].iswildcard = true;
403                                 splitqlen++;
404                         }
405                         /*
406                          * ignore wildcard, because we don't make difference beetween
407                          * %, _ or a combination of its
408                          */
409                 }
410                 else
411                 {
412                         if ( splitqlen == 0 || items[ splitqlen-1 ].iswildcard == true )
413                         {
414                                 items[ splitqlen ].ptr = qptr;
415                                 splitqlen++;
416                         }
417                 }
418                 qptr += clen;
419         }
420
421         Assert( splitqlen >= 1 );
422         if ( items[ splitqlen-1 ].iswildcard == false )
423                 items[ splitqlen-1 ].len = qptr - items[ splitqlen-1 ].ptr;
424
425         if ( items[ 0 ].iswildcard == false )
426         {
427                 /* X... */
428                 if ( splitqlen == 1 )
429                 {
430                         /*   X => X$, exact match */
431                         *nentries = 1;
432                         entry = appendStrToText(NULL, items[ 0 ].ptr, items[ 0 ].len, lenq+1);
433                         entry = appendMarkToText( entry, -1 );
434                         entries[0] = PointerGetDatum( entry );
435                 } 
436                 else if ( items[ splitqlen-1 ].iswildcard == false ) 
437                 {
438                         /*   X * [X1 * [] ] ] Y => Y$X* [ + X1* [] ] */
439
440                         *nentries = 1;
441                         entry = appendStrToText(NULL, items[ splitqlen-1 ].ptr, items[ splitqlen-1 ].len, lenq+1);
442                         entry = appendMarkToText( entry, -1 );
443                         entry = appendStrToText(entry, items[ 0 ].ptr, items[ 0 ].len, -1);
444                         partialmatch[0] = true;
445                         entries[0] = PointerGetDatum( entry );
446
447                         for(i=1; i<splitqlen-1; i++)
448                         {
449                                 if ( items[ i ].iswildcard )
450                                         continue;
451                                 entry = setFlagOfText( WC_MIDDLE, lenq + 1 /* MARK_SIGN */ + 2 /* flag */ ); 
452                                 entry = appendStrToText(entry, items[ i ].ptr, items[ i ].len, -1 );
453                                 partialmatch[ *nentries ] = true;
454                                 entries[ *nentries ] =  PointerGetDatum( entry );
455                                 (*nentries)++;
456                         }
457                 }
458                 else
459                 {
460                         /*   X * [ X1 * [] ]  => X*$ [ + X1* [] ] */
461                 
462                         entry = setFlagOfText( WC_BEGIN, lenq + 1 /* MARK_SIGN */ + 2 /* flag */ );
463                         entry = appendStrToText(entry, items[ 0 ].ptr, items[ 0 ].len, -1);
464                         *nentries = 1;
465                         partialmatch[ 0 ] = true;
466                         entries[0] = PointerGetDatum( entry );
467
468                         for(i=2; i<splitqlen-1; i++)
469                         {
470                                 if ( items[ i ].iswildcard )
471                                         continue;
472                                 entry = setFlagOfText( (i==splitqlen-2) ? (WC_MIDDLE | WC_END) : WC_MIDDLE, 
473                                                                                 lenq + 1 /* MARK_SIGN */ + 2 /* flag */ );
474                                 entry = appendStrToText(entry, items[ i ].ptr, items[ i ].len, -1);
475                                 partialmatch[ *nentries ] = true;
476                                 entries[ *nentries ] =  PointerGetDatum( entry );
477                                 (*nentries)++;
478                         }
479                 }
480         } 
481         else
482         {
483                 /* *...  */
484
485                 if ( splitqlen == 1 )
486                 {
487                         /* any word => full scan */
488                         *nentries = 1;
489                         entry = appendStrToText(NULL, "", 0, lenq+1);
490                         partialmatch[0] = true;
491                         entries[0] = PointerGetDatum( entry );
492                 }
493                 else if ( items[ splitqlen-1 ].iswildcard == false )
494                 {
495                         /*     * [ X1 * [] ] X  => X$* [ + X1* [] ]  */
496                         *nentries = 1;
497                         entry = appendStrToText(NULL, items[ splitqlen-1 ].ptr, items[ splitqlen-1 ].len, lenq+1);
498                         entry = appendMarkToText( entry, -1 );
499                         partialmatch[0] = true;
500                         entries[0] = PointerGetDatum( entry );
501
502                         for(i=1; i<splitqlen-1; i++)
503                         {
504                                 if ( items[ i ].iswildcard )
505                                         continue;
506                                 entry = setFlagOfText( (i==1) ? (WC_MIDDLE | WC_BEGIN) : WC_MIDDLE, 
507                                                                                 lenq + 1 /* MARK_SIGN */ + 2 /* flag */ );
508                                 entry = appendStrToText(entry, items[ i ].ptr, items[ i ].len, -1);
509                                 partialmatch[ *nentries ] = true;
510                                 entries[ *nentries ] =  PointerGetDatum( entry );
511                                 (*nentries)++;
512                         }
513                 }
514                 else
515                 {
516                         /* * X [ * X1 [] ] * => X* [ + X1* [] ] */
517                         for(i=1; i<splitqlen-1; i++)
518                         {
519                                 if ( items[ i ].iswildcard )
520                                         continue;
521
522                                 if ( splitqlen > 3 )
523                                 {
524                                         if ( i==1 )
525                                                 entry = setFlagOfText( WC_MIDDLE | WC_BEGIN, lenq + 1 /* MARK_SIGN */ + 2 /* flag */ );
526                                         else if ( i == splitqlen-2 )
527                                                 entry = setFlagOfText( WC_MIDDLE | WC_END, lenq + 1 /* MARK_SIGN */ + 2 /* flag */ );
528                                         else
529                                                 entry = setFlagOfText( WC_MIDDLE, lenq + 1 /* MARK_SIGN */ + 2 /* flag */ ); 
530                                 }
531                                 else
532                                         entry = NULL;
533                                 entry = appendStrToText(entry, items[ i ].ptr, items[ i ].len, lenq+1);
534                                 partialmatch[ *nentries ] = true;
535                                 entries[ *nentries ] =  PointerGetDatum( entry );
536                                 (*nentries)++;
537                         }
538                 }
539         }
540
541         PG_FREE_IF_COPY(q,0);
542
543 #ifdef OPTIMIZE_WILDCARD_QUERY
544         if ( *nentries > 1 )
545                 optimize_wildcard_search( entries, nentries );
546 #endif
547
548         PG_RETURN_POINTER(entries);
549 }
550
551
552 PG_FUNCTION_INFO_V1(gin_consistent_wildcard);
553 Datum       gin_consistent_wildcard(PG_FUNCTION_ARGS);
554 Datum
555 gin_consistent_wildcard(PG_FUNCTION_ARGS)
556 {
557         bool            *check = (bool *) PG_GETARG_POINTER(0);
558         bool        res = true;
559         int         i;
560         int32           nentries;
561
562         if ( fcinfo->flinfo->fn_extra == NULL )
563         {
564                 bool    *pmatch;
565
566                 /*
567                  * we need to get nentries, we'll get it by regular way
568                  * and store it in function context
569                  */
570
571                 fcinfo->flinfo->fn_extra = MemoryContextAlloc(fcinfo->flinfo->fn_mcxt,
572                                                                                                                 sizeof(int32));
573
574                 DirectFunctionCall4(
575                                         gin_extract_wildcard,
576                                         PG_GETARG_DATUM(2),  /* query */
577                                         PointerGetDatum( fcinfo->flinfo->fn_extra ), /* &nentries */
578                                         PG_GETARG_DATUM(1),  /* strategy */
579                                         PointerGetDatum( &pmatch )
580                 );
581         }
582
583         nentries = *(int32*) fcinfo->flinfo->fn_extra;
584
585         for (i = 0; res && i < nentries; i++)
586                 if (check[i] == false)
587                         res = false;
588
589         PG_RETURN_BOOL(res);
590 }
591
592 /*
593  * Mostly debug fuction
594  */
595 PG_FUNCTION_INFO_V1(permute);
596 Datum       permute(PG_FUNCTION_ARGS);
597 Datum
598 permute(PG_FUNCTION_ARGS)
599 {
600         Datum           src = PG_GETARG_DATUM(0);
601         int32           nentries = 0;
602         Datum           *entries;
603         ArrayType       *res;
604         int             i;
605
606         /*
607          * Get permuted values by gin_extract_permuted()
608          */
609         entries = (Datum*) DatumGetPointer(DirectFunctionCall2(
610                                         gin_extract_permuted, src, PointerGetDatum(&nentries)
611                         ));
612
613         /*
614          * We need to replace MARK_SIGN to MARK_SIGN_SHOW.
615          * See comments above near definition of MARK_SIGN and MARK_SIGN_SHOW.
616          */
617         if ( nentries == 1 && VARSIZE(entries[0]) == VARHDRSZ + 1)
618         {
619                 *(VARDATA(entries[0])) = MARK_SIGN_SHOW;                
620         }
621         else
622         {
623                 int32   offset = 0; /* offset of MARK_SIGN */
624                 char    *ptr;
625
626                 /*
627                  * We scan array from the end because it allows simple calculation
628                  * of MARK_SIGN position: on every iteration it's moved one 
629                  * character to the end.
630                  */
631                 for(i=nentries-1;i>=0;i--) 
632                 {
633                         ptr = VARDATA(entries[i]);
634
635                         offset += pg_mblen(ptr);
636                         Assert( *(ptr + offset) == MARK_SIGN );
637                         *(ptr + offset) = MARK_SIGN_SHOW;
638                 }
639         }
640
641         res = construct_array(
642                                         entries,
643                                         nentries,
644                                         TEXTOID,
645                                         -1,
646                                         false,
647                                         'i'
648                         );
649
650         PG_RETURN_POINTER(res);
651 }