ce7e2ad70b8b0a1a3086c846454341ec08adeec1
[smlar.git] / smlar.c
1 #include "smlar.h"
2
3 #include "fmgr.h"
4 #include "access/genam.h"
5 #include "access/heapam.h"
6 #include "access/htup_details.h"
7 #include "access/nbtree.h"
8 #include "catalog/indexing.h"
9 #include "catalog/pg_am.h"
10 #include "catalog/pg_amproc.h"
11 #include "catalog/pg_cast.h"
12 #include "catalog/pg_opclass.h"
13 #include "catalog/pg_type.h"
14 #include "executor/spi.h"
15 #include "utils/catcache.h"
16 #include "utils/fmgroids.h"
17 #include "utils/lsyscache.h"
18 #include "utils/memutils.h"
19 #if (PG_VERSION_NUM < 120000)
20 #include "utils/tqual.h"
21 #endif
22 #include "utils/syscache.h"
23 #include "utils/typcache.h"
24
25 PG_MODULE_MAGIC;
26
27 #if (PG_VERSION_NUM >= 90400)
28 #define SNAPSHOT NULL
29 #else
30 #define SNAPSHOT SnapshotNow
31 #endif
32
33
34 static Oid
35 getDefaultOpclass(Oid amoid, Oid typid)
36 {
37         ScanKeyData     skey;
38         SysScanDesc     scan;
39         HeapTuple       tuple;
40         Relation        heapRel;
41         Oid                     opclassOid = InvalidOid;
42
43         heapRel = heap_open(OperatorClassRelationId, AccessShareLock);
44
45         ScanKeyInit(&skey,
46                                 Anum_pg_opclass_opcmethod,
47                                 BTEqualStrategyNumber,  F_OIDEQ,
48                                 ObjectIdGetDatum(amoid));
49
50         scan = systable_beginscan(heapRel,
51                                                                 OpclassAmNameNspIndexId, true,
52                                                                 SNAPSHOT, 1, &skey);
53
54         while (HeapTupleIsValid((tuple = systable_getnext(scan))))
55         {
56                 Form_pg_opclass opclass = (Form_pg_opclass)GETSTRUCT(tuple);
57
58                 if ( opclass->opcintype == typid && opclass->opcdefault )
59                 {
60                         if ( OidIsValid(opclassOid) )
61                                 elog(ERROR, "Ambiguous opclass for type %u (access method %u)", typid, amoid); 
62 #if (PG_VERSION_NUM >= 120000)
63                         opclassOid = opclass->oid;
64 #else
65                         opclassOid = HeapTupleGetOid(tuple);
66 #endif
67                 }
68         }
69
70         systable_endscan(scan);
71         heap_close(heapRel, AccessShareLock);
72
73         return opclassOid;
74 }
75
76 static Oid
77 getAMProc(Oid amoid, Oid typid)
78 {
79         Oid             opclassOid = getDefaultOpclass(amoid, typid);
80         Oid             procOid = InvalidOid;
81         Oid             opfamilyOid;
82         ScanKeyData     skey[4];
83         SysScanDesc     scan;
84         HeapTuple       tuple;
85         Relation        heapRel;
86
87         if ( !OidIsValid(opclassOid) )
88         {
89                 typid = getBaseType(typid);
90                 opclassOid = getDefaultOpclass(amoid, typid);
91         }
92
93         if ( !OidIsValid(opclassOid) )
94         {
95                 CatCList        *catlist;
96                 int                     i;
97
98                 /*
99                  * Search binary-coercible type
100                  */
101 #ifdef SearchSysCacheList1
102                 catlist = SearchSysCacheList1(CASTSOURCETARGET,
103                                                                           ObjectIdGetDatum(typid));
104 #else
105                 catlist = SearchSysCacheList(CASTSOURCETARGET, 1,
106                                                                                 ObjectIdGetDatum(typid),
107                                                                                 0, 0, 0);
108 #endif
109
110                 for (i = 0; i < catlist->n_members; i++)
111                 {
112                         HeapTuple               tuple = &catlist->members[i]->tuple;
113                         Form_pg_cast    castForm = (Form_pg_cast)GETSTRUCT(tuple);
114
115                         if ( castForm->castfunc == InvalidOid && castForm->castcontext == COERCION_CODE_IMPLICIT )
116                         {
117                                 typid = castForm->casttarget;
118                                 opclassOid = getDefaultOpclass(amoid, typid);
119                                 if( OidIsValid(opclassOid) )
120                                         break;
121                         }
122                 }
123
124                 ReleaseSysCacheList(catlist);
125         }
126
127         if ( !OidIsValid(opclassOid) )
128                 return InvalidOid;
129
130         opfamilyOid = get_opclass_family(opclassOid);
131
132         heapRel = heap_open(AccessMethodProcedureRelationId, AccessShareLock);
133         ScanKeyInit(&skey[0],
134                                 Anum_pg_amproc_amprocfamily,
135                                 BTEqualStrategyNumber, F_OIDEQ,
136                                 ObjectIdGetDatum(opfamilyOid));
137         ScanKeyInit(&skey[1],
138                                 Anum_pg_amproc_amproclefttype,
139                                 BTEqualStrategyNumber, F_OIDEQ,
140                                 ObjectIdGetDatum(typid));
141         ScanKeyInit(&skey[2],
142                                 Anum_pg_amproc_amprocrighttype,
143                                 BTEqualStrategyNumber, F_OIDEQ,
144                                 ObjectIdGetDatum(typid));
145 #if PG_VERSION_NUM >= 90200
146         ScanKeyInit(&skey[3],
147                                 Anum_pg_amproc_amprocnum,
148                                 BTEqualStrategyNumber, F_OIDEQ,
149                                 Int32GetDatum(BTORDER_PROC));
150 #endif
151
152         scan = systable_beginscan(heapRel, AccessMethodProcedureIndexId, true,
153                                                                 SNAPSHOT,
154 #if PG_VERSION_NUM >= 90200
155                                                                 4,
156 #else
157                                                                 3,
158 #endif
159                                                                 skey);
160         while (HeapTupleIsValid(tuple = systable_getnext(scan)))
161         {
162                 Form_pg_amproc amprocform = (Form_pg_amproc) GETSTRUCT(tuple);
163
164                 switch(amoid)
165                 {
166                         case BTREE_AM_OID:
167                         case HASH_AM_OID:
168                                 if ( OidIsValid(procOid) )
169                                         elog(ERROR,"Ambiguous support function for type %u (opclass %u)", typid, opfamilyOid);
170                                 procOid = amprocform->amproc;
171                                 break;
172                         default:
173                                 elog(ERROR,"Unsupported access method");
174                 }
175         }
176
177         systable_endscan(scan);
178         heap_close(heapRel, AccessShareLock);
179
180         return procOid;
181 }
182
183 static ProcTypeInfo *cacheProcs = NULL;
184 static int nCacheProcs = 0;
185
186 #ifndef TupleDescAttr
187 #define TupleDescAttr(tupdesc, i)       ((tupdesc)->attrs[(i)])
188 #endif
189
190 static ProcTypeInfo
191 fillProcs(Oid typid)
192 {
193         ProcTypeInfo    info = malloc(sizeof(ProcTypeInfoData));
194
195         if (!info)
196                 elog(ERROR, "Can't allocate %u memory", (uint32)sizeof(ProcTypeInfoData));
197
198         info->typid = typid;
199         info->typtype = get_typtype(typid);
200
201         if (info->typtype == 'c')
202         {
203                 /* composite type */
204                 TupleDesc               tupdesc;
205                 MemoryContext   oldcontext;
206
207                 tupdesc = lookup_rowtype_tupdesc(typid, -1);
208
209                 if (tupdesc->natts != 2)
210                         elog(ERROR,"Composite type has wrong number of fields");
211                 if (TupleDescAttr(tupdesc, 1)->atttypid != FLOAT4OID)
212                         elog(ERROR,"Second field of composite type is not float4");
213
214                 oldcontext = MemoryContextSwitchTo(TopMemoryContext);
215                 info->tupDesc = CreateTupleDescCopyConstr(tupdesc);
216                 MemoryContextSwitchTo(oldcontext);
217
218                 ReleaseTupleDesc(tupdesc);
219
220                 info->cmpFuncOid = getAMProc(BTREE_AM_OID,
221                                                                          TupleDescAttr(info->tupDesc, 0)->atttypid);
222                 info->hashFuncOid = getAMProc(HASH_AM_OID,
223                                                                           TupleDescAttr(info->tupDesc, 0)->atttypid);
224         }
225         else
226         {
227                 info->tupDesc = NULL;
228
229                 /* plain type */
230                 info->cmpFuncOid = getAMProc(BTREE_AM_OID, typid);
231                 info->hashFuncOid = getAMProc(HASH_AM_OID, typid);
232         }
233
234         get_typlenbyvalalign(typid, &info->typlen, &info->typbyval, &info->typalign);
235         info->hashFuncInited = info->cmpFuncInited = false;
236
237
238         return info;
239 }
240
241 void
242 getFmgrInfoCmp(ProcTypeInfo info)
243 {
244         if ( info->cmpFuncInited == false )
245         {
246                 if ( !OidIsValid(info->cmpFuncOid) )
247                         elog(ERROR, "Could not find cmp function for type %u", info->typid);
248
249                 fmgr_info_cxt( info->cmpFuncOid, &info->cmpFunc, TopMemoryContext );
250                 info->cmpFuncInited = true;
251         }
252 }
253
254 void
255 getFmgrInfoHash(ProcTypeInfo info)
256 {
257         if ( info->hashFuncInited == false )
258         {
259                 if ( !OidIsValid(info->hashFuncOid) )
260                         elog(ERROR, "Could not find hash function for type %u", info->typid);
261
262                 fmgr_info_cxt( info->hashFuncOid, &info->hashFunc, TopMemoryContext );
263                 info->hashFuncInited = true;
264         }
265 }
266
267 static int
268 cmpProcTypeInfo(const void *a, const void *b)
269 {
270         ProcTypeInfo av = *(ProcTypeInfo*)a;
271         ProcTypeInfo bv = *(ProcTypeInfo*)b;
272
273         Assert( av->typid != bv->typid );
274
275         return ( av->typid > bv->typid ) ? 1 : -1;
276 }
277
278 ProcTypeInfo
279 findProcs(Oid typid)
280 {
281         ProcTypeInfo    info = NULL;
282
283         if ( nCacheProcs == 1 )
284         {
285                 if ( cacheProcs[0]->typid == typid )
286                 {
287                         /*cacheProcs[0]->hashFuncInited = cacheProcs[0]->cmpFuncInited = false;*/
288                         return cacheProcs[0];
289                 }
290         }
291         else if ( nCacheProcs > 1 )
292         {
293                 ProcTypeInfo    *StopMiddle;
294                 ProcTypeInfo    *StopLow = cacheProcs,
295                                                 *StopHigh = cacheProcs + nCacheProcs;
296
297                 while (StopLow < StopHigh) {
298                         StopMiddle = StopLow + ((StopHigh - StopLow) >> 1);
299                         info = *StopMiddle;
300
301                         if ( info->typid == typid )
302                         {
303                                 /* info->hashFuncInited = info->cmpFuncInited = false; */
304                                 return info;
305                         }
306                         else if ( info->typid < typid )
307                                 StopLow = StopMiddle + 1;
308                         else
309                                 StopHigh = StopMiddle;
310                 }
311
312                 /* not found */
313         } 
314
315         info = fillProcs(typid);
316         if ( nCacheProcs == 0 )
317         {
318                 cacheProcs = malloc(sizeof(ProcTypeInfo));
319
320                 if (!cacheProcs)
321                         elog(ERROR, "Can't allocate %u memory", (uint32)sizeof(ProcTypeInfo));
322                 else
323                 {
324                         nCacheProcs = 1;
325                         cacheProcs[0] = info;
326                 }
327         }
328         else
329         {
330                 ProcTypeInfo    *cacheProcsTmp = realloc(cacheProcs, (nCacheProcs+1) * sizeof(ProcTypeInfo));
331
332                 if (!cacheProcsTmp)
333                         elog(ERROR, "Can't allocate %u memory", (uint32)sizeof(ProcTypeInfo) * (nCacheProcs+1));
334                 else
335                 {
336                         cacheProcs = cacheProcsTmp;
337                         cacheProcs[ nCacheProcs ] = info;
338                         nCacheProcs++;
339                         qsort(cacheProcs, nCacheProcs, sizeof(ProcTypeInfo), cmpProcTypeInfo);
340                 }
341         }
342
343         /* info->hashFuncInited = info->cmpFuncInited = false; */
344
345         return info;
346 }
347
348 /*
349  * WARNING. Array2SimpleArray* doesn't copy Datum!
350  */
351 SimpleArray * 
352 Array2SimpleArray(ProcTypeInfo info, ArrayType *a)
353 {
354         SimpleArray     *s = palloc(sizeof(SimpleArray));
355
356         CHECKARRVALID(a);
357
358         if ( info == NULL )
359                 info = findProcs(ARR_ELEMTYPE(a));
360
361         s->info = info;
362         s->df = NULL;
363         s->hash = NULL;
364
365         deconstruct_array(a, info->typid,
366                                                 info->typlen, info->typbyval, info->typalign,
367                                                 &s->elems, NULL, &s->nelems);
368
369         return s;
370 }
371
372 static Datum
373 deconstructCompositeType(ProcTypeInfo info, Datum in, double *weight)
374 {
375         HeapTupleHeader rec = DatumGetHeapTupleHeader(in);
376         HeapTupleData   tuple;
377         Datum                   values[2];
378         bool                    nulls[2];
379
380         /* Build a temporary HeapTuple control structure */
381         tuple.t_len = HeapTupleHeaderGetDatumLength(rec);
382         ItemPointerSetInvalid(&(tuple.t_self));
383         tuple.t_tableOid = InvalidOid;
384         tuple.t_data = rec;
385
386         heap_deform_tuple(&tuple, info->tupDesc, values, nulls);
387         if (nulls[0] || nulls[1])
388                 elog(ERROR, "Both fields in composite type could not be NULL");
389
390         if (weight)
391                 *weight = DatumGetFloat4(values[1]);
392         return values[0];
393 }
394
395 static int
396 cmpArrayElem(const void *a, const void *b, void *arg)
397 {
398         ProcTypeInfo    info = (ProcTypeInfo)arg;
399
400         if (info->tupDesc)
401                 /* composite type */
402                 return DatumGetInt32( FCall2( &info->cmpFunc,
403                                                 deconstructCompositeType(info, *(Datum*)a, NULL),
404                                                 deconstructCompositeType(info, *(Datum*)b, NULL) ) );
405
406         return DatumGetInt32( FCall2( &info->cmpFunc,
407                                                         *(Datum*)a, *(Datum*)b ) );
408 }
409
410 SimpleArray *
411 Array2SimpleArrayS(ProcTypeInfo info, ArrayType *a)
412 {
413         SimpleArray     *s = Array2SimpleArray(info, a);
414
415         if ( s->nelems > 1 )
416         {
417                 getFmgrInfoCmp(s->info);
418
419                 qsort_arg(s->elems, s->nelems, sizeof(Datum), cmpArrayElem, s->info);
420         }
421
422         return s;
423 }
424
425 typedef struct cmpArrayElemData {
426         ProcTypeInfo    info;
427         bool                    hasDuplicate;
428
429 } cmpArrayElemData;
430
431 static int
432 cmpArrayElemArg(const void *a, const void *b, void *arg)
433 {
434         cmpArrayElemData        *data = (cmpArrayElemData*)arg;
435         int                                     res;
436
437         if (data->info->tupDesc)
438                 res =  DatumGetInt32( FCall2( &data->info->cmpFunc,
439                                         deconstructCompositeType(data->info, *(Datum*)a, NULL),
440                                         deconstructCompositeType(data->info, *(Datum*)b, NULL) ) );
441         else
442                 res = DatumGetInt32( FCall2( &data->info->cmpFunc,
443                                                                 *(Datum*)a, *(Datum*)b ) );
444
445         if ( res == 0 )
446                 data->hasDuplicate = true;
447
448         return res;
449 }
450
451 /*
452  * Uniquefy array and calculate TF. Although 
453  * result doesn't depend on normalization, we
454  * normalize TF by length array to have possiblity
455  * to limit estimation for index support.
456  *
457  * Cache signals of needing of TF caclulation
458  */
459
460 SimpleArray *
461 Array2SimpleArrayU(ProcTypeInfo info, ArrayType *a, void *cache)
462 {
463         SimpleArray     *s = Array2SimpleArray(info, a);
464         StatElem        *stat = NULL;
465
466         if ( s->nelems > 0 && cache )
467         {
468                 s->df = palloc(sizeof(double) * s->nelems);
469                 s->df[0] = 1.0; /* init */
470         }
471
472         if ( s->nelems > 1 )
473         {
474                 cmpArrayElemData        data;
475                 int                                     i;
476
477                 getFmgrInfoCmp(s->info);
478                 data.info = s->info;
479                 data.hasDuplicate = false;
480
481                 qsort_arg(s->elems, s->nelems, sizeof(Datum), cmpArrayElemArg, &data);
482
483                 if ( data.hasDuplicate )
484                 {
485                         Datum   *tmp,
486                                         *dr,
487                                         *data;
488                         int             num = s->nelems,
489                                         cmp;
490
491                         data = tmp = dr = s->elems;
492
493                         while (tmp - data < num)
494                         {
495                                 cmp = (tmp == dr) ? 0 : cmpArrayElem(tmp, dr, s->info);
496                                 if ( cmp != 0 )
497                                 {
498                                         *(++dr) = *tmp++;
499                                         if ( cache ) 
500                                                 s->df[ dr - data ] = 1.0;
501                                 }
502                                 else
503                                 {
504                                         if ( cache )
505                                                 s->df[ dr - data ] += 1.0;
506                                         tmp++;
507                                 }
508                         }
509
510                         s->nelems = dr + 1 - s->elems;
511
512                         if ( cache )
513                         {
514                                 int tfm = getTFMethod();
515
516                                 for(i=0;i<s->nelems;i++)
517                                 {
518                                         stat = fingArrayStat(cache, s->info->typid, s->elems[i], stat);
519                                         if ( stat )
520                                         {
521                                                 switch(tfm)
522                                                 {
523                                                         case TF_LOG:
524                                                                 s->df[i] = (1.0 + log( s->df[i] ));
525                                                         case TF_N:
526                                                                 s->df[i] *= stat->idf;
527                                                                 break;
528                                                         case TF_CONST:
529                                                                 s->df[i] = stat->idf;
530                                                                 break;
531                                                         default:
532                                                                 elog(ERROR,"Unknown TF method: %d", tfm);
533                                                 }
534                                         }
535                                         else
536                                         {
537                                                 s->df[i] = 0.0; /* unknown word */
538                                         }
539                                 }
540                         }
541                 }
542                 else if ( cache )
543                 {
544                         for(i=0;i<s->nelems;i++)
545                         {
546                                 stat = fingArrayStat(cache, s->info->typid, s->elems[i], stat);
547                                 if ( stat )
548                                         s->df[i] = stat->idf;
549                                 else
550                                         s->df[i] = 0.0;
551                         }
552                 }
553         }
554         else if (s->nelems > 0 && cache)
555         {
556                 stat = fingArrayStat(cache, s->info->typid, s->elems[0], stat);
557                 if ( stat )
558                         s->df[0] = stat->idf;
559                 else
560                         s->df[0] = 0.0;
561         }
562
563         return s;
564 }
565
566 static int
567 numOfIntersect(SimpleArray *a, SimpleArray *b)
568 {
569         int                             cnt = 0,
570                                         cmp;
571         Datum                   *aptr = a->elems,
572                                         *bptr = b->elems;
573         ProcTypeInfo    info = a->info;
574
575         Assert( a->info->typid == b->info->typid );
576
577         getFmgrInfoCmp(info);
578
579         while( aptr - a->elems < a->nelems && bptr - b->elems < b->nelems )
580         {
581                 cmp = cmpArrayElem(aptr, bptr, info);
582                 if ( cmp < 0 )
583                         aptr++;
584                 else if ( cmp > 0 )
585                         bptr++;
586                 else
587                 {
588                         cnt++;
589                         aptr++;
590                         bptr++;
591                 }
592         }
593
594         return cnt;
595 }
596
597 static double
598 TFIDFSml(SimpleArray *a, SimpleArray *b)
599 {
600         int                             cmp;
601         Datum                   *aptr = a->elems,
602                                         *bptr = b->elems;
603         ProcTypeInfo    info = a->info;
604         double                  res = 0.0;
605         double                  suma = 0.0, sumb = 0.0;
606
607         Assert( a->info->typid == b->info->typid );
608         Assert( a->df );
609         Assert( b->df );
610
611         getFmgrInfoCmp(info);
612
613         while( aptr - a->elems < a->nelems && bptr - b->elems < b->nelems )
614         {
615                 cmp = cmpArrayElem(aptr, bptr, info);
616                 if ( cmp < 0 )
617                 {
618                         suma += a->df[ aptr - a->elems ] * a->df[ aptr - a->elems ];
619                         aptr++;
620                 }
621                 else if ( cmp > 0 )
622                 {
623                         sumb += b->df[ bptr - b->elems ] * b->df[ bptr - b->elems ];
624                         bptr++;
625                 }
626                 else
627                 {
628                         res += a->df[ aptr - a->elems ] * b->df[ bptr - b->elems ];
629                         suma += a->df[ aptr - a->elems ] * a->df[ aptr - a->elems ];
630                         sumb += b->df[ bptr - b->elems ] * b->df[ bptr - b->elems ];
631                         aptr++;
632                         bptr++;
633                 }
634         }
635
636         /*
637          * Compute last elements
638          */
639         while( aptr - a->elems < a->nelems )
640         {
641                 suma += a->df[ aptr - a->elems ] * a->df[ aptr - a->elems ];
642                 aptr++;
643         }
644
645         while( bptr - b->elems < b->nelems )
646         {
647                 sumb += b->df[ bptr - b->elems ] * b->df[ bptr - b->elems ];
648                 bptr++;
649         }
650
651         if ( suma > 0.0 && sumb > 0.0 )
652                 res = res / sqrt( suma * sumb );
653         else
654                 res = 0.0;
655
656         return res;
657 }
658
659
660 PG_FUNCTION_INFO_V1(arraysml);
661 Datum   arraysml(PG_FUNCTION_ARGS);
662 Datum
663 arraysml(PG_FUNCTION_ARGS)
664 {
665         ArrayType               *a, *b;
666         SimpleArray             *sa, *sb;
667
668         fcinfo->flinfo->fn_extra = SearchArrayCache(
669                                                         fcinfo->flinfo->fn_extra,
670                                                         fcinfo->flinfo->fn_mcxt,
671                                                         PG_GETARG_DATUM(0), &a, &sa, NULL);
672         fcinfo->flinfo->fn_extra = SearchArrayCache(
673                                                         fcinfo->flinfo->fn_extra,
674                                                         fcinfo->flinfo->fn_mcxt,
675                                                         PG_GETARG_DATUM(1), &b, &sb, NULL);
676
677         if ( ARR_ELEMTYPE(a) != ARR_ELEMTYPE(b) )
678                 elog(ERROR,"Arguments array are not the same type!");
679
680         if (ARRISVOID(a) || ARRISVOID(b))
681                  PG_RETURN_FLOAT4(0.0);
682
683         switch(getSmlType())
684         {
685                 case ST_TFIDF:
686                         PG_RETURN_FLOAT4( TFIDFSml(sa, sb) );
687                         break;
688                 case ST_COSINE:
689                         {
690                                 int                             cnt;
691                                 double                  power;
692
693                                 power = ((double)(sa->nelems)) * ((double)(sb->nelems));
694                                 cnt = numOfIntersect(sa, sb);
695
696                                 PG_RETURN_FLOAT4(  ((double)cnt) / sqrt( power ) );
697                         }
698                         break;
699                 case ST_OVERLAP:
700                         {
701                                 float4 res = (float4)numOfIntersect(sa, sb);
702
703                                 PG_RETURN_FLOAT4(res);
704                         }
705                         break;
706                 default:
707                         elog(ERROR,"Unsupported formula type of similarity");
708         }
709
710         PG_RETURN_FLOAT4(0.0); /* keep compiler quiet */
711 }
712
713 PG_FUNCTION_INFO_V1(arraysmlw);
714 Datum   arraysmlw(PG_FUNCTION_ARGS);
715 Datum
716 arraysmlw(PG_FUNCTION_ARGS)
717 {
718         ArrayType               *a, *b;
719         SimpleArray             *sa, *sb;
720         bool                    useIntersect = PG_GETARG_BOOL(2);
721         double                  numerator = 0.0;
722         double                  denominatorA = 0.0,
723                                         denominatorB = 0.0,
724                                         tmpA, tmpB;
725         int                             cmp;
726         ProcTypeInfo    info;
727         int                             ai = 0, bi = 0;
728
729         fcinfo->flinfo->fn_extra = SearchArrayCache(
730                                                         fcinfo->flinfo->fn_extra,
731                                                         fcinfo->flinfo->fn_mcxt,
732                                                         PG_GETARG_DATUM(0), &a, &sa, NULL);
733         fcinfo->flinfo->fn_extra = SearchArrayCache(
734                                                         fcinfo->flinfo->fn_extra,
735                                                         fcinfo->flinfo->fn_mcxt,
736                                                         PG_GETARG_DATUM(1), &b, &sb, NULL);
737
738         if ( ARR_ELEMTYPE(a) != ARR_ELEMTYPE(b) )
739                 elog(ERROR,"Arguments array are not the same type!");
740
741         if (ARRISVOID(a) || ARRISVOID(b))
742                  PG_RETURN_FLOAT4(0.0);
743
744         info = sa->info;
745         if (info->tupDesc == NULL)
746                 elog(ERROR, "Only weigthed (composite) types should be used");
747         getFmgrInfoCmp(info);
748
749         while(ai < sa->nelems && bi < sb->nelems)
750         {
751                 Datum   ad = deconstructCompositeType(info, sa->elems[ai], &tmpA),
752                                 bd = deconstructCompositeType(info, sb->elems[bi], &tmpB);
753
754                 cmp = DatumGetInt32(FCall2(&info->cmpFunc, ad, bd));
755
756                 if ( cmp < 0 ) {
757                         if (useIntersect == false)
758                                 denominatorA += tmpA * tmpA;
759                         ai++;
760                 } else if ( cmp > 0 ) {
761                         if (useIntersect == false)
762                                 denominatorB += tmpB * tmpB;
763                         bi++;
764                 } else {
765                         denominatorA += tmpA * tmpA;
766                         denominatorB += tmpB * tmpB;
767                         numerator += tmpA * tmpB;
768                         ai++;
769                         bi++;
770                 }
771         }
772
773         if (useIntersect == false) {
774                 while(ai < sa->nelems) {
775                         deconstructCompositeType(info, sa->elems[ai], &tmpA);
776                         denominatorA += tmpA * tmpA;
777                         ai++;
778                 }
779
780                 while(bi < sb->nelems) {
781                         deconstructCompositeType(info, sb->elems[bi], &tmpB);
782                         denominatorB += tmpB * tmpB;
783                         bi++;
784                 }
785         }
786
787         if (numerator != 0.0) {
788                 numerator = numerator / sqrt( denominatorA * denominatorB );
789         }
790
791         PG_RETURN_FLOAT4(numerator);
792 }
793
794 PG_FUNCTION_INFO_V1(arraysml_op);
795 Datum   arraysml_op(PG_FUNCTION_ARGS);
796 Datum
797 arraysml_op(PG_FUNCTION_ARGS)
798 {
799         ArrayType               *a, *b;
800         SimpleArray             *sa, *sb;
801         double                  power = 0.0;
802
803         fcinfo->flinfo->fn_extra = SearchArrayCache(
804                                                         fcinfo->flinfo->fn_extra,
805                                                         fcinfo->flinfo->fn_mcxt,
806                                                         PG_GETARG_DATUM(0), &a, &sa, NULL);
807         fcinfo->flinfo->fn_extra = SearchArrayCache(
808                                                         fcinfo->flinfo->fn_extra,
809                                                         fcinfo->flinfo->fn_mcxt,
810                                                         PG_GETARG_DATUM(1), &b, &sb, NULL);
811
812         if ( ARR_ELEMTYPE(a) != ARR_ELEMTYPE(b) )
813                 elog(ERROR,"Arguments array are not the same type!");
814
815         if (ARRISVOID(a) || ARRISVOID(b))
816                  PG_RETURN_BOOL(false);
817
818         switch(getSmlType())
819         {
820                 case ST_TFIDF:
821                         power = TFIDFSml(sa, sb);
822                         break;
823                 case ST_COSINE:
824                         {
825                                 int                             cnt;
826
827                                 power = sqrt( ((double)(sa->nelems)) * ((double)(sb->nelems)) );
828
829                                 if (  ((double)Min(sa->nelems, sb->nelems)) / power < GetSmlarLimit()  )
830                                         PG_RETURN_BOOL(false);
831
832                                 cnt = numOfIntersect(sa, sb);
833                                 power = ((double)cnt) / power;
834                         }
835                         break;
836                 case ST_OVERLAP:
837                         power = (double)numOfIntersect(sa, sb);
838                         break;
839                 default:
840                         elog(ERROR,"Unsupported formula type of similarity");
841         }
842
843         PG_RETURN_BOOL(power >= GetSmlarLimit());
844 }
845
846 #define QBSIZE          8192
847 static char cachedFormula[QBSIZE];
848 static int      cachedLen  = 0;
849 static void     *cachedPlan = NULL;
850
851 PG_FUNCTION_INFO_V1(arraysml_func);
852 Datum   arraysml_func(PG_FUNCTION_ARGS);
853 Datum
854 arraysml_func(PG_FUNCTION_ARGS)
855 {
856         ArrayType               *a, *b;
857         SimpleArray             *sa, *sb;
858         int                             cnt;
859         float4                  result = -1.0;
860         Oid                             arg[] = {INT4OID, INT4OID, INT4OID};
861         Datum                   pars[3];
862         bool                    isnull;
863         void                    *plan;
864         int                             stat;
865         text                    *formula = PG_GETARG_TEXT_P(2);
866
867         fcinfo->flinfo->fn_extra = SearchArrayCache(
868                                                         fcinfo->flinfo->fn_extra,
869                                                         fcinfo->flinfo->fn_mcxt,
870                                                         PG_GETARG_DATUM(0), &a, &sa, NULL);
871         fcinfo->flinfo->fn_extra = SearchArrayCache(
872                                                         fcinfo->flinfo->fn_extra,
873                                                         fcinfo->flinfo->fn_mcxt,
874                                                         PG_GETARG_DATUM(1), &b, &sb, NULL);
875
876         if ( ARR_ELEMTYPE(a) != ARR_ELEMTYPE(b) )
877                 elog(ERROR,"Arguments array are not the same type!");
878
879         if (ARRISVOID(a) || ARRISVOID(b))
880                  PG_RETURN_BOOL(false);
881
882         cnt = numOfIntersect(sa, sb);
883
884         if ( VARSIZE(formula) - VARHDRSZ > QBSIZE - 1024 )
885                 elog(ERROR,"Formula is too long");
886
887         SPI_connect();
888
889         if ( cachedPlan == NULL || cachedLen != VARSIZE(formula) - VARHDRSZ ||
890                                 memcmp( cachedFormula, VARDATA(formula), VARSIZE(formula) - VARHDRSZ ) != 0 )
891         {
892                 char                    *ptr, buf[QBSIZE];
893
894                 *cachedFormula = '\0';
895                 if ( cachedPlan )
896                         SPI_freeplan(cachedPlan);
897                 cachedPlan = NULL;
898                 cachedLen = 0;
899
900                 ptr = stpcpy( buf, "SELECT (" );
901                 memcpy( ptr, VARDATA(formula), VARSIZE(formula) - VARHDRSZ );
902                 ptr += VARSIZE(formula) - VARHDRSZ;
903                 ptr = stpcpy( ptr, ")::float4 FROM");
904                 ptr = stpcpy( ptr, " (SELECT $1 ::float8 AS i, $2 ::float8 AS a, $3 ::float8 AS b) AS N;");
905                 *ptr = '\0';
906
907                 plan = SPI_prepare(buf, 3, arg);
908                 if (!plan)
909                         elog(ERROR, "SPI_prepare() failed");
910
911                 cachedPlan = SPI_saveplan(plan);
912                 if (!cachedPlan)
913                         elog(ERROR, "SPI_saveplan() failed");
914
915                 SPI_freeplan(plan);
916                 cachedLen = VARSIZE(formula) - VARHDRSZ;
917                 memcpy( cachedFormula, VARDATA(formula), VARSIZE(formula) - VARHDRSZ );
918         }
919
920         plan = cachedPlan;
921
922
923         pars[0] = Int32GetDatum( cnt );
924         pars[1] = Int32GetDatum( sa->nelems );
925         pars[2] = Int32GetDatum( sb->nelems );
926
927         stat = SPI_execute_plan(plan, pars, NULL, true, 3);
928         if (stat < 0)
929                 elog(ERROR, "SPI_execute_plan() returns %d", stat);
930
931         if ( SPI_processed > 0)
932                 result = DatumGetFloat4(SPI_getbinval(SPI_tuptable->vals[0], SPI_tuptable->tupdesc, 1, &isnull));
933
934         SPI_finish();
935
936         PG_RETURN_FLOAT4(result);
937 }
938
939 PG_FUNCTION_INFO_V1(array_unique);
940 Datum   array_unique(PG_FUNCTION_ARGS);
941 Datum
942 array_unique(PG_FUNCTION_ARGS)
943 {
944         ArrayType               *a = PG_GETARG_ARRAYTYPE_P(0);
945         ArrayType               *res;
946         SimpleArray             *sa;
947
948         sa = Array2SimpleArrayU(NULL, a, NULL);
949
950         res = construct_array(  sa->elems, 
951                                                         sa->nelems,
952                                                         sa->info->typid,
953                                                         sa->info->typlen,
954                                                         sa->info->typbyval,
955                                                         sa->info->typalign);
956
957         pfree(sa->elems);
958         pfree(sa);
959         PG_FREE_IF_COPY(a, 0);
960
961         PG_RETURN_ARRAYTYPE_P(res);
962 }
963
964 PG_FUNCTION_INFO_V1(inarray);
965 Datum   inarray(PG_FUNCTION_ARGS);
966 Datum
967 inarray(PG_FUNCTION_ARGS)
968 {
969         ArrayType               *a;
970         SimpleArray             *sa;
971         Datum                   query = PG_GETARG_DATUM(1);
972         Oid                             queryTypeOid;
973         Datum                   *StopLow,
974                                         *StopHigh,
975                                         *StopMiddle;
976         int                             cmp;
977
978         fcinfo->flinfo->fn_extra = SearchArrayCache(
979                                                         fcinfo->flinfo->fn_extra,
980                                                         fcinfo->flinfo->fn_mcxt,
981                                                         PG_GETARG_DATUM(0), &a, &sa, NULL);
982
983         queryTypeOid = get_fn_expr_argtype(fcinfo->flinfo, 1);
984
985         if ( queryTypeOid == InvalidOid )
986                 elog(ERROR,"inarray: could not determine actual argument type");
987
988         if ( queryTypeOid != sa->info->typid )
989                 elog(ERROR,"inarray: Type of array's element and type of argument are not the same");
990
991         getFmgrInfoCmp(sa->info);
992         StopLow = sa->elems;
993         StopHigh = sa->elems + sa->nelems;
994
995         while (StopLow < StopHigh)
996         {
997                 StopMiddle = StopLow + ((StopHigh - StopLow) >> 1);
998                 cmp = cmpArrayElem(StopMiddle, &query, sa->info);
999
1000                 if ( cmp == 0 )
1001                 {
1002                         /* found */
1003                         if ( PG_NARGS() >= 3 )
1004                                 PG_RETURN_DATUM(PG_GETARG_DATUM(2));
1005                         PG_RETURN_FLOAT4(1.0);
1006                 }
1007                 else if (cmp < 0)
1008                         StopLow = StopMiddle + 1;
1009                 else
1010                         StopHigh = StopMiddle;
1011         }
1012
1013         if ( PG_NARGS() >= 4 )
1014                 PG_RETURN_DATUM(PG_GETARG_DATUM(3));
1015         PG_RETURN_FLOAT4(0.0);
1016 }