# -*- coding: utf-8 -*- """ @Remark: 自定义视图集 """ from drf_yasg import openapi from drf_yasg.utils import swagger_auto_schema from rest_framework.decorators import action from rest_framework.viewsets import ModelViewSet from utils.filters import DataLevelPermissionsFilter from utils.jsonResponse import SuccessResponse,ErrorResponse,DetailResponse from utils.permission import CustomPermission from django.http import Http404 from django.shortcuts import get_object_or_404 as _get_object_or_404 from django.core.exceptions import ValidationError from utils.exception import APIException from django_filters.rest_framework import DjangoFilterBackend from django_filters import utils from rest_framework.filters import OrderingFilter, SearchFilter from rest_framework.permissions import IsAuthenticated from utils.export_excel2 import LyExportExcel def get_object_or_404(queryset, *filter_args, **filter_kwargs): """ Same as Django's standard shortcut, but make sure to also raise 404 if the filter_kwargs don't match the required types. """ try: return _get_object_or_404(queryset, *filter_args, **filter_kwargs) except (TypeError, ValueError, ValidationError): raise APIException(message='该对象不存在或者无访问权限') class CustomDjangoFilterBackend(DjangoFilterBackend): """ 自定义DjangoFilterBackend过滤,重新支持filter_fields,filter_class 新版本:django-filter==22.1开始弃用21.1版本及以下的filter_fields,filter_class 改为:filterset_fields和filterset_class """ def get_filterset_class(self, view, queryset=None): """ Return the `FilterSet` class used to filter the queryset. """ filterset_class = getattr(view, 'filterset_class', None) filterset_fields = getattr(view, 'filterset_fields', None) # TODO: remove assertion in 2.1 if filterset_class is None and hasattr(view, 'filter_class'): utils.deprecate( "`%s.filter_class` attribute should be renamed `filterset_class`." % view.__class__.__name__) filterset_class = getattr(view, 'filter_class', None) # TODO: remove assertion in 2.1 if filterset_fields is None and hasattr(view, 'filter_fields'): utils.deprecate( "`%s.filter_fields` attribute should be renamed `filterset_fields`." % view.__class__.__name__) filterset_fields = getattr(view, 'filter_fields', None) if filterset_class: filterset_model = filterset_class._meta.model # FilterSets do not need to specify a Meta class if filterset_model and queryset is not None: assert issubclass(queryset.model, filterset_model), \ 'FilterSet model %s does not match queryset model %s' % \ (filterset_model, queryset.model) return filterset_class if filterset_fields and queryset is not None: MetaBase = getattr(self.filterset_base, 'Meta', object) class AutoFilterSet(self.filterset_base): class Meta(MetaBase): model = queryset.model fields = filterset_fields return AutoFilterSet return None class CustomModelViewSet(ModelViewSet): """ 自定义的ModelViewSet: 统一标准的返回格式;新增,查询,修改可使用不同序列化器 (1)create_serializer_class 新增时,使用的序列化器 (2)update_serializer_class 修改时,使用的序列化器 (3)export_serializer_class 导出数据,使用的序列化器,可为空,为空则调用默认序列化器 (4)export_field_dict = {} 导出时的字段,如:export_field_dict = {'name':'姓名','age':'年龄'} (5)export_download_mode 指定导出excel形式:temp 内存型临时下载(系统不保存文件,内存文件流直接下载)、url 下载链接(系统保存文件,返回http/https下载链接地址) (6)export_download_filename 指定导出excel文件名称,可为空: 如:export_download_filename = "导出用户数据" """ values_queryset = None ordering_fields = '__all__' create_serializer_class = None update_serializer_class = None export_serializer_class = None export_field_dict = {} export_download_mode = "temp" export_download_filename = "" filterset_fields = () # filterset_fields = '__all__' search_fields = () extra_filter_backends = [DataLevelPermissionsFilter] permission_classes = [CustomPermission,IsAuthenticated] filter_backends = [CustomDjangoFilterBackend,SearchFilter,OrderingFilter]#对于想要提高一点性能的小伙伴可以覆盖filter_backends去除当前接口无用的过滤类 def filter_queryset(self, queryset): for backend in (list(self.extra_filter_backends) + list(self.filter_backends)): queryset = backend().filter_queryset(self.request, queryset, self) return queryset def get_queryset(self): if getattr(self, 'values_queryset', None): return self.values_queryset return super().get_queryset() def get_serializer_class(self): action_serializer_name = f"{self.action}_serializer_class" action_serializer_class = getattr(self, action_serializer_name, None) if action_serializer_class: return action_serializer_class return super().get_serializer_class() # 导出Excel @action(methods=['post'], detail=False) def export(self, request, *args, **kwargs): queryset = self.filter_queryset(self.get_queryset()) request_data = request.data ids = request_data.get('ids', None) # 存在导出的id数据列表则只导出指定的数据 if ids: queryset = queryset.filter(id__in=ids) export_serializer_class = self.export_serializer_class if self.export_serializer_class else self.serializer_class data = export_serializer_class(queryset, many=True, request=request).data if self.export_download_filename: result = LyExportExcel(request=request, downloadMode=self.export_download_mode, fileName=self.export_download_filename + ".xlsx").export_data( self.export_field_dict, data) else: result = LyExportExcel(request=request, downloadMode=self.export_download_mode).export_data( self.export_field_dict, data) if self.export_download_mode == "temp": return result return DetailResponse(data=result, msg="导出成功") def create(self, request, *args, **kwargs): serializer = self.get_serializer(data=request.data, request=request) serializer.is_valid(raise_exception=True) self.perform_create(serializer) return DetailResponse(data=serializer.data, msg="新增成功") def list(self, request, *args, **kwargs): queryset = self.filter_queryset(self.get_queryset()) page = self.paginate_queryset(queryset) if page is not None: serializer = self.get_serializer(page, many=True, request=request) return self.get_paginated_response(serializer.data) serializer = self.get_serializer(queryset, many=True, request=request) return SuccessResponse(data=serializer.data, msg="获取成功") def retrieve(self, request, *args, **kwargs): instance = self.get_object() serializer = self.get_serializer(instance) return SuccessResponse(data=serializer.data, msg="获取成功") def update(self, request, *args, **kwargs): partial = kwargs.pop('partial', False) instance = self.get_object() serializer = self.get_serializer(instance, data=request.data, request=request, partial=partial) serializer.is_valid(raise_exception=True) self.perform_update(serializer) if getattr(instance, '_prefetched_objects_cache', None): # If 'prefetch_related' has been applied to a queryset, we need to # forcibly invalidate the prefetch cache on the instance. instance._prefetched_objects_cache = {} return DetailResponse(data=serializer.data, msg="更新成功") #增强drf得批量删除功能 :http请求方法:delete 如: url /api/admin/user/1,2,3/ 批量删除id 1,2,3得用户 def get_object_list(self): queryset = self.filter_queryset(self.get_queryset()) lookup_url_kwarg = self.lookup_url_kwarg or self.lookup_field assert lookup_url_kwarg in self.kwargs, ( 'Expected view %s to be called with a URL keyword argument ' 'named "%s". Fix your URL conf, or set the `.lookup_field` ' 'attribute on the view correctly.' % (self.__class__.__name__, lookup_url_kwarg) ) filter_kwargs = {f"{self.lookup_field}__in": self.kwargs[lookup_url_kwarg].split(',')} obj = queryset.filter(**filter_kwargs) self.check_object_permissions(self.request, obj) return obj #重写delete方法,让它支持批量删除 如: /api/admin/user/1,2,3/ 批量删除id 1,2,3得用户 def destroy(self, request, *args, **kwargs): instance = self.get_object_list() self.perform_destroy(instance) return DetailResponse(data=[], msg="删除成功") def perform_destroy(self, instance): instance.delete() #原来得单id删除方法 # def destroy(self, request, *args, **kwargs): # instance = self.get_object() # self.perform_destroy(instance) # return SuccessResponse(data=[], msg="删除成功") #新的批量删除方法 keys = openapi.Schema(description='主键列表', type=openapi.TYPE_ARRAY, items=openapi.TYPE_STRING) @swagger_auto_schema(request_body=openapi.Schema( type=openapi.TYPE_OBJECT, required=['keys'], properties={'keys': keys} ), operation_summary='批量删除') @action(methods=['delete'], detail=False) def multiple_delete(self, request, *args, **kwargs): request_data = request.data keys = request_data.get('keys', None) if keys: self.get_queryset().filter(id__in=keys).delete() return SuccessResponse(data=[], msg="删除成功") else: return ErrorResponse(msg="未获取到keys字段")