Source code for rdc.etl.extra.db.load

# -*- coding: utf-8 -*-
# Copyright 2012-2014 Romain Dorgueil
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# See the License for the specific language governing permissions and
# limitations under the License.

from copy import copy
from sqlalchemy import MetaData, Table
from rdc.etl.error import ProhibitedOperationError
from rdc.etl.hash import Hash
from rdc.etl.transform import Transform
from rdc.etl.util import now, cached_property

[docs]class DatabaseLoad(Transform): """ TODO doc this !!! test this !!!! """ engine = None table_name = None fetch_columns = None insert_only_fields = () discriminant = ('id', ) created_at_field = 'created_at' updated_at_field = 'updated_at' allowed_operations = (INSERT, UPDATE, ) def __init__(self, engine=None, table_name=None, fetch_columns=None, discriminant=None, created_at_field=None, updated_at_field=None, insert_only_fields=None, allowed_operations=None): super(DatabaseLoad, self).__init__() self.engine = engine or self.engine self.table_name = table_name or self.table_name # XXX should take self.fetch_columns into account if provided self.fetch_columns = {} if isinstance(fetch_columns, (list, tuple, )): self.add_fetch_column(*fetch_columns) elif isinstance(fetch_columns, dict): self.add_fetch_column(**fetch_columns) self.discriminant = discriminant or self.discriminant self.created_at_field = created_at_field or self.created_at_field self.updated_at_field = updated_at_field or self.updated_at_field self.insert_only_fields = insert_only_fields or self.insert_only_fields self.allowed_operations = allowed_operations or self.allowed_operations self._buffer = [] self._connection = None self._max_buffer_size = 1000 self._last_duration = None self._last_commit_at = None self._query_count = 0 @property def connection(self): if self._connection is None: self._connection = self.engine.connect() return self._connection def commit(self): with self.connection.begin(): while len(self._buffer): hash = self._buffer.pop(0) try: yield self.do_transform(copy(hash)) except Exception as e: yield Hash(( ('_input', hash, ), ('_transform', self, ), ('_error', e, ), )), STDERR def close_connection(self): self._connection.close() self._connection = None def get_insert_columns_for(self, hash): """List of columns we can use for insert.""" return self.columns def get_update_columns_for(self, hash, row): """List of columns we can use for update.""" return [ column for column in self.columns if not column in self.insert_only_fields ] def get_columns_for(self, hash, row=None): """Retrieve list of table column names for which we have a value in given hash. """ if row: column_names = self.get_update_columns_for(hash, row) else: column_names = self.get_insert_columns_for(hash) return [key for key in hash if key in column_names] def find(self, dataset, connection=None): query = '''SELECT * FROM {table} WHERE {criteria} LIMIT 1'''.format( table=self.table_name, criteria=' AND '.join([key_atom + ' = %s' for key_atom in self.discriminant]), ) rp = (connection or self.connection).execute(query, [dataset.get(key_atom) for key_atom in self.discriminant]) # Increment stats self._input._special_stats[SELECT] += 1 return rp.fetchone() def initialize(self): super(DatabaseLoad, self).initialize() self._input._special_stats[SELECT] = 0 self._output._special_stats[INSERT] = 0 self._output._special_stats[UPDATE] = 0 def do_transform(self, hash): """Actual database load transformation logic, without the buffering / transaction logic. """ # find line, if it exist row = self.find(hash) now = column_names = self.table.columns.keys() # UpdatedAt field configured ? Let's set the value in source hash if self.updated_at_field in column_names: hash[self.updated_at_field] = now # Otherwise, make sure there is no such field else: if self.updated_at_field in hash: del hash[self.updated_at_field] # UPDATE if row: if not UPDATE in self.allowed_operations: raise ProhibitedOperationError('UPDATE operations are not allowed by this transformation.') _columns = self.get_columns_for(hash, row) query = '''UPDATE {table} SET {values} WHERE {criteria}'''.format( table=self.table_name, values=', '.join(( '{column} = %s'.format(column=_column) for _column in _columns if not _column in self.discriminant )), criteria=' AND '.join(( '{key} = %s'.format(key=_key) for _key in self.discriminant )) ) values = [hash[_column] for _column in _columns if not _column in self.discriminant] + \ [hash[_column] for _column in self.discriminant] # INSERT else: if not INSERT in self.allowed_operations: raise ProhibitedOperationError('INSERT operations are not allowed by this transformation.') if self.created_at_field in column_names: hash[self.created_at_field] = now else: if self.created_at_field in hash: del hash[self.created_at_field] _columns = self.get_columns_for(hash) query = '''INSERT INTO {table} ({keys}) VALUES ({values})'''.format( table=self.table_name, keys=', '.join(_columns), values=', '.join(['%s'] * len(_columns)) ) values = [hash[key] for key in _columns] # Execute self.connection.execute(query, values) # Increment stats if row: self._output._special_stats[UPDATE] += 1 else: self._output._special_stats[INSERT] += 1 # If user required us to fetch some columns, let's query again to get their actual values. if self.fetch_columns and len(self.fetch_columns): if not row: row = self.find(hash) if not row: raise ValueError('Could not find matching row after load.') for alias, column in self.fetch_columns.iteritems(): hash[alias] = row[column] return hash def transform(self, hash, channel=STDIN): """Transform method. Stores the input in a buffer, and only unstack buffer content if some limit has been exceeded. TODO for now buffer limit is hardcoded as 1000, but we may use a few criterias to add intelligence to this: time since last commit, duration of last commit, buffer length ... """ self._buffer.append(hash) if len(self._buffer) >= self._max_buffer_size: for _out in self.commit(): yield _out def finalize(self): """Transform's finalize method. Empties the remaining lines in buffer by loading them into database and close database connection. """ super(DatabaseLoad, self).finalize() for _out in self.commit(): yield _out self.close_connection() def add_fetch_column(self, *columns, **aliased_columns): self.fetch_columns.update(aliased_columns) for column in columns: self.fetch_columns[column] = column @cached_property def columns(self): return self.table.columns.keys() @cached_property def metadata(self): """SQLAlchemy metadata.""" return MetaData() @cached_property def table(self): """SQLAlchemy table object, using metadata autoloading from database to avoid the need of column definitions.""" return Table(self.table_name, self.metadata, autoload=True, autoload_with=self.engine) @property def now(self): """Current timestamp, used for created/updated at fields.""" return now()