13c8169ba49a2b6a7e6432146baea0e639095100
[smlar.git] / smlar.c
1 #include "smlar.h"
2
3 #include "fmgr.h"
4 #include "access/genam.h"
5 #include "access/heapam.h"
6 #include "access/htup_details.h"
7 #include "access/nbtree.h"
8 #include "catalog/indexing.h"
9 #include "catalog/pg_am.h"
10 #include "catalog/pg_amproc.h"
11 #include "catalog/pg_cast.h"
12 #include "catalog/pg_opclass.h"
13 #include "catalog/pg_type.h"
14 #include "executor/spi.h"
15 #include "utils/catcache.h"
16 #include "utils/fmgroids.h"
17 #include "utils/lsyscache.h"
18 #include "utils/memutils.h"
19 #include "utils/tqual.h"
20 #include "utils/syscache.h"
21 #include "utils/typcache.h"
22
23 PG_MODULE_MAGIC;
24
25 #if (PG_VERSION_NUM >= 90400)
26 #define SNAPSHOT NULL
27 #else
28 #define SNAPSHOT SnapshotNow
29 #endif
30
31 static Oid
32 getDefaultOpclass(Oid amoid, Oid typid)
33 {
34         ScanKeyData     skey;
35         SysScanDesc     scan;
36         HeapTuple       tuple;
37         Relation        heapRel;
38         Oid                     opclassOid = InvalidOid;
39
40         heapRel = heap_open(OperatorClassRelationId, AccessShareLock);
41
42         ScanKeyInit(&skey,
43                                 Anum_pg_opclass_opcmethod,
44                                 BTEqualStrategyNumber,  F_OIDEQ,
45                                 ObjectIdGetDatum(amoid));
46
47         scan = systable_beginscan(heapRel,
48                                                                 OpclassAmNameNspIndexId, true,
49                                                                 SNAPSHOT, 1, &skey);
50
51         while (HeapTupleIsValid((tuple = systable_getnext(scan))))
52         {
53                 Form_pg_opclass opclass = (Form_pg_opclass)GETSTRUCT(tuple);
54
55                 if ( opclass->opcintype == typid && opclass->opcdefault )
56                 {
57                         if ( OidIsValid(opclassOid) )
58                                 elog(ERROR, "Ambiguous opclass for type %u (access method %u)", typid, amoid); 
59                         opclassOid = HeapTupleGetOid(tuple);
60                 }
61         }
62
63         systable_endscan(scan);
64         heap_close(heapRel, AccessShareLock);
65
66         return opclassOid;
67 }
68
69 static Oid
70 getAMProc(Oid amoid, Oid typid)
71 {
72         Oid             opclassOid = getDefaultOpclass(amoid, typid);
73         Oid             procOid = InvalidOid;
74         Oid             opfamilyOid;
75         ScanKeyData     skey[4];
76         SysScanDesc     scan;
77         HeapTuple       tuple;
78         Relation        heapRel;
79
80         if ( !OidIsValid(opclassOid) )
81         {
82                 typid = getBaseType(typid);
83                 opclassOid = getDefaultOpclass(amoid, typid);
84         }
85
86         if ( !OidIsValid(opclassOid) )
87         {
88                 CatCList        *catlist;
89                 int                     i;
90
91                 /*
92                  * Search binary-coercible type
93                  */
94 #ifdef SearchSysCacheList1
95                 catlist = SearchSysCacheList1(CASTSOURCETARGET,
96                                                                           ObjectIdGetDatum(typid));
97 #else
98                 catlist = SearchSysCacheList(CASTSOURCETARGET, 1,
99                                                                                 ObjectIdGetDatum(typid),
100                                                                                 0, 0, 0);
101 #endif
102
103                 for (i = 0; i < catlist->n_members; i++)
104                 {
105                         HeapTuple               tuple = &catlist->members[i]->tuple;
106                         Form_pg_cast    castForm = (Form_pg_cast)GETSTRUCT(tuple);
107
108                         if ( castForm->castfunc == InvalidOid && castForm->castcontext == COERCION_CODE_IMPLICIT )
109                         {
110                                 typid = castForm->casttarget;
111                                 opclassOid = getDefaultOpclass(amoid, typid);
112                                 if( OidIsValid(opclassOid) )
113                                         break;
114                         }
115                 }
116
117                 ReleaseSysCacheList(catlist);
118         }
119
120         if ( !OidIsValid(opclassOid) )
121                 return InvalidOid;
122
123         opfamilyOid = get_opclass_family(opclassOid);
124
125         heapRel = heap_open(AccessMethodProcedureRelationId, AccessShareLock);
126         ScanKeyInit(&skey[0],
127                                 Anum_pg_amproc_amprocfamily,
128                                 BTEqualStrategyNumber, F_OIDEQ,
129                                 ObjectIdGetDatum(opfamilyOid));
130         ScanKeyInit(&skey[1],
131                                 Anum_pg_amproc_amproclefttype,
132                                 BTEqualStrategyNumber, F_OIDEQ,
133                                 ObjectIdGetDatum(typid));
134         ScanKeyInit(&skey[2],
135                                 Anum_pg_amproc_amprocrighttype,
136                                 BTEqualStrategyNumber, F_OIDEQ,
137                                 ObjectIdGetDatum(typid));
138 #if PG_VERSION_NUM >= 90200
139         ScanKeyInit(&skey[3],
140                                 Anum_pg_amproc_amprocnum,
141                                 BTEqualStrategyNumber, F_OIDEQ,
142                                 Int32GetDatum(BTORDER_PROC));
143 #endif
144
145         scan = systable_beginscan(heapRel, AccessMethodProcedureIndexId, true,
146                                                                 SNAPSHOT,
147 #if PG_VERSION_NUM >= 90200
148                                                                 4,
149 #else
150                                                                 3,
151 #endif
152                                                                 skey);
153         while (HeapTupleIsValid(tuple = systable_getnext(scan)))
154         {
155                 Form_pg_amproc amprocform = (Form_pg_amproc) GETSTRUCT(tuple);
156
157                 switch(amoid)
158                 {
159                         case BTREE_AM_OID:
160                         case HASH_AM_OID:
161                                 if ( OidIsValid(procOid) )
162                                         elog(ERROR,"Ambiguous support function for type %u (opclass %u)", typid, opfamilyOid);
163                                 procOid = amprocform->amproc;
164                                 break;
165                         default:
166                                 elog(ERROR,"Unsupported access method");
167                 }
168         }
169
170         systable_endscan(scan);
171         heap_close(heapRel, AccessShareLock);
172
173         return procOid;
174 }
175
176 static ProcTypeInfo *cacheProcs = NULL;
177 static int nCacheProcs = 0;
178
179 #ifndef TupleDescAttr
180 #define TupleDescAttr(tupdesc, i)       ((tupdesc)->attrs[(i)])
181 #endif
182
183 static ProcTypeInfo
184 fillProcs(Oid typid)
185 {
186         ProcTypeInfo    info = malloc(sizeof(ProcTypeInfoData));
187
188         if (!info)
189                 elog(ERROR, "Can't allocate %u memory", (uint32)sizeof(ProcTypeInfoData));
190
191         info->typid = typid;
192         info->typtype = get_typtype(typid);
193
194         if (info->typtype == 'c')
195         {
196                 /* composite type */
197                 TupleDesc               tupdesc;
198                 MemoryContext   oldcontext;
199
200                 tupdesc = lookup_rowtype_tupdesc(typid, -1);
201
202                 if (tupdesc->natts != 2)
203                         elog(ERROR,"Composite type has wrong number of fields");
204                 if (TupleDescAttr(tupdesc, 1)->atttypid != FLOAT4OID)
205                         elog(ERROR,"Second field of composite type is not float4");
206
207                 oldcontext = MemoryContextSwitchTo(TopMemoryContext);
208                 info->tupDesc = CreateTupleDescCopyConstr(tupdesc);
209                 MemoryContextSwitchTo(oldcontext);
210
211                 ReleaseTupleDesc(tupdesc);
212
213                 info->cmpFuncOid = getAMProc(BTREE_AM_OID,
214                                                                          TupleDescAttr(info->tupDesc, 0)->atttypid);
215                 info->hashFuncOid = getAMProc(HASH_AM_OID,
216                                                                           TupleDescAttr(info->tupDesc, 0)->atttypid);
217         }
218         else
219         {
220                 info->tupDesc = NULL;
221
222                 /* plain type */
223                 info->cmpFuncOid = getAMProc(BTREE_AM_OID, typid);
224                 info->hashFuncOid = getAMProc(HASH_AM_OID, typid);
225         }
226
227         get_typlenbyvalalign(typid, &info->typlen, &info->typbyval, &info->typalign);
228         info->hashFuncInited = info->cmpFuncInited = false;
229
230
231         return info;
232 }
233
234 void
235 getFmgrInfoCmp(ProcTypeInfo info)
236 {
237         if ( info->cmpFuncInited == false )
238         {
239                 if ( !OidIsValid(info->cmpFuncOid) )
240                         elog(ERROR, "Could not find cmp function for type %u", info->typid);
241
242                 fmgr_info_cxt( info->cmpFuncOid, &info->cmpFunc, TopMemoryContext );
243                 info->cmpFuncInited = true;
244         }
245 }
246
247 void
248 getFmgrInfoHash(ProcTypeInfo info)
249 {
250         if ( info->hashFuncInited == false )
251         {
252                 if ( !OidIsValid(info->hashFuncOid) )
253                         elog(ERROR, "Could not find hash function for type %u", info->typid);
254
255                 fmgr_info_cxt( info->hashFuncOid, &info->hashFunc, TopMemoryContext );
256                 info->hashFuncInited = true;
257         }
258 }
259
260 static int
261 cmpProcTypeInfo(const void *a, const void *b)
262 {
263         ProcTypeInfo av = *(ProcTypeInfo*)a;
264         ProcTypeInfo bv = *(ProcTypeInfo*)b;
265
266         Assert( av->typid != bv->typid );
267
268         return ( av->typid > bv->typid ) ? 1 : -1;
269 }
270
271 ProcTypeInfo
272 findProcs(Oid typid)
273 {
274         ProcTypeInfo    info = NULL;
275
276         if ( nCacheProcs == 1 )
277         {
278                 if ( cacheProcs[0]->typid == typid )
279                 {
280                         /*cacheProcs[0]->hashFuncInited = cacheProcs[0]->cmpFuncInited = false;*/
281                         return cacheProcs[0];
282                 }
283         }
284         else if ( nCacheProcs > 1 )
285         {
286                 ProcTypeInfo    *StopMiddle;
287                 ProcTypeInfo    *StopLow = cacheProcs,
288                                                 *StopHigh = cacheProcs + nCacheProcs;
289
290                 while (StopLow < StopHigh) {
291                         StopMiddle = StopLow + ((StopHigh - StopLow) >> 1);
292                         info = *StopMiddle;
293
294                         if ( info->typid == typid )
295                         {
296                                 /* info->hashFuncInited = info->cmpFuncInited = false; */
297                                 return info;
298                         }
299                         else if ( info->typid < typid )
300                                 StopLow = StopMiddle + 1;
301                         else
302                                 StopHigh = StopMiddle;
303                 }
304
305                 /* not found */
306         } 
307
308         info = fillProcs(typid);
309         if ( nCacheProcs == 0 )
310         {
311                 cacheProcs = malloc(sizeof(ProcTypeInfo));
312
313                 if (!cacheProcs)
314                         elog(ERROR, "Can't allocate %u memory", (uint32)sizeof(ProcTypeInfo));
315                 else
316                 {
317                         nCacheProcs = 1;
318                         cacheProcs[0] = info;
319                 }
320         }
321         else
322         {
323                 ProcTypeInfo    *cacheProcsTmp = realloc(cacheProcs, (nCacheProcs+1) * sizeof(ProcTypeInfo));
324
325                 if (!cacheProcsTmp)
326                         elog(ERROR, "Can't allocate %u memory", (uint32)sizeof(ProcTypeInfo) * (nCacheProcs+1));
327                 else
328                 {
329                         cacheProcs = cacheProcsTmp;
330                         cacheProcs[ nCacheProcs ] = info;
331                         nCacheProcs++;
332                         qsort(cacheProcs, nCacheProcs, sizeof(ProcTypeInfo), cmpProcTypeInfo);
333                 }
334         }
335
336         /* info->hashFuncInited = info->cmpFuncInited = false; */
337
338         return info;
339 }
340
341 /*
342  * WARNING. Array2SimpleArray* doesn't copy Datum!
343  */
344 SimpleArray * 
345 Array2SimpleArray(ProcTypeInfo info, ArrayType *a)
346 {
347         SimpleArray     *s = palloc(sizeof(SimpleArray));
348
349         CHECKARRVALID(a);
350
351         if ( info == NULL )
352                 info = findProcs(ARR_ELEMTYPE(a));
353
354         s->info = info;
355         s->df = NULL;
356         s->hash = NULL;
357
358         deconstruct_array(a, info->typid,
359                                                 info->typlen, info->typbyval, info->typalign,
360                                                 &s->elems, NULL, &s->nelems);
361
362         return s;
363 }
364
365 static Datum
366 deconstructCompositeType(ProcTypeInfo info, Datum in, double *weight)
367 {
368         HeapTupleHeader rec = DatumGetHeapTupleHeader(in);
369         HeapTupleData   tuple;
370         Datum                   values[2];
371         bool                    nulls[2];
372
373         /* Build a temporary HeapTuple control structure */
374         tuple.t_len = HeapTupleHeaderGetDatumLength(rec);
375         ItemPointerSetInvalid(&(tuple.t_self));
376         tuple.t_tableOid = InvalidOid;
377         tuple.t_data = rec;
378
379         heap_deform_tuple(&tuple, info->tupDesc, values, nulls);
380         if (nulls[0] || nulls[1])
381                 elog(ERROR, "Both fields in composite type could not be NULL");
382
383         if (weight)
384                 *weight = DatumGetFloat4(values[1]);
385         return values[0];
386 }
387
388 static int
389 cmpArrayElem(const void *a, const void *b, void *arg)
390 {
391         ProcTypeInfo    info = (ProcTypeInfo)arg;
392
393         if (info->tupDesc)
394                 /* composite type */
395                 return DatumGetInt32( FCall2( &info->cmpFunc,
396                                                 deconstructCompositeType(info, *(Datum*)a, NULL),
397                                                 deconstructCompositeType(info, *(Datum*)b, NULL) ) );
398
399         return DatumGetInt32( FCall2( &info->cmpFunc,
400                                                         *(Datum*)a, *(Datum*)b ) );
401 }
402
403 SimpleArray *
404 Array2SimpleArrayS(ProcTypeInfo info, ArrayType *a)
405 {
406         SimpleArray     *s = Array2SimpleArray(info, a);
407
408         if ( s->nelems > 1 )
409         {
410                 getFmgrInfoCmp(s->info);
411
412                 qsort_arg(s->elems, s->nelems, sizeof(Datum), cmpArrayElem, s->info);
413         }
414
415         return s;
416 }
417
418 typedef struct cmpArrayElemData {
419         ProcTypeInfo    info;
420         bool                    hasDuplicate;
421
422 } cmpArrayElemData;
423
424 static int
425 cmpArrayElemArg(const void *a, const void *b, void *arg)
426 {
427         cmpArrayElemData        *data = (cmpArrayElemData*)arg;
428         int                                     res;
429
430         if (data->info->tupDesc)
431                 res =  DatumGetInt32( FCall2( &data->info->cmpFunc,
432                                         deconstructCompositeType(data->info, *(Datum*)a, NULL),
433                                         deconstructCompositeType(data->info, *(Datum*)b, NULL) ) );
434         else
435                 res = DatumGetInt32( FCall2( &data->info->cmpFunc,
436                                                                 *(Datum*)a, *(Datum*)b ) );
437
438         if ( res == 0 )
439                 data->hasDuplicate = true;
440
441         return res;
442 }
443
444 /*
445  * Uniquefy array and calculate TF. Although 
446  * result doesn't depend on normalization, we
447  * normalize TF by length array to have possiblity
448  * to limit estimation for index support.
449  *
450  * Cache signals of needing of TF caclulation
451  */
452
453 SimpleArray *
454 Array2SimpleArrayU(ProcTypeInfo info, ArrayType *a, void *cache)
455 {
456         SimpleArray     *s = Array2SimpleArray(info, a);
457         StatElem        *stat = NULL;
458
459         if ( s->nelems > 0 && cache )
460         {
461                 s->df = palloc(sizeof(double) * s->nelems);
462                 s->df[0] = 1.0; /* init */
463         }
464
465         if ( s->nelems > 1 )
466         {
467                 cmpArrayElemData        data;
468                 int                                     i;
469
470                 getFmgrInfoCmp(s->info);
471                 data.info = s->info;
472                 data.hasDuplicate = false;
473
474                 qsort_arg(s->elems, s->nelems, sizeof(Datum), cmpArrayElemArg, &data);
475
476                 if ( data.hasDuplicate )
477                 {
478                         Datum   *tmp,
479                                         *dr,
480                                         *data;
481                         int             num = s->nelems,
482                                         cmp;
483
484                         data = tmp = dr = s->elems;
485
486                         while (tmp - data < num)
487                         {
488                                 cmp = (tmp == dr) ? 0 : cmpArrayElem(tmp, dr, s->info);
489                                 if ( cmp != 0 )
490                                 {
491                                         *(++dr) = *tmp++;
492                                         if ( cache ) 
493                                                 s->df[ dr - data ] = 1.0;
494                                 }
495                                 else
496                                 {
497                                         if ( cache )
498                                                 s->df[ dr - data ] += 1.0;
499                                         tmp++;
500                                 }
501                         }
502
503                         s->nelems = dr + 1 - s->elems;
504
505                         if ( cache )
506                         {
507                                 int tfm = getTFMethod();
508
509                                 for(i=0;i<s->nelems;i++)
510                                 {
511                                         stat = fingArrayStat(cache, s->info->typid, s->elems[i], stat);
512                                         if ( stat )
513                                         {
514                                                 switch(tfm)
515                                                 {
516                                                         case TF_LOG:
517                                                                 s->df[i] = (1.0 + log( s->df[i] ));
518                                                         case TF_N:
519                                                                 s->df[i] *= stat->idf;
520                                                                 break;
521                                                         case TF_CONST:
522                                                                 s->df[i] = stat->idf;
523                                                                 break;
524                                                         default:
525                                                                 elog(ERROR,"Unknown TF method: %d", tfm);
526                                                 }
527                                         }
528                                         else
529                                         {
530                                                 s->df[i] = 0.0; /* unknown word */
531                                         }
532                                 }
533                         }
534                 }
535                 else if ( cache )
536                 {
537                         for(i=0;i<s->nelems;i++)
538                         {
539                                 stat = fingArrayStat(cache, s->info->typid, s->elems[i], stat);
540                                 if ( stat )
541                                         s->df[i] = stat->idf;
542                                 else
543                                         s->df[i] = 0.0;
544                         }
545                 }
546         }
547         else if (s->nelems > 0 && cache)
548         {
549                 stat = fingArrayStat(cache, s->info->typid, s->elems[0], stat);
550                 if ( stat )
551                         s->df[0] = stat->idf;
552                 else
553                         s->df[0] = 0.0;
554         }
555
556         return s;
557 }
558
559 static int
560 numOfIntersect(SimpleArray *a, SimpleArray *b)
561 {
562         int                             cnt = 0,
563                                         cmp;
564         Datum                   *aptr = a->elems,
565                                         *bptr = b->elems;
566         ProcTypeInfo    info = a->info;
567
568         Assert( a->info->typid == b->info->typid );
569
570         getFmgrInfoCmp(info);
571
572         while( aptr - a->elems < a->nelems && bptr - b->elems < b->nelems )
573         {
574                 cmp = cmpArrayElem(aptr, bptr, info);
575                 if ( cmp < 0 )
576                         aptr++;
577                 else if ( cmp > 0 )
578                         bptr++;
579                 else
580                 {
581                         cnt++;
582                         aptr++;
583                         bptr++;
584                 }
585         }
586
587         return cnt;
588 }
589
590 static double
591 TFIDFSml(SimpleArray *a, SimpleArray *b)
592 {
593         int                             cmp;
594         Datum                   *aptr = a->elems,
595                                         *bptr = b->elems;
596         ProcTypeInfo    info = a->info;
597         double                  res = 0.0;
598         double                  suma = 0.0, sumb = 0.0;
599
600         Assert( a->info->typid == b->info->typid );
601         Assert( a->df );
602         Assert( b->df );
603
604         getFmgrInfoCmp(info);
605
606         while( aptr - a->elems < a->nelems && bptr - b->elems < b->nelems )
607         {
608                 cmp = cmpArrayElem(aptr, bptr, info);
609                 if ( cmp < 0 )
610                 {
611                         suma += a->df[ aptr - a->elems ] * a->df[ aptr - a->elems ];
612                         aptr++;
613                 }
614                 else if ( cmp > 0 )
615                 {
616                         sumb += b->df[ bptr - b->elems ] * b->df[ bptr - b->elems ];
617                         bptr++;
618                 }
619                 else
620                 {
621                         res += a->df[ aptr - a->elems ] * b->df[ bptr - b->elems ];
622                         suma += a->df[ aptr - a->elems ] * a->df[ aptr - a->elems ];
623                         sumb += b->df[ bptr - b->elems ] * b->df[ bptr - b->elems ];
624                         aptr++;
625                         bptr++;
626                 }
627         }
628
629         /*
630          * Compute last elements
631          */
632         while( aptr - a->elems < a->nelems )
633         {
634                 suma += a->df[ aptr - a->elems ] * a->df[ aptr - a->elems ];
635                 aptr++;
636         }
637
638         while( bptr - b->elems < b->nelems )
639         {
640                 sumb += b->df[ bptr - b->elems ] * b->df[ bptr - b->elems ];
641                 bptr++;
642         }
643
644         if ( suma > 0.0 && sumb > 0.0 )
645                 res = res / sqrt( suma * sumb );
646         else
647                 res = 0.0;
648
649         return res;
650 }
651
652
653 PG_FUNCTION_INFO_V1(arraysml);
654 Datum   arraysml(PG_FUNCTION_ARGS);
655 Datum
656 arraysml(PG_FUNCTION_ARGS)
657 {
658         ArrayType               *a, *b;
659         SimpleArray             *sa, *sb;
660
661         fcinfo->flinfo->fn_extra = SearchArrayCache(
662                                                         fcinfo->flinfo->fn_extra,
663                                                         fcinfo->flinfo->fn_mcxt,
664                                                         PG_GETARG_DATUM(0), &a, &sa, NULL);
665         fcinfo->flinfo->fn_extra = SearchArrayCache(
666                                                         fcinfo->flinfo->fn_extra,
667                                                         fcinfo->flinfo->fn_mcxt,
668                                                         PG_GETARG_DATUM(1), &b, &sb, NULL);
669
670         if ( ARR_ELEMTYPE(a) != ARR_ELEMTYPE(b) )
671                 elog(ERROR,"Arguments array are not the same type!");
672
673         if (ARRISVOID(a) || ARRISVOID(b))
674                  PG_RETURN_FLOAT4(0.0);
675
676         switch(getSmlType())
677         {
678                 case ST_TFIDF:
679                         PG_RETURN_FLOAT4( TFIDFSml(sa, sb) );
680                         break;
681                 case ST_COSINE:
682                         {
683                                 int                             cnt;
684                                 double                  power;
685
686                                 power = ((double)(sa->nelems)) * ((double)(sb->nelems));
687                                 cnt = numOfIntersect(sa, sb);
688
689                                 PG_RETURN_FLOAT4(  ((double)cnt) / sqrt( power ) );
690                         }
691                         break;
692                 case ST_OVERLAP:
693                         {
694                                 float4 res = (float4)numOfIntersect(sa, sb);
695
696                                 PG_RETURN_FLOAT4(res);
697                         }
698                         break;
699                 default:
700                         elog(ERROR,"Unsupported formula type of similarity");
701         }
702
703         PG_RETURN_FLOAT4(0.0); /* keep compiler quiet */
704 }
705
706 PG_FUNCTION_INFO_V1(arraysmlw);
707 Datum   arraysmlw(PG_FUNCTION_ARGS);
708 Datum
709 arraysmlw(PG_FUNCTION_ARGS)
710 {
711         ArrayType               *a, *b;
712         SimpleArray             *sa, *sb;
713         bool                    useIntersect = PG_GETARG_BOOL(2);
714         double                  numerator = 0.0;
715         double                  denominatorA = 0.0,
716                                         denominatorB = 0.0,
717                                         tmpA, tmpB;
718         int                             cmp;
719         ProcTypeInfo    info;
720         int                             ai = 0, bi = 0;
721
722         fcinfo->flinfo->fn_extra = SearchArrayCache(
723                                                         fcinfo->flinfo->fn_extra,
724                                                         fcinfo->flinfo->fn_mcxt,
725                                                         PG_GETARG_DATUM(0), &a, &sa, NULL);
726         fcinfo->flinfo->fn_extra = SearchArrayCache(
727                                                         fcinfo->flinfo->fn_extra,
728                                                         fcinfo->flinfo->fn_mcxt,
729                                                         PG_GETARG_DATUM(1), &b, &sb, NULL);
730
731         if ( ARR_ELEMTYPE(a) != ARR_ELEMTYPE(b) )
732                 elog(ERROR,"Arguments array are not the same type!");
733
734         if (ARRISVOID(a) || ARRISVOID(b))
735                  PG_RETURN_FLOAT4(0.0);
736
737         info = sa->info;
738         if (info->tupDesc == NULL)
739                 elog(ERROR, "Only weigthed (composite) types should be used");
740         getFmgrInfoCmp(info);
741
742         while(ai < sa->nelems && bi < sb->nelems)
743         {
744                 Datum   ad = deconstructCompositeType(info, sa->elems[ai], &tmpA),
745                                 bd = deconstructCompositeType(info, sb->elems[bi], &tmpB);
746
747                 cmp = DatumGetInt32(FCall2(&info->cmpFunc, ad, bd));
748
749                 if ( cmp < 0 ) {
750                         if (useIntersect == false)
751                                 denominatorA += tmpA * tmpA;
752                         ai++;
753                 } else if ( cmp > 0 ) {
754                         if (useIntersect == false)
755                                 denominatorB += tmpB * tmpB;
756                         bi++;
757                 } else {
758                         denominatorA += tmpA * tmpA;
759                         denominatorB += tmpB * tmpB;
760                         numerator += tmpA * tmpB;
761                         ai++;
762                         bi++;
763                 }
764         }
765
766         if (useIntersect == false) {
767                 while(ai < sa->nelems) {
768                         deconstructCompositeType(info, sa->elems[ai], &tmpA);
769                         denominatorA += tmpA * tmpA;
770                         ai++;
771                 }
772
773                 while(bi < sb->nelems) {
774                         deconstructCompositeType(info, sb->elems[bi], &tmpB);
775                         denominatorB += tmpB * tmpB;
776                         bi++;
777                 }
778         }
779
780         if (numerator != 0.0) {
781                 numerator = numerator / sqrt( denominatorA * denominatorB );
782         }
783
784         PG_RETURN_FLOAT4(numerator);
785 }
786
787 PG_FUNCTION_INFO_V1(arraysml_op);
788 Datum   arraysml_op(PG_FUNCTION_ARGS);
789 Datum
790 arraysml_op(PG_FUNCTION_ARGS)
791 {
792         ArrayType               *a, *b;
793         SimpleArray             *sa, *sb;
794         double                  power = 0.0;
795
796         fcinfo->flinfo->fn_extra = SearchArrayCache(
797                                                         fcinfo->flinfo->fn_extra,
798                                                         fcinfo->flinfo->fn_mcxt,
799                                                         PG_GETARG_DATUM(0), &a, &sa, NULL);
800         fcinfo->flinfo->fn_extra = SearchArrayCache(
801                                                         fcinfo->flinfo->fn_extra,
802                                                         fcinfo->flinfo->fn_mcxt,
803                                                         PG_GETARG_DATUM(1), &b, &sb, NULL);
804
805         if ( ARR_ELEMTYPE(a) != ARR_ELEMTYPE(b) )
806                 elog(ERROR,"Arguments array are not the same type!");
807
808         if (ARRISVOID(a) || ARRISVOID(b))
809                  PG_RETURN_BOOL(false);
810
811         switch(getSmlType())
812         {
813                 case ST_TFIDF:
814                         power = TFIDFSml(sa, sb);
815                         break;
816                 case ST_COSINE:
817                         {
818                                 int                             cnt;
819
820                                 power = sqrt( ((double)(sa->nelems)) * ((double)(sb->nelems)) );
821
822                                 if (  ((double)Min(sa->nelems, sb->nelems)) / power < GetSmlarLimit()  )
823                                         PG_RETURN_BOOL(false);
824
825                                 cnt = numOfIntersect(sa, sb);
826                                 power = ((double)cnt) / power;
827                         }
828                         break;
829                 case ST_OVERLAP:
830                         power = (double)numOfIntersect(sa, sb);
831                         break;
832                 default:
833                         elog(ERROR,"Unsupported formula type of similarity");
834         }
835
836         PG_RETURN_BOOL(power >= GetSmlarLimit());
837 }
838
839 #define QBSIZE          8192
840 static char cachedFormula[QBSIZE];
841 static int      cachedLen  = 0;
842 static void     *cachedPlan = NULL;
843
844 PG_FUNCTION_INFO_V1(arraysml_func);
845 Datum   arraysml_func(PG_FUNCTION_ARGS);
846 Datum
847 arraysml_func(PG_FUNCTION_ARGS)
848 {
849         ArrayType               *a, *b;
850         SimpleArray             *sa, *sb;
851         int                             cnt;
852         float4                  result = -1.0;
853         Oid                             arg[] = {INT4OID, INT4OID, INT4OID};
854         Datum                   pars[3];
855         bool                    isnull;
856         void                    *plan;
857         int                             stat;
858         text                    *formula = PG_GETARG_TEXT_P(2);
859
860         fcinfo->flinfo->fn_extra = SearchArrayCache(
861                                                         fcinfo->flinfo->fn_extra,
862                                                         fcinfo->flinfo->fn_mcxt,
863                                                         PG_GETARG_DATUM(0), &a, &sa, NULL);
864         fcinfo->flinfo->fn_extra = SearchArrayCache(
865                                                         fcinfo->flinfo->fn_extra,
866                                                         fcinfo->flinfo->fn_mcxt,
867                                                         PG_GETARG_DATUM(1), &b, &sb, NULL);
868
869         if ( ARR_ELEMTYPE(a) != ARR_ELEMTYPE(b) )
870                 elog(ERROR,"Arguments array are not the same type!");
871
872         if (ARRISVOID(a) || ARRISVOID(b))
873                  PG_RETURN_BOOL(false);
874
875         cnt = numOfIntersect(sa, sb);
876
877         if ( VARSIZE(formula) - VARHDRSZ > QBSIZE - 1024 )
878                 elog(ERROR,"Formula is too long");
879
880         SPI_connect();
881
882         if ( cachedPlan == NULL || cachedLen != VARSIZE(formula) - VARHDRSZ ||
883                                 memcmp( cachedFormula, VARDATA(formula), VARSIZE(formula) - VARHDRSZ ) != 0 )
884         {
885                 char                    *ptr, buf[QBSIZE];
886
887                 *cachedFormula = '\0';
888                 if ( cachedPlan )
889                         SPI_freeplan(cachedPlan);
890                 cachedPlan = NULL;
891                 cachedLen = 0;
892
893                 ptr = stpcpy( buf, "SELECT (" );
894                 memcpy( ptr, VARDATA(formula), VARSIZE(formula) - VARHDRSZ );
895                 ptr += VARSIZE(formula) - VARHDRSZ;
896                 ptr = stpcpy( ptr, ")::float4 FROM");
897                 ptr = stpcpy( ptr, " (SELECT $1 ::float8 AS i, $2 ::float8 AS a, $3 ::float8 AS b) AS N;");
898                 *ptr = '\0';
899
900                 plan = SPI_prepare(buf, 3, arg);
901                 if (!plan)
902                         elog(ERROR, "SPI_prepare() failed");
903
904                 cachedPlan = SPI_saveplan(plan);
905                 if (!cachedPlan)
906                         elog(ERROR, "SPI_saveplan() failed");
907
908                 SPI_freeplan(plan);
909                 cachedLen = VARSIZE(formula) - VARHDRSZ;
910                 memcpy( cachedFormula, VARDATA(formula), VARSIZE(formula) - VARHDRSZ );
911         }
912
913         plan = cachedPlan;
914
915
916         pars[0] = Int32GetDatum( cnt );
917         pars[1] = Int32GetDatum( sa->nelems );
918         pars[2] = Int32GetDatum( sb->nelems );
919
920         stat = SPI_execute_plan(plan, pars, NULL, true, 3);
921         if (stat < 0)
922                 elog(ERROR, "SPI_execute_plan() returns %d", stat);
923
924         if ( SPI_processed > 0)
925                 result = DatumGetFloat4(SPI_getbinval(SPI_tuptable->vals[0], SPI_tuptable->tupdesc, 1, &isnull));
926
927         SPI_finish();
928
929         PG_RETURN_FLOAT4(result);
930 }
931
932 PG_FUNCTION_INFO_V1(array_unique);
933 Datum   array_unique(PG_FUNCTION_ARGS);
934 Datum
935 array_unique(PG_FUNCTION_ARGS)
936 {
937         ArrayType               *a = PG_GETARG_ARRAYTYPE_P(0);
938         ArrayType               *res;
939         SimpleArray             *sa;
940
941         sa = Array2SimpleArrayU(NULL, a, NULL);
942
943         res = construct_array(  sa->elems, 
944                                                         sa->nelems,
945                                                         sa->info->typid,
946                                                         sa->info->typlen,
947                                                         sa->info->typbyval,
948                                                         sa->info->typalign);
949
950         pfree(sa->elems);
951         pfree(sa);
952         PG_FREE_IF_COPY(a, 0);
953
954         PG_RETURN_ARRAYTYPE_P(res);
955 }
956
957 PG_FUNCTION_INFO_V1(inarray);
958 Datum   inarray(PG_FUNCTION_ARGS);
959 Datum
960 inarray(PG_FUNCTION_ARGS)
961 {
962         ArrayType               *a;
963         SimpleArray             *sa;
964         Datum                   query = PG_GETARG_DATUM(1);
965         Oid                             queryTypeOid;
966         Datum                   *StopLow,
967                                         *StopHigh,
968                                         *StopMiddle;
969         int                             cmp;
970
971         fcinfo->flinfo->fn_extra = SearchArrayCache(
972                                                         fcinfo->flinfo->fn_extra,
973                                                         fcinfo->flinfo->fn_mcxt,
974                                                         PG_GETARG_DATUM(0), &a, &sa, NULL);
975
976         queryTypeOid = get_fn_expr_argtype(fcinfo->flinfo, 1);
977
978         if ( queryTypeOid == InvalidOid )
979                 elog(ERROR,"inarray: could not determine actual argument type");
980
981         if ( queryTypeOid != sa->info->typid )
982                 elog(ERROR,"inarray: Type of array's element and type of argument are not the same");
983
984         getFmgrInfoCmp(sa->info);
985         StopLow = sa->elems;
986         StopHigh = sa->elems + sa->nelems;
987
988         while (StopLow < StopHigh)
989         {
990                 StopMiddle = StopLow + ((StopHigh - StopLow) >> 1);
991                 cmp = cmpArrayElem(StopMiddle, &query, sa->info);
992
993                 if ( cmp == 0 )
994                 {
995                         /* found */
996                         if ( PG_NARGS() >= 3 )
997                                 PG_RETURN_DATUM(PG_GETARG_DATUM(2));
998                         PG_RETURN_FLOAT4(1.0);
999                 }
1000                 else if (cmp < 0)
1001                         StopLow = StopMiddle + 1;
1002                 else
1003                         StopHigh = StopMiddle;
1004         }
1005
1006         if ( PG_NARGS() >= 4 )
1007                 PG_RETURN_DATUM(PG_GETARG_DATUM(3));
1008         PG_RETURN_FLOAT4(0.0);
1009 }