support 9.5
[smlar.git] / smlar.c
1 #include "smlar.h"
2
3 #include "fmgr.h"
4 #include "access/genam.h"
5 #include "access/heapam.h"
6 #include "access/htup_details.h"
7 #include "access/nbtree.h"
8 #include "catalog/indexing.h"
9 #include "catalog/pg_am.h"
10 #include "catalog/pg_amproc.h"
11 #include "catalog/pg_cast.h"
12 #include "catalog/pg_opclass.h"
13 #include "catalog/pg_type.h"
14 #include "executor/spi.h"
15 #include "utils/catcache.h"
16 #include "utils/fmgroids.h"
17 #include "utils/lsyscache.h"
18 #include "utils/memutils.h"
19 #include "utils/tqual.h"
20 #include "utils/syscache.h"
21 #include "utils/typcache.h"
22
23 PG_MODULE_MAGIC;
24
25 #if (PG_VERSION_NUM >= 90400)
26 #define SNAPSHOT NULL
27 #else
28 #define SNAPSHOT SnapshotNow
29 #endif
30
31 static Oid
32 getDefaultOpclass(Oid amoid, Oid typid) 
33 {
34         ScanKeyData     skey;
35         SysScanDesc     scan;
36         HeapTuple       tuple;
37         Relation        heapRel;
38         Oid                     opclassOid = InvalidOid;
39
40         heapRel = heap_open(OperatorClassRelationId, AccessShareLock);  
41
42         ScanKeyInit(&skey,
43                                 Anum_pg_opclass_opcmethod,
44                                 BTEqualStrategyNumber,  F_OIDEQ,
45                                 ObjectIdGetDatum(amoid));
46
47         scan = systable_beginscan(heapRel, 
48                                                                 OpclassAmNameNspIndexId, true, 
49                                                                 SNAPSHOT, 1, &skey);
50
51         while (HeapTupleIsValid((tuple = systable_getnext(scan))))
52         {
53                 Form_pg_opclass opclass = (Form_pg_opclass)GETSTRUCT(tuple);
54
55                 if ( opclass->opcintype == typid && opclass->opcdefault )
56                 {
57                         if ( OidIsValid(opclassOid) )
58                                 elog(ERROR, "Ambiguous opclass for type %u (access method %u)", typid, amoid); 
59                         opclassOid = HeapTupleGetOid(tuple);
60                 }
61         }
62
63         systable_endscan(scan);
64         heap_close(heapRel, AccessShareLock);
65
66         return opclassOid;
67 }
68
69 static Oid
70 getAMProc(Oid amoid, Oid typid) 
71 {
72         Oid             opclassOid = getDefaultOpclass(amoid, typid);
73         Oid             procOid = InvalidOid;
74         Oid             opfamilyOid;
75         ScanKeyData     skey[4];
76         SysScanDesc     scan;
77         HeapTuple       tuple;
78         Relation        heapRel;
79
80         if ( !OidIsValid(opclassOid) )
81         {
82                 typid = getBaseType(typid);
83                 opclassOid = getDefaultOpclass(amoid, typid);
84         }
85
86         if ( !OidIsValid(opclassOid) )
87         {
88                 CatCList        *catlist;
89                 int                     i;
90
91                 /*
92                  * Search binary-coercible type
93                  */
94                 catlist = SearchSysCacheList(CASTSOURCETARGET, 1,
95                                                                                 ObjectIdGetDatum(typid),
96                                                                                 0, 0, 0);
97
98                 for (i = 0; i < catlist->n_members; i++)
99                 {
100                         HeapTuple       tuple = &catlist->members[i]->tuple;
101                         Form_pg_cast    castForm = (Form_pg_cast)GETSTRUCT(tuple);
102
103                         if ( castForm->castfunc == InvalidOid && castForm->castcontext == COERCION_CODE_IMPLICIT )
104                         {
105                                 typid = castForm->casttarget;
106                                 opclassOid = getDefaultOpclass(amoid, typid);
107                                 if( OidIsValid(opclassOid) )
108                                         break;
109                         }
110                 }
111
112                 ReleaseSysCacheList(catlist);
113         }
114         
115         if ( !OidIsValid(opclassOid) )
116                 return InvalidOid;
117
118         opfamilyOid = get_opclass_family(opclassOid);
119
120         heapRel = heap_open(AccessMethodProcedureRelationId, AccessShareLock);
121         ScanKeyInit(&skey[0],
122                                 Anum_pg_amproc_amprocfamily,
123                                 BTEqualStrategyNumber, F_OIDEQ,
124                                 ObjectIdGetDatum(opfamilyOid));
125         ScanKeyInit(&skey[1],
126                                 Anum_pg_amproc_amproclefttype,
127                                 BTEqualStrategyNumber, F_OIDEQ,
128                                 ObjectIdGetDatum(typid));
129         ScanKeyInit(&skey[2],
130                                 Anum_pg_amproc_amprocrighttype,
131                                 BTEqualStrategyNumber, F_OIDEQ,
132                                 ObjectIdGetDatum(typid));
133 #if PG_VERSION_NUM >= 90200
134         ScanKeyInit(&skey[3],
135                                 Anum_pg_amproc_amprocnum,
136                                 BTEqualStrategyNumber, F_OIDEQ,
137                                 Int32GetDatum(BTORDER_PROC));
138 #endif
139
140         scan = systable_beginscan(heapRel, AccessMethodProcedureIndexId, true,
141                                                                 SNAPSHOT, 
142 #if PG_VERSION_NUM >= 90200
143                                                                 4,
144 #else
145                                                                 3,
146 #endif
147                                                                 skey);
148         while (HeapTupleIsValid(tuple = systable_getnext(scan)))
149         {
150                 Form_pg_amproc amprocform = (Form_pg_amproc) GETSTRUCT(tuple);
151
152                 switch(amoid)
153                 {
154                         case BTREE_AM_OID:
155                         case HASH_AM_OID:
156                                 if ( OidIsValid(procOid) )
157                                         elog(ERROR,"Ambiguous support function for type %u (opclass %u)", typid, opfamilyOid);
158                                 procOid = amprocform->amproc;
159                                 break;
160                         default:
161                                 elog(ERROR,"Unsupported access method");
162                 }
163         }
164
165         systable_endscan(scan);
166         heap_close(heapRel, AccessShareLock);
167
168         return procOid;
169 }
170
171 static ProcTypeInfo *cacheProcs = NULL;
172 static int nCacheProcs = 0;
173
174 static ProcTypeInfo
175 fillProcs(Oid typid)
176 {
177         ProcTypeInfo    info = malloc(sizeof(ProcTypeInfoData));
178
179         if (!info)
180                 elog(ERROR, "Can't allocate %u memory", (uint32)sizeof(ProcTypeInfoData));
181
182         info->typid = typid;
183         info->typtype = get_typtype(typid);
184
185         if (info->typtype == 'c')
186         {
187                 /* composite type */
188                 TupleDesc               tupdesc;
189                 MemoryContext   oldcontext;
190
191                 tupdesc = lookup_rowtype_tupdesc(typid, -1);
192
193                 if (tupdesc->natts != 2)
194                         elog(ERROR,"Composite type has wrong number of fields");
195                 if (tupdesc->attrs[1]->atttypid != FLOAT4OID)
196                         elog(ERROR,"Second field of composite type is not float4");
197
198                 oldcontext = MemoryContextSwitchTo(TopMemoryContext);
199                 info->tupDesc = CreateTupleDescCopyConstr(tupdesc);
200                 MemoryContextSwitchTo(oldcontext);
201
202                 ReleaseTupleDesc(tupdesc);
203
204                 info->cmpFuncOid = getAMProc(BTREE_AM_OID, info->tupDesc->attrs[0]->atttypid);
205                 info->hashFuncOid = getAMProc(HASH_AM_OID, info->tupDesc->attrs[0]->atttypid);
206         } 
207         else 
208         {
209                 info->tupDesc = NULL;
210
211                 /* plain type */
212                 info->cmpFuncOid = getAMProc(BTREE_AM_OID, typid);
213                 info->hashFuncOid = getAMProc(HASH_AM_OID, typid);
214         }
215
216         get_typlenbyvalalign(typid, &info->typlen, &info->typbyval, &info->typalign);
217         info->hashFuncInited = info->cmpFuncInited = false;
218
219
220         return info;
221 }
222
223 void
224 getFmgrInfoCmp(ProcTypeInfo info)
225 {
226         if ( info->cmpFuncInited == false )
227         {
228                 if ( !OidIsValid(info->cmpFuncOid) )
229                         elog(ERROR, "Could not find cmp function for type %u", info->typid);
230
231                 fmgr_info_cxt( info->cmpFuncOid, &info->cmpFunc, TopMemoryContext );
232                 info->cmpFuncInited = true;
233         }
234 }
235
236 void
237 getFmgrInfoHash(ProcTypeInfo info)
238 {
239         if ( info->hashFuncInited == false )
240         {
241                 if ( !OidIsValid(info->hashFuncOid) )
242                         elog(ERROR, "Could not find hash function for type %u", info->typid);
243
244                 fmgr_info_cxt( info->hashFuncOid, &info->hashFunc, TopMemoryContext );
245                 info->hashFuncInited = true;
246         }
247 }
248
249 static int
250 cmpProcTypeInfo(const void *a, const void *b)
251 {
252         ProcTypeInfo av = *(ProcTypeInfo*)a;
253         ProcTypeInfo bv = *(ProcTypeInfo*)b;
254
255         Assert( av->typid != bv->typid );
256
257         return ( av->typid > bv->typid ) ? 1 : -1; 
258 }
259
260 ProcTypeInfo
261 findProcs(Oid typid)
262 {
263         ProcTypeInfo    info = NULL;
264
265         if ( nCacheProcs == 1 )
266         {
267                 if ( cacheProcs[0]->typid == typid ) 
268                 {
269                         /*cacheProcs[0]->hashFuncInited = cacheProcs[0]->cmpFuncInited = false;*/
270                         return cacheProcs[0];
271                 }
272         }
273         else if ( nCacheProcs > 1 )
274         {
275                 ProcTypeInfo    *StopMiddle;
276                 ProcTypeInfo    *StopLow = cacheProcs, 
277                                                 *StopHigh = cacheProcs + nCacheProcs;
278
279                 while (StopLow < StopHigh) {
280                         StopMiddle = StopLow + ((StopHigh - StopLow) >> 1);
281                         info = *StopMiddle;
282
283                         if ( info->typid == typid )
284                         {
285                                 /* info->hashFuncInited = info->cmpFuncInited = false; */
286                                 return info;
287                         }
288                         else if ( info->typid < typid )
289                                 StopLow = StopMiddle + 1;
290                         else
291                                 StopHigh = StopMiddle;
292                 }
293
294                 /* not found */
295         } 
296
297         info = fillProcs(typid);
298         if ( nCacheProcs == 0 ) 
299         {
300                 cacheProcs = malloc(sizeof(ProcTypeInfo));
301
302                 if (!cacheProcs)
303                         elog(ERROR, "Can't allocate %u memory", (uint32)sizeof(ProcTypeInfo));
304                 else
305                 {
306                         nCacheProcs = 1;
307                         cacheProcs[0] = info;
308                 }
309         }
310         else
311         {
312                 ProcTypeInfo    *cacheProcsTmp = realloc(cacheProcs, (nCacheProcs+1) * sizeof(ProcTypeInfo));
313
314                 if (!cacheProcsTmp)
315                         elog(ERROR, "Can't allocate %u memory", (uint32)sizeof(ProcTypeInfo) * (nCacheProcs+1));
316                 else
317                 {
318                         cacheProcs = cacheProcsTmp;
319                         cacheProcs[ nCacheProcs ] = info;
320                         nCacheProcs++;
321                         qsort(cacheProcs, nCacheProcs, sizeof(ProcTypeInfo), cmpProcTypeInfo);
322                 }
323         }
324
325         /* info->hashFuncInited = info->cmpFuncInited = false; */
326
327         return info;
328 }
329
330 /*
331  * WARNING. Array2SimpleArray* doesn't copy Datum!
332  */
333 SimpleArray * 
334 Array2SimpleArray(ProcTypeInfo info, ArrayType *a)
335 {
336         SimpleArray     *s = palloc(sizeof(SimpleArray));
337
338         CHECKARRVALID(a);
339
340         if ( info == NULL )
341                 info = findProcs(ARR_ELEMTYPE(a));
342
343         s->info = info;
344         s->df = NULL;
345         s->hash = NULL;
346
347         deconstruct_array(a, info->typid,
348                                                 info->typlen, info->typbyval, info->typalign,
349                                                 &s->elems, NULL, &s->nelems);
350
351         return s;
352 }
353
354 static Datum
355 deconstructCompositeType(ProcTypeInfo info, Datum in, double *weight)
356 {
357         HeapTupleHeader rec = DatumGetHeapTupleHeader(in);
358         HeapTupleData   tuple;
359         Datum                   values[2];
360         bool                    nulls[2];
361
362         /* Build a temporary HeapTuple control structure */
363         tuple.t_len = HeapTupleHeaderGetDatumLength(rec);
364         ItemPointerSetInvalid(&(tuple.t_self));
365         tuple.t_tableOid = InvalidOid;
366         tuple.t_data = rec;
367
368         heap_deform_tuple(&tuple, info->tupDesc, values, nulls);
369         if (nulls[0] || nulls[1])
370                 elog(ERROR, "Both fields in composite type could not be NULL");
371
372         if (weight)
373                 *weight = DatumGetFloat4(values[1]);
374         return values[0];
375 }
376
377 static int
378 cmpArrayElem(const void *a, const void *b, void *arg)
379 {
380         ProcTypeInfo    info = (ProcTypeInfo)arg;
381
382         if (info->tupDesc) 
383                 /* composite type */
384                 return DatumGetInt32( FCall2( &info->cmpFunc,
385                                                                                         deconstructCompositeType(info, *(Datum*)a, NULL),
386                                                                                         deconstructCompositeType(info, *(Datum*)b, NULL) ) );
387
388         return DatumGetInt32( FCall2( &info->cmpFunc,
389                                                         *(Datum*)a, *(Datum*)b ) );
390 }
391
392 SimpleArray *
393 Array2SimpleArrayS(ProcTypeInfo info, ArrayType *a)
394 {
395         SimpleArray     *s = Array2SimpleArray(info, a);
396
397         if ( s->nelems > 1 )
398         {
399                 getFmgrInfoCmp(s->info);
400
401                 qsort_arg(s->elems, s->nelems, sizeof(Datum), cmpArrayElem, s->info);
402         }
403
404         return s;
405 }
406
407 typedef struct cmpArrayElemData {
408         ProcTypeInfo    info;
409         bool                    hasDuplicate;
410
411 } cmpArrayElemData;
412
413 static int
414 cmpArrayElemArg(const void *a, const void *b, void *arg)
415 {
416     cmpArrayElemData    *data = (cmpArrayElemData*)arg;
417         int res;
418
419         if (data->info->tupDesc)
420                 res =  DatumGetInt32( FCall2( &data->info->cmpFunc,
421                                                                                         deconstructCompositeType(data->info, *(Datum*)a, NULL),
422                                                                                         deconstructCompositeType(data->info, *(Datum*)b, NULL) ) );
423         else
424                 res = DatumGetInt32( FCall2( &data->info->cmpFunc,
425                                                                 *(Datum*)a, *(Datum*)b ) );
426
427         if ( res == 0 )
428                 data->hasDuplicate = true;
429
430         return res;
431 }
432
433 /*
434  * Uniquefy array and calculate TF. Although 
435  * result doesn't depend on normalization, we
436  * normalize TF by length array to have possiblity
437  * to limit estimation for index support.
438  *
439  * Cache signals of needing of TF caclulation
440  */
441
442 SimpleArray *
443 Array2SimpleArrayU(ProcTypeInfo info, ArrayType *a, void *cache)
444 {
445         SimpleArray     *s = Array2SimpleArray(info, a);
446         StatElem        *stat = NULL;
447
448         if ( s->nelems > 0 && cache )
449         {
450                 s->df = palloc(sizeof(double) * s->nelems);
451                 s->df[0] = 1.0; /* init */
452         }
453
454         if ( s->nelems > 1 )
455         {
456                 cmpArrayElemData        data;
457                 int                             i;
458
459                 getFmgrInfoCmp(s->info);
460                 data.info = s->info;
461                 data.hasDuplicate = false;
462
463                 qsort_arg(s->elems, s->nelems, sizeof(Datum), cmpArrayElemArg, &data);
464
465                 if ( data.hasDuplicate )
466                 {
467                         Datum   *tmp,
468                                         *dr,
469                                         *data;
470                         int             num = s->nelems,
471                                         cmp;
472
473                         data = tmp = dr = s->elems;
474
475                         while (tmp - data < num)
476                         {
477                                 cmp = (tmp == dr) ? 0 : cmpArrayElem(tmp, dr, s->info);
478                                 if ( cmp != 0 )
479                                 {
480                                         *(++dr) = *tmp++;
481                                         if ( cache ) 
482                                                 s->df[ dr - data ] = 1.0;
483                                 }
484                                 else
485                                 {
486                                         if ( cache )
487                                                 s->df[ dr - data ] += 1.0;
488                                         tmp++;
489                                 }
490                         }
491
492                         s->nelems = dr + 1 - s->elems;
493
494                         if ( cache )
495                         {
496                                 int tfm = getTFMethod();
497
498                                 for(i=0;i<s->nelems;i++)
499                                 {
500                                         stat = fingArrayStat(cache, s->info->typid, s->elems[i], stat);
501                                         if ( stat )
502                                         {
503                                                 switch(tfm)
504                                                 {
505                                                         case TF_LOG:
506                                                                 s->df[i] = (1.0 + log( s->df[i] ));
507                                                         case TF_N:
508                                                                 s->df[i] *= stat->idf;
509                                                                 break;
510                                                         case TF_CONST:
511                                                                 s->df[i] = stat->idf;
512                                                                 break;
513                                                         default:
514                                                                 elog(ERROR,"Unknown TF method: %d", tfm);
515                                                 }       
516                                         }
517                                         else
518                                         {
519                                                 s->df[i] = 0.0; /* unknown word */
520                                         }
521                                 }       
522                         }
523                 }
524                 else if ( cache )
525                 {
526                         for(i=0;i<s->nelems;i++)
527                         {
528                                 stat = fingArrayStat(cache, s->info->typid, s->elems[i], stat);
529                                 if ( stat )
530                                         s->df[i] = stat->idf;
531                                 else
532                                         s->df[i] = 0.0;
533                         }
534                 }
535         }
536         else if (s->nelems > 0 && cache) 
537         {
538                 stat = fingArrayStat(cache, s->info->typid, s->elems[0], stat);
539                 if ( stat )
540                         s->df[0] = stat->idf;
541                 else
542                         s->df[0] = 0.0;
543         }
544
545         return s;
546 }
547
548 static int
549 numOfIntersect(SimpleArray *a, SimpleArray *b)
550 {
551         int                             cnt = 0,
552                                         cmp;
553         Datum                   *aptr = a->elems,
554                                         *bptr = b->elems;
555         ProcTypeInfo    info = a->info;
556
557         Assert( a->info->typid == b->info->typid );
558
559         getFmgrInfoCmp(info);
560
561         while( aptr - a->elems < a->nelems && bptr - b->elems < b->nelems )
562         {
563                 cmp = cmpArrayElem(aptr, bptr, info);
564                 if ( cmp < 0 )
565                         aptr++;
566                 else if ( cmp > 0 )
567                         bptr++;
568                 else
569                 {
570                         cnt++;
571                         aptr++;
572                         bptr++;
573                 }
574         }
575
576         return cnt;
577 }
578
579 static double
580 TFIDFSml(SimpleArray *a, SimpleArray *b)
581 {
582         int                             cmp;
583         Datum                   *aptr = a->elems,
584                                         *bptr = b->elems;
585         ProcTypeInfo    info = a->info;
586         double                  res = 0.0;
587         double                  suma = 0.0, sumb = 0.0;
588
589         Assert( a->info->typid == b->info->typid );
590         Assert( a->df );
591         Assert( b->df );
592
593         getFmgrInfoCmp(info);
594
595         while( aptr - a->elems < a->nelems && bptr - b->elems < b->nelems )
596         {
597                 cmp = cmpArrayElem(aptr, bptr, info);
598                 if ( cmp < 0 )
599                 {
600                         suma += a->df[ aptr - a->elems ] * a->df[ aptr - a->elems ];
601                         aptr++;
602                 }
603                 else if ( cmp > 0 )
604                 {
605                         sumb += b->df[ bptr - b->elems ] * b->df[ bptr - b->elems ];
606                         bptr++;
607                 }
608                 else
609                 {
610                         res += a->df[ aptr - a->elems ] * b->df[ bptr - b->elems ];
611                         suma += a->df[ aptr - a->elems ] * a->df[ aptr - a->elems ];
612                         sumb += b->df[ bptr - b->elems ] * b->df[ bptr - b->elems ];
613                         aptr++;
614                         bptr++;
615                 }
616         }
617
618         /*
619          * Compute last elements
620          */
621         while( aptr - a->elems < a->nelems )
622         {
623                 suma += a->df[ aptr - a->elems ] * a->df[ aptr - a->elems ];
624                 aptr++;
625         }
626
627         while( bptr - b->elems < b->nelems )
628         {
629                 sumb += b->df[ bptr - b->elems ] * b->df[ bptr - b->elems ];
630                 bptr++;
631         }
632
633         if ( suma > 0.0 && sumb > 0.0 )
634                 res = res / sqrt( suma * sumb );
635         else
636                 res = 0.0;
637
638         return res;
639 }
640
641
642 PG_FUNCTION_INFO_V1(arraysml);
643 Datum   arraysml(PG_FUNCTION_ARGS);
644 Datum
645 arraysml(PG_FUNCTION_ARGS)
646 {
647         ArrayType               *a, *b;
648         SimpleArray             *sa, *sb;
649         
650         fcinfo->flinfo->fn_extra = SearchArrayCache( 
651                                                         fcinfo->flinfo->fn_extra,
652                                                         fcinfo->flinfo->fn_mcxt,
653                                                         PG_GETARG_DATUM(0), &a, &sa, NULL);
654         fcinfo->flinfo->fn_extra = SearchArrayCache( 
655                                                         fcinfo->flinfo->fn_extra,
656                                                         fcinfo->flinfo->fn_mcxt,
657                                                         PG_GETARG_DATUM(1), &b, &sb, NULL);
658
659         if ( ARR_ELEMTYPE(a) != ARR_ELEMTYPE(b) )
660                 elog(ERROR,"Arguments array are not the same type!");
661
662         if (ARRISVOID(a) || ARRISVOID(b))
663                  PG_RETURN_FLOAT4(0.0);
664
665         switch(getSmlType())
666         {
667                 case    ST_TFIDF:
668                         PG_RETURN_FLOAT4( TFIDFSml(sa, sb) );
669                         break;
670                 case    ST_COSINE:
671                         {
672                                 int                             cnt;
673                                 double                  power;
674
675                                 power = ((double)(sa->nelems)) * ((double)(sb->nelems));
676                                 cnt = numOfIntersect(sa, sb);
677
678                                 PG_RETURN_FLOAT4(  ((double)cnt) / sqrt( power ) );
679                         }
680                         break;
681                 case    ST_OVERLAP:
682                         {
683                                 float4 res = (float4)numOfIntersect(sa, sb);
684
685                                 PG_RETURN_FLOAT4(res);
686                         }
687                         break;
688                 default:
689                         elog(ERROR,"Unsupported formula type of similarity");
690         }
691
692         PG_RETURN_FLOAT4(0.0); /* keep compiler quiet */ 
693 }
694
695 PG_FUNCTION_INFO_V1(arraysmlw);
696 Datum   arraysmlw(PG_FUNCTION_ARGS);
697 Datum
698 arraysmlw(PG_FUNCTION_ARGS)
699 {
700         ArrayType               *a, *b;
701         SimpleArray             *sa, *sb;
702         bool                    useIntersect = PG_GETARG_BOOL(2);
703         double                  numerator = 0.0;
704         double                  denominatorA = 0.0,
705                                         denominatorB = 0.0,
706                                         tmpA, tmpB;
707         int                             cmp;
708         ProcTypeInfo    info;
709         int                             ai = 0, bi = 0;
710
711         
712         fcinfo->flinfo->fn_extra = SearchArrayCache( 
713                                                         fcinfo->flinfo->fn_extra,
714                                                         fcinfo->flinfo->fn_mcxt,
715                                                         PG_GETARG_DATUM(0), &a, &sa, NULL);
716         fcinfo->flinfo->fn_extra = SearchArrayCache( 
717                                                         fcinfo->flinfo->fn_extra,
718                                                         fcinfo->flinfo->fn_mcxt,
719                                                         PG_GETARG_DATUM(1), &b, &sb, NULL);
720
721         if ( ARR_ELEMTYPE(a) != ARR_ELEMTYPE(b) )
722                 elog(ERROR,"Arguments array are not the same type!");
723
724         if (ARRISVOID(a) || ARRISVOID(b))
725                  PG_RETURN_FLOAT4(0.0);
726
727         info = sa->info;
728         if (info->tupDesc == NULL)
729                 elog(ERROR, "Only weigthed (composite) types should be used");
730         getFmgrInfoCmp(info);
731
732         while(ai < sa->nelems && bi < sb->nelems)
733         {
734                 Datum   ad = deconstructCompositeType(info, sa->elems[ai], &tmpA),
735                                 bd = deconstructCompositeType(info, sb->elems[bi], &tmpB);
736
737                 cmp = DatumGetInt32(FCall2(&info->cmpFunc, ad, bd));
738
739                 if ( cmp < 0 ) {
740                         if (useIntersect == false)
741                                 denominatorA += tmpA * tmpA;
742                         ai++;
743                 } else if ( cmp > 0 ) {
744                         if (useIntersect == false)
745                                 denominatorB += tmpB * tmpB;
746                         bi++;
747                 } else {
748                         denominatorA += tmpA * tmpA;
749                         denominatorB += tmpB * tmpB;
750                         numerator += tmpA * tmpB;
751                         ai++;
752                         bi++;
753                 }
754         }
755
756         if (useIntersect == false) {
757                 while(ai < sa->nelems) {
758                         deconstructCompositeType(info, sa->elems[ai], &tmpA);
759                         denominatorA += tmpA * tmpA;
760                         ai++;
761                 }
762
763                 while(bi < sb->nelems) {
764                         deconstructCompositeType(info, sb->elems[bi], &tmpB);
765                         denominatorB += tmpB * tmpB;
766                         bi++;
767                 }
768         }
769
770         if (numerator != 0.0) {
771                 numerator = numerator / sqrt( denominatorA * denominatorB );
772         }
773
774         PG_RETURN_FLOAT4(numerator);
775 }
776
777 PG_FUNCTION_INFO_V1(arraysml_op);
778 Datum   arraysml_op(PG_FUNCTION_ARGS);
779 Datum
780 arraysml_op(PG_FUNCTION_ARGS)
781 {
782         ArrayType               *a, *b;
783         SimpleArray             *sa, *sb;
784         double                  power = 0.0;
785         
786         fcinfo->flinfo->fn_extra = SearchArrayCache( 
787                                                         fcinfo->flinfo->fn_extra,
788                                                         fcinfo->flinfo->fn_mcxt,
789                                                         PG_GETARG_DATUM(0), &a, &sa, NULL);
790         fcinfo->flinfo->fn_extra = SearchArrayCache( 
791                                                         fcinfo->flinfo->fn_extra,
792                                                         fcinfo->flinfo->fn_mcxt,
793                                                         PG_GETARG_DATUM(1), &b, &sb, NULL);
794
795         if ( ARR_ELEMTYPE(a) != ARR_ELEMTYPE(b) )
796                 elog(ERROR,"Arguments array are not the same type!");
797
798         if (ARRISVOID(a) || ARRISVOID(b))
799                  PG_RETURN_BOOL(false);
800
801         switch(getSmlType())
802         {
803                 case    ST_TFIDF:
804                         power = TFIDFSml(sa, sb);
805                         break;
806                 case    ST_COSINE:
807                         {
808                                 int                             cnt;
809
810                                 power = sqrt( ((double)(sa->nelems)) * ((double)(sb->nelems)) );
811
812                                 if (  ((double)Min(sa->nelems, sb->nelems)) / power < GetSmlarLimit()  )
813                                         PG_RETURN_BOOL(false);
814
815                                 cnt = numOfIntersect(sa, sb);
816                                 power = ((double)cnt) / power;
817                         }
818                         break;
819                 case    ST_OVERLAP:
820                         power = (double)numOfIntersect(sa, sb);
821                         break;
822                 default:
823                         elog(ERROR,"Unsupported formula type of similarity");
824         }
825
826         PG_RETURN_BOOL(power >= GetSmlarLimit());
827 }
828
829 #define QBSIZE          8192
830 static char cachedFormula[QBSIZE];
831 static int      cachedLen  = 0;
832 static void     *cachedPlan = NULL;
833
834 PG_FUNCTION_INFO_V1(arraysml_func);
835 Datum   arraysml_func(PG_FUNCTION_ARGS);
836 Datum
837 arraysml_func(PG_FUNCTION_ARGS)
838 {
839         ArrayType               *a, *b;
840         SimpleArray             *sa, *sb;
841         int                             cnt;
842         float4                  result = -1.0;
843         Oid                             arg[] = {INT4OID, INT4OID, INT4OID};
844         Datum                   pars[3];
845         bool                    isnull;
846         void                    *plan;
847         int                             stat;
848         text                    *formula = PG_GETARG_TEXT_P(2); 
849         
850         fcinfo->flinfo->fn_extra = SearchArrayCache( 
851                                                         fcinfo->flinfo->fn_extra,
852                                                         fcinfo->flinfo->fn_mcxt,
853                                                         PG_GETARG_DATUM(0), &a, &sa, NULL);
854         fcinfo->flinfo->fn_extra = SearchArrayCache( 
855                                                         fcinfo->flinfo->fn_extra,
856                                                         fcinfo->flinfo->fn_mcxt,
857                                                         PG_GETARG_DATUM(1), &b, &sb, NULL);
858
859         if ( ARR_ELEMTYPE(a) != ARR_ELEMTYPE(b) )
860                 elog(ERROR,"Arguments array are not the same type!");
861
862         if (ARRISVOID(a) || ARRISVOID(b))
863                  PG_RETURN_BOOL(false);
864
865         cnt = numOfIntersect(sa, sb);
866
867         if ( VARSIZE(formula) - VARHDRSZ > QBSIZE - 1024 )
868                 elog(ERROR,"Formula is too long");
869
870         SPI_connect();
871
872         
873         if ( cachedPlan == NULL || cachedLen != VARSIZE(formula) - VARHDRSZ ||
874                                 memcmp( cachedFormula, VARDATA(formula), VARSIZE(formula) - VARHDRSZ ) != 0 )
875         {
876                 char                    *ptr, buf[QBSIZE];
877
878                 *cachedFormula = '\0';
879                 if ( cachedPlan )
880                         SPI_freeplan(cachedPlan);
881                 cachedPlan = NULL;
882                 cachedLen = 0;
883
884                 ptr = stpcpy( buf, "SELECT (" );
885                 memcpy( ptr, VARDATA(formula), VARSIZE(formula) - VARHDRSZ );
886                 ptr += VARSIZE(formula) - VARHDRSZ;
887                 ptr = stpcpy( ptr, ")::float4 FROM");
888                 ptr = stpcpy( ptr, " (SELECT $1 ::float8 AS i, $2 ::float8 AS a, $3 ::float8 AS b) AS N;");
889                 *ptr = '\0';
890
891                 plan = SPI_prepare(buf, 3, arg);
892                 if (!plan)
893                         elog(ERROR, "SPI_prepare() failed");
894
895                 cachedPlan = SPI_saveplan(plan);
896                 if (!cachedPlan)
897                         elog(ERROR, "SPI_saveplan() failed");
898
899                 SPI_freeplan(plan);
900                 cachedLen = VARSIZE(formula) - VARHDRSZ;
901                 memcpy( cachedFormula, VARDATA(formula), VARSIZE(formula) - VARHDRSZ ); 
902         }
903
904         plan = cachedPlan;
905
906
907         pars[0] = Int32GetDatum( cnt );
908         pars[1] = Int32GetDatum( sa->nelems );
909         pars[2] = Int32GetDatum( sb->nelems );
910
911         stat = SPI_execute_plan(plan, pars, NULL, true, 3);
912         if (stat < 0)
913                 elog(ERROR, "SPI_execute_plan() returns %d", stat);
914
915         if ( SPI_processed > 0)
916                 result = DatumGetFloat4(SPI_getbinval(SPI_tuptable->vals[0], SPI_tuptable->tupdesc, 1, &isnull));
917
918         SPI_finish();
919
920         PG_RETURN_FLOAT4(result);
921 }
922
923 PG_FUNCTION_INFO_V1(array_unique);
924 Datum   array_unique(PG_FUNCTION_ARGS);
925 Datum
926 array_unique(PG_FUNCTION_ARGS)
927 {
928         ArrayType               *a = PG_GETARG_ARRAYTYPE_P(0);
929         ArrayType               *res;
930         SimpleArray             *sa;
931
932         sa = Array2SimpleArrayU(NULL, a, NULL);
933
934         res = construct_array(  sa->elems, 
935                                                         sa->nelems,
936                                                         sa->info->typid,
937                                                         sa->info->typlen,
938                                                         sa->info->typbyval,
939                                                         sa->info->typalign);
940
941         pfree(sa->elems);
942         pfree(sa);
943         PG_FREE_IF_COPY(a, 0);
944
945         PG_RETURN_ARRAYTYPE_P(res);
946 }
947
948 PG_FUNCTION_INFO_V1(inarray);
949 Datum   inarray(PG_FUNCTION_ARGS);
950 Datum
951 inarray(PG_FUNCTION_ARGS)
952 {
953         ArrayType               *a;
954         SimpleArray             *sa;
955         Datum                   query = PG_GETARG_DATUM(1); 
956         Oid                             queryTypeOid;
957         Datum                   *StopLow,
958                                         *StopHigh,
959                                         *StopMiddle;
960         int                             cmp;
961
962         fcinfo->flinfo->fn_extra = SearchArrayCache( 
963                                                         fcinfo->flinfo->fn_extra,
964                                                         fcinfo->flinfo->fn_mcxt,
965                                                         PG_GETARG_DATUM(0), &a, &sa, NULL);
966
967         queryTypeOid = get_fn_expr_argtype(fcinfo->flinfo, 1);
968
969         if ( queryTypeOid == InvalidOid )
970                 elog(ERROR,"inarray: could not determine actual argument type");
971
972         if ( queryTypeOid != sa->info->typid )
973                 elog(ERROR,"inarray: Type of array's element and type of argument are not the same");
974
975         getFmgrInfoCmp(sa->info);
976         StopLow = sa->elems;
977         StopHigh = sa->elems + sa->nelems;
978
979         while (StopLow < StopHigh) 
980         {
981                 StopMiddle = StopLow + ((StopHigh - StopLow) >> 1);
982                 cmp = cmpArrayElem(StopMiddle, &query, sa->info);
983
984                 if ( cmp == 0 )
985                 {
986                         /* found */
987                         if ( PG_NARGS() >= 3 )
988                                 PG_RETURN_DATUM(PG_GETARG_DATUM(2));
989                         PG_RETURN_FLOAT4(1.0);
990                 }
991                 else if (cmp < 0)
992                         StopLow = StopMiddle + 1;
993                 else
994                         StopHigh = StopMiddle;
995         }
996
997         if ( PG_NARGS() >= 4 )
998                 PG_RETURN_DATUM(PG_GETARG_DATUM(3));
999         PG_RETURN_FLOAT4(0.0);
1000 }